[Mlir-commits] [mlir] [MLIR][TOSA] add additional verification to TOSA (PR #108133)

Arteen Abrishami llvmlistbot at llvm.org
Tue Sep 10 19:02:52 PDT 2024


https://github.com/arteen1000 created https://github.com/llvm/llvm-project/pull/108133

----------
Motivation:
----------

Spec conformance. Allows assumptions to be made in TOSA code.

------------
Changes Made:
------------

Add full permutation tensor verification to tosa.TRANSPOSE. Priorly would not verify that permuted values were between 0 - (rank - 1).

Update tosa.TRANSPOSE perms data type to be strictly i32.

Verify input/output shapes for tosa.TRANSPOSE.

Add verifier to tosa.CONST, with consideration for quantization.

Fix TOSA conformance of tensor type to disallow dimensions with size 0 for ranked tensors, per spec.
This is not the same as rank 0 tensors. Here is an example of a disallowed tensor: tensor<3x0xi32>. Naturally, this means that the number of elements in a TOSA tensor will always be greater than 0.

>From 925f1f3f77fe437ddb932f0406be672b933b43cf Mon Sep 17 00:00:00 2001
From: Arteen Abrishami <arteen.abrishami at arm.com>
Date: Thu, 5 Sep 2024 20:15:51 +0000
Subject: [PATCH] [MLIR][TOSA] add additional verification to TOSA

----------
Motivation:
----------

Spec conformance. Allows assumptions to be made in TOSA
code.

------------
Changes Made:
------------

Add full permutation tensor verification to tosa.TRANSPOSE.
Priorly would not verify that permuted values were between 0 - (rank - 1).

Update tosa.TRANSPOSE perms data type to be strictly i32.

Verify input/output shapes for tosa.TRANSPOSE.

Add verifier to tosa.CONST, with consideration for quantization.

Fix TOSA conformance of tensor type to disallow dimensions with size 0
for ranked tensors, per spec.
This is not the same as rank 0 tensors. Here is an example of a disallowed
tensor: tensor<3x0xi32>. Naturally, this means that the number of elements
in a TOSA tensor will always be greater than 0.

Signed-off-by: Arteen Abrishami <arteen.abrishami at arm.com>
---
 .../mlir/Dialect/Tosa/IR/CMakeLists.txt       |   4 +-
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td  |  58 ++++----
 .../mlir/Dialect/Tosa/IR/TosaTypesBase.td     |  61 +++++---
 .../mlir/Dialect/Tosa/Utils/ConversionUtils.h |  13 ++
 .../TosaToLinalg/TosaToLinalgNamed.cpp        |  21 +--
 .../Dialect/Tosa/IR/TosaCanonicalizations.cpp |  64 ++++++---
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          | 135 ++++++++++--------
 .../TosaToLinalg/tosa-to-linalg-named.mlir    |  10 +-
 .../TosaToLinalg/tosa-to-linalg-pipeline.mlir |   2 +-
 mlir/test/Dialect/Tosa/constant-op-fold.mlir  |   4 +-
 mlir/test/Dialect/Tosa/invalid.mlir           |  74 ++++++++--
 mlir/test/Dialect/Tosa/ops.mlir               |  16 +++
 12 files changed, 301 insertions(+), 161 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
index 12b4fc402c390f..1ee105f0ceb98b 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
@@ -3,8 +3,8 @@ add_mlir_doc(TosaOps TosaOps Dialects/ -gen-op-doc)
 add_mlir_interface(TosaInterfaces)
 
 set(LLVM_TARGET_DEFINITIONS TosaOps.td)
-mlir_tablegen(TosaAttributes.h.inc -gen-attrdef-decls)
-mlir_tablegen(TosaAttributes.cpp.inc -gen-attrdef-defs)
+mlir_tablegen(TosaAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=tosa)
+mlir_tablegen(TosaAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=tosa)
 add_public_tablegen_target(MLIRTosaAttributesIncGen)
 
 set(LLVM_TARGET_DEFINITIONS TosaDialectBytecode.td)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index ab6daa39708d13..63572f287b7dde 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -73,7 +73,6 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
 
   let arguments = (ins
     Tosa_Tensor4D:$input,
-
     Tosa_IntArrayAttr2:$kernel,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr4:$pad,
@@ -102,9 +101,8 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
 
   let arguments = (ins
     Tosa_Tensor4D:$input,
-    4DTensorOf<[Tosa_Weight]>:$weight,
+    TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
     Tosa_Tensor1D:$bias,
-
     Tosa_IntArrayAttr4:$pad,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr2:$dilation,
@@ -132,9 +130,8 @@ def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> {
 
   let arguments = (ins
     Tosa_Tensor5D:$input,
-    TensorRankOf<[Tosa_Weight], [5]>:$weight,
+    TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
     Tosa_Tensor1D:$bias,
-
     Tosa_IntArrayAttr6:$pad,
     Tosa_IntArrayAttr3:$stride,
     Tosa_IntArrayAttr3:$dilation,
@@ -163,9 +160,8 @@ def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> {
 
   let arguments = (ins
     Tosa_Tensor4D:$input,
-    4DTensorOf<[Tosa_Weight]>:$weight,
+    TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
     Tosa_Tensor1D:$bias,
-
     Tosa_IntArrayAttr4:$pad,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr2:$dilation,
@@ -232,7 +228,7 @@ def Tosa_FullyConnectedOp : Tosa_InferShapedTypeOp<"fully_connected"> {
 
   let arguments = (ins
     Tosa_Tensor2D:$input,
-    2DTensorOf<[Tosa_Weight]>:$weight,
+    TosaTensorRankOf<[Tosa_Weight], [2]>:$weight,
     Tosa_Tensor1D:$bias,
     OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info
   );
@@ -347,9 +343,8 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {
 
   let arguments = (ins
     Tosa_Tensor4D:$input,
-    4DTensorOf<[Tosa_Weight]>:$filter,
+    TosaTensorRankOf<[Tosa_Weight], [4]>:$filter,
     Tosa_Tensor1D:$bias,
-
     Tosa_IntArrayAttr4:$out_pad,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttrUpto4:$out_shape,
@@ -641,12 +636,12 @@ def Tosa_LogicalAndOp : Tosa_ElementwiseOp<"logical_and", [
   }];
 
   let arguments = (ins
-    I1Tensor:$input1,
-    I1Tensor:$input2
+    Tosa_I1Tensor:$input1,
+    Tosa_I1Tensor:$input2
   );
 
   let results = (outs
-    I1Tensor:$z
+    Tosa_I1Tensor:$z
   );
 }
 
@@ -708,12 +703,12 @@ def Tosa_LogicalOrOp : Tosa_ElementwiseOp<"logical_or", [
   }];
 
   let arguments = (ins
-    I1Tensor:$input1,
-    I1Tensor:$input2
+    Tosa_I1Tensor:$input1,
+    Tosa_I1Tensor:$input2
   );
 
   let results = (outs
-    I1Tensor:$z
+    Tosa_I1Tensor:$z
   );
 }
 
@@ -731,12 +726,12 @@ def Tosa_LogicalXorOp : Tosa_ElementwiseOp<"logical_xor", [
   }];
 
   let arguments = (ins
-    I1Tensor:$input1,
-    I1Tensor:$input2
+    Tosa_I1Tensor:$input1,
+    Tosa_I1Tensor:$input2
   );
 
   let results = (outs
-    I1Tensor:$z
+    Tosa_I1Tensor:$z
   );
 }
 
@@ -1085,11 +1080,11 @@ def Tosa_LogicalNotOp : Tosa_ElementwiseOp<"logical_not",
   }];
 
   let arguments = (ins
-    I1Tensor:$input1
+    Tosa_I1Tensor:$input1
   );
 
   let results = (outs
-    I1Tensor:$output
+    Tosa_I1Tensor:$output
   );
 }
 
@@ -1208,7 +1203,7 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
   }];
 
   let arguments = (ins
-    I1Tensor:$pred,
+    Tosa_I1Tensor:$pred,
     Tosa_Tensor:$on_true,
     Tosa_Tensor:$on_false
   );
@@ -1249,7 +1244,7 @@ def Tosa_EqualOp : Tosa_ElementwiseOp<"equal", [
   );
 
   let results = (outs
-    I1Tensor:$output
+    Tosa_I1Tensor:$output
   );
 
   let extraClassDeclaration = [{
@@ -1277,7 +1272,7 @@ def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [SameOperandsElementType]> {
   );
 
   let results = (outs
-    I1Tensor:$output
+    Tosa_I1Tensor:$output
   );
 
   let hasFolder = 1;
@@ -1300,7 +1295,7 @@ def Tosa_GreaterEqualOp : Tosa_ElementwiseOp<"greater_equal",
   );
 
   let results = (outs
-    I1Tensor:$output
+    Tosa_I1Tensor:$output
   );
 
   let hasFolder = 1;
@@ -1721,7 +1716,7 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
 
   let arguments = (ins
     Tosa_Tensor:$input1,
-    Tosa_Int32Or64Tensor:$perms
+    Tosa_Int32Tensor:$perms
   );
 
   let results = (
@@ -1729,7 +1724,7 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
   );
 
   let extraClassDeclaration = [{
-    LogicalResult getConstantPerms(llvm::SmallVector<int64_t> &perms);
+    LogicalResult getConstantPerms(llvm::SmallVector<int32_t> &perms);
   }];
 
   let hasCanonicalizer = 1;
@@ -1755,7 +1750,7 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
 
   let arguments = (ins
     Tosa_Tensor3D:$values,
-    2DTensorOf<[Tosa_Int32]>:$indices
+    TosaTensorRankOf<[Tosa_Int32], [2]>:$indices
   );
 
   let results = (outs
@@ -1776,7 +1771,7 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
 
   let arguments = (ins
     Tosa_Tensor3D:$values_in,
-    2DTensorOf<[Tosa_Int32]>:$indices,
+    TosaTensorRankOf<[Tosa_Int32], [2]>:$indices,
     Tosa_Tensor3D:$input
   );
 
@@ -1947,10 +1942,11 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
   );
 
   let results = (outs
-    TensorOf<[AnyTypeOf<[Tosa_AnyNumber]>]>:$output
+    TosaTensorOf<[AnyTypeOf<[Tosa_AnyNumber]>]>:$output
   );
 
   let hasFolder = 1;
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -2054,7 +2050,7 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
   }];
 
   let arguments = (ins
-    I1Tensor:$cond,
+    Tosa_I1Tensor:$cond,
     Variadic<Tosa_Tensor>:$inputs
   );
 
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 14fc9c7a6730cc..c3a0128e95a84b 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -82,58 +82,83 @@ def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
 def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
                              Tosa_QuantizedInt, AnyFloat]>;
 
+//===----------------------------------------------------------------------===//
+// TOSA Tensor Conformance
+//===----------------------------------------------------------------------===//
+
+def HasNo0Dimensions : And<[
+    IsRankedTensorTypePred,
+    CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v != 0; })">]>;
+
+class TosaTensorOf<
+    list<Type> allowedTypes, string summary = "tosa-conformant tensor">
+    : TensorOf<allowedTypes, [Or<[HasNo0Dimensions, IsUnrankedTensorTypePred]>], summary>;
+
+class TosaRankedTensorOf<
+    list<Type> allowedTypes, list<Pred> preds = [], string summary = "tosa-conformant ranked tensor">
+    : RankedTensorOf<allowedTypes, !listconcat([HasNo0Dimensions], preds), summary>;
+
+class TosaUnrankedTensorOf<list<Type> allowedTypes, list<Pred> preds = [], string summary = "tosa-conformant unranked tensor">
+    : UnrankedTensorOf<allowedTypes, preds, summary>;
+
+class TosaTensorRankOf<list<Type> allowedTypes, list<int> ranks>
+    : TosaRankedTensorOf<allowedTypes,
+      [HasAnyRankOfPred<ranks>],
+      !interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensor">;
+
 //===----------------------------------------------------------------------===//
 // Tensor types
 //===----------------------------------------------------------------------===//
 
-def Tosa_Int32Tensor : TensorOf<[Tosa_Int32]>;
-def Tosa_Int32Or64Tensor : TensorOf<[Tosa_Int32Or64]>;
+def Tosa_I1Tensor : TosaTensorOf<[I1]>;
+def Tosa_Int32Tensor : TosaTensorOf<[Tosa_Int32]>;
+def Tosa_Int32Or64Tensor :TosaTensorOf<[Tosa_Int32Or64]>;
 
-def Tosa_FloatTensor : TensorOf<[AnyFloat]>;
+def Tosa_FloatTensor : TosaTensorOf<[AnyFloat]>;
 
 // Either ranked or unranked tensor of TOSA supported element types.
-def Tosa_Tensor : TensorOf<[Tosa_AnyNumber]>;
+def Tosa_Tensor : TosaTensorOf<[Tosa_AnyNumber]>;
 
 // Must be ranked but no further constraints
-def Tosa_RankedTensor : RankedTensorOf<[Tosa_AnyNumber]>;
+def Tosa_RankedTensor : TosaRankedTensorOf<[Tosa_AnyNumber]>;
 
 // Any tensor element type allowed in Tosa ops.
 def Tosa_ElementType : Type<Or<[Tosa_Int.predicate, Tosa_QuantizedInt.predicate,
                                 AnyFloat.predicate]>, "tosa.dtype">;
 
 class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
-  AnyTypeOf<[TensorOf<allowedTypes>, NoneType], description>;
+  AnyTypeOf<[TosaTensorOf<allowedTypes>, NoneType], description>;
 
 //===----------------------------------------------------------------------===//
 // Tensor types with constrained ranks.
 //===----------------------------------------------------------------------===//
 
 // Rank-0 (scalar) tensor
-def Tosa_ScalarTensor : TensorRankOf<[Tosa_AnyNumber], [0]>;
+def Tosa_ScalarTensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;
 
 // We include unranked tensors as a supported type for all possible tosa
 // Tensors as unranked does not guarantee invalid. If unranked tensors exist
 // they should be shape propagate used Tosa's shape inference pass and verified
 // to not include any remaining unranked tensors.
-def Tosa_UnrankedTensor : UnrankedTensorOf<[Tosa_AnyNumber]>;
+def Tosa_UnrankedTensor : TosaUnrankedTensorOf<[Tosa_AnyNumber]>;
 
-def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, 1DTensorOf<[Tosa_AnyNumber]>], "1-d tensor", "::mlir::TensorType">;
-def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, 2DTensorOf<[Tosa_AnyNumber]>], "2-d tensor", "::mlir::TensorType">;
-def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, 3DTensorOf<[Tosa_AnyNumber]>], "3-d tensor", "::mlir::TensorType">;
-def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, 4DTensorOf<[Tosa_AnyNumber]>], "4-d tensor", "::mlir::TensorType">;
-def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tensor", "::mlir::TensorType">;
+def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1]>], "1-d tosa-conformant tensor", "::mlir::TensorType">;
+def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [2]>], "2-d tosa-conformant tensor", "::mlir::TensorType">;
+def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [3]>], "3-d tosa-conformant tensor", "::mlir::TensorType">;
+def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [4]>], "4-d tosa-conformant tensor", "::mlir::TensorType">;
+def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tosa-conformant tensor", "::mlir::TensorType">;
 
 // Ranked tensors up to given rank.
 def Tosa_Tensor1Dto4D : AnyTypeOf<[
-  Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>]>;
+  Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>]>;
 def Tosa_Tensor1Dto6D : AnyTypeOf<[
-  Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [1,2,3,4,5,6]>]>;
+  Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1,2,3,4,5,6]>]>;
 
 def Tosa_TensorUpto4D : AnyTypeOf<[
-  Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>]>;
+  Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>]>;
 
 def Tosa_Int32TensorUpto4D : AnyTypeOf<[
-  Tosa_UnrankedTensor, TensorRankOf<[Tosa_Int32], [0,1,2,3,4]>]>;
+  Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [0,1,2,3,4]>]>;
 
 //===----------------------------------------------------------------------===//
 // Generic scalar, vector, or tensor of a particular type.
@@ -142,7 +167,7 @@ def Tosa_Int32TensorUpto4D : AnyTypeOf<[
 class Tosa_TypeLike<list<Type> types, string description = ""> : TypeConstraint<Or<[
      AnyTypeOf<types>.predicate,
      VectorOf<types>.predicate,
-     TensorOf<types>.predicate]>,
+     TosaTensorOf<types>.predicate]>,
      description>;
 
 def Tosa_IntLike : Tosa_TypeLike<[Tosa_Int], "signless-integer-like">;
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
index ef40b348ab5499..90fea1f68beb58 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
@@ -216,6 +216,19 @@ TosaOp CreateOpAndInferShape(PatternRewriter &rewriter, Location loc,
   return CreateOpAndInferShape<TosaOp>(builder, resultTy, args...);
 }
 
+// Apply an int32_t permutation to some input, that should be of the same
+// size as perms. Perms should contain some permutation of 0 - perms.size() - 1.
+template <typename T>
+SmallVector<T> applyTOSAPermutation(ArrayRef<T> input,
+                                    ArrayRef<int32_t> perms) {
+  SmallVector<T> permuted;
+  size_t N = input.size();
+  permuted.resize_for_overwrite(N);
+  for (size_t i = 0; i < N; i++)
+    permuted[i] = input[perms[i]];
+  return permuted;
+}
+
 } // namespace tosa
 } // namespace mlir
 
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 77c3d2e8757910..fe53b499674324 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -313,7 +313,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
         // convolution operation.
         // TODO(suderman): See if this can be efficiently folded - check whether
         // the input is used anywhere else, if not fold the constant.
-        SmallVector<int64_t> weightPerm;
+        SmallVector<int32_t> weightPerm;
         for (int i = 1; i < resultTy.getRank(); i++)
           weightPerm.push_back(i);
         weightPerm.push_back(0);
@@ -321,7 +321,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
         SmallVector<int64_t> newWeightShape;
         for (auto dim : weightPerm)
           newWeightShape.push_back(weightShape[dim]);
-        auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm);
+        auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
         Value weightPermValue =
             rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
         Type newWeightTy =
@@ -337,7 +337,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
     if (5 == inputTy.getRank()) {
       // TODO(suderman): See if this can be efficiently folded - check whether
       // the input is used anywhere else, if not fold the constant.
-      SmallVector<int64_t> weightPerm;
+      SmallVector<int32_t> weightPerm;
       for (int i = 1; i < resultTy.getRank(); i++)
         weightPerm.push_back(i);
       weightPerm.push_back(0);
@@ -345,7 +345,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
       SmallVector<int64_t> newWeightShape;
       for (auto dim : weightPerm)
         newWeightShape.push_back(weightShape[dim]);
-      auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm);
+      auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
       Value weightPermValue =
           rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
       Type newWeightTy =
@@ -1040,22 +1040,25 @@ class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
 
   LogicalResult matchAndRewrite(tosa::TransposeOp op,
                                 PatternRewriter &rewriter) const final {
-    SmallVector<int64_t> constantPerms;
+    SmallVector<int32_t> constantPerms;
     if (failed(op.getConstantPerms(constantPerms)))
       return failure();
 
     Location loc = op.getLoc();
-    // The verifier should have made sure we have a valid permutation tensor.
-    assert(isPermutationVector(constantPerms) && "Expected valid permutation");
+    // The verifier should have made sure we have a valid TOSA permutation
+    // tensor. isPermutationVector doesn't actually check the TOSA perms we
+    // expect.
     SmallVector<OpFoldResult> inputSizes =
         tensor::getMixedSizes(rewriter, loc, op.getInput1());
     auto permutedSizes =
-        applyPermutation<OpFoldResult>(inputSizes, constantPerms);
+        applyTOSAPermutation<OpFoldResult>(inputSizes, constantPerms);
 
     auto permutedInit = rewriter.create<tensor::EmptyOp>(
         loc, permutedSizes, op.getInput1().getType().getElementType());
     rewriter.replaceOpWithNewOp<linalg::TransposeOp>(
-        op, op.getInput1(), permutedInit, constantPerms);
+        op, op.getInput1(), permutedInit,
+        llvm::to_vector(llvm::map_range(
+            constantPerms, [](int32_t v) -> int64_t { return v; })));
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index da9a93feac4d65..03876a7c64d07c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -88,7 +88,7 @@ struct ConsolidateTransposeOptimization
       return rewriter.notifyMatchFailure(transposeOp,
                                          "input must be transpose operation");
 
-    SmallVector<int64_t> transposePerms, innerTransposePerms;
+    SmallVector<int32_t> transposePerms, innerTransposePerms;
     if (transposeOp.getConstantPerms(transposePerms).failed())
       return rewriter.notifyMatchFailure(transposeOp,
                                          "transpose perms must be constant");
@@ -497,8 +497,10 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
     return {};
 
   auto resultETy = resultTy.getElementType();
-  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
-  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
+  auto lhsAttr =
+      llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+  auto rhsAttr =
+      llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
 
   if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
     return getInput1();
@@ -536,8 +538,10 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
 
   // IntDivOp inputs must be integer type, no need to check for quantized type
   auto resultETy = resultTy.getElementType();
-  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
-  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
+  auto lhsAttr =
+      llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+  auto rhsAttr =
+      llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
   if (lhsAttr && lhsAttr.isSplat()) {
     if (llvm::isa<IntegerType>(resultETy) &&
         lhsAttr.getSplatValue<APInt>().isZero())
@@ -605,10 +609,13 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
     return {};
 
   auto resultETy = resultTy.getElementType();
-  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
-  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
+  auto lhsAttr =
+      llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+  auto rhsAttr =
+      llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
 
   const int64_t shift = llvm::isa<IntegerType>(resultETy) ? getShift() : 0;
+
   if (rhsTy == resultTy) {
     if (isSplatZero(resultETy, lhsAttr))
       return lhsAttr.resizeSplat(resultTy);
@@ -638,8 +645,10 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
     return {};
 
   auto resultETy = resultTy.getElementType();
-  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
-  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
+  auto lhsAttr =
+      llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+  auto rhsAttr =
+      llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
 
   if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
     return getInput1();
@@ -681,8 +690,10 @@ struct APIntFoldGreaterEqual {
 
 OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
   auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
-  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
-  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
+  auto lhsAttr =
+      llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+  auto rhsAttr =
+      llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
 
   if (!lhsAttr || !rhsAttr)
     return {};
@@ -693,8 +704,10 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
 
 OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
   auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
-  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
-  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
+  auto lhsAttr =
+      llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+  auto rhsAttr =
+      llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
 
   if (!lhsAttr || !rhsAttr)
     return {};
@@ -706,8 +719,10 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
 
 OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
   auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
-  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
-  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
+  auto lhsAttr =
+      llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+  auto rhsAttr =
+      llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
   Value lhs = getInput1();
   Value rhs = getInput2();
   auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
@@ -838,14 +853,16 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
     return {};
 
   // reshape(const(x)) -> const(reshape-attr(x))
-  if (auto operand = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
+  if (auto operand =
+          llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
     // Constants must have static shape.
     if (!outputTy.hasStaticShape())
       return {};
 
     // Okay to duplicate splat constants.
     if (operand.isSplat())
-      return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
+      return SplatElementsAttr::get(outputTy,
+                                    operand.getSplatValue<Attribute>());
 
     // Don't duplicate other constants.
     if (!getInput1().hasOneUse())
@@ -905,7 +922,8 @@ OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
   auto operand = getInput();
   auto operandTy = llvm::cast<ShapedType>(operand.getType());
   auto axis = getAxis();
-  auto operandAttr = llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput());
+  auto operandAttr =
+      llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput());
   if (operandAttr)
     return operandAttr;
 
@@ -954,7 +972,8 @@ OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
   if (getOnTrue() == getOnFalse())
     return getOnTrue();
 
-  auto predicate = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred());
+  auto predicate =
+      llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred());
   if (!predicate)
     return {};
 
@@ -975,7 +994,8 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
   auto resultTy = llvm::cast<ShapedType>(getType());
 
   // Transposing splat values just means reshaping.
-  if (auto input = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
+  if (auto input =
+          llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
     if (input.isSplat() && resultTy.hasStaticShape() &&
         input.getType().getElementType() == resultTy.getElementType())
       return input.reshape(resultTy);
@@ -986,11 +1006,11 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
     return {};
 
   // Transpose is not the identity transpose.
-  SmallVector<int64_t> perms;
+  SmallVector<int32_t> perms;
   if (getConstantPerms(perms).failed())
     return {};
 
-  if (!llvm::equal(llvm::seq<int64_t>(0, perms.size()), perms))
+  if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
     return {};
 
   return getInput1();
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index d93db1b237f316..0d0241fea5152c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -204,22 +204,6 @@ void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
 // TOSA Operator Verifiers.
 //===----------------------------------------------------------------------===//
 
-static bool hasZeroDimension(ShapedType shapedType) {
-  if (!shapedType.hasRank())
-    return false;
-
-  auto rank = shapedType.getRank();
-
-  for (int i = 0; i < rank; i++) {
-    if (shapedType.isDynamicDim(i))
-      continue;
-    if (shapedType.getDimSize(i) == 0)
-      return true;
-  }
-
-  return false;
-}
-
 template <typename T>
 static LogicalResult verifyConvOp(T op) {
   // All TOSA conv ops have an input() and weight().
@@ -236,10 +220,6 @@ static LogicalResult verifyConvOp(T op) {
     return failure();
   }
 
-  if (hasZeroDimension(inputType))
-    return op.emitOpError() << "tensor has a dimension with size zero. Each "
-                               "dimension of a tensor must have size >= 1";
-
   auto inputEType = inputType.getElementType();
   auto weightEType = weightType.getElementType();
 
@@ -262,6 +242,29 @@ static LogicalResult verifyConvOp(T op) {
                    "allowed for float type");
     return failure();
   }
+  return success();
+}
+
+LogicalResult tosa::ConstOp::verify() {
+
+  auto attrType = llvm::dyn_cast<TensorType>(getValueAttr().getType());
+  auto outputType = llvm::dyn_cast<TensorType>(getOutput().getType());
+
+  if (!attrType || !outputType) {
+    emitOpError("expected tensors for attr/result type");
+    return failure();
+  }
+
+  if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
+          outputType.getElementType())) {
+    if (result.getStorageType() == attrType.getElementType())
+      return success();
+  }
+
+  if (attrType.getElementType() != outputType.getElementType()) {
+    emitOpError("expected same attr/result element types");
+    return failure();
+  }
 
   return success();
 }
@@ -283,9 +286,6 @@ LogicalResult tosa::ArgMaxOp::verify() {
 
 LogicalResult tosa::AvgPool2dOp::verify() {
   auto inputType = llvm::cast<ShapedType>(getInput().getType());
-  if (hasZeroDimension(inputType))
-    return emitOpError() << "tensor has a dimension with size zero. Each "
-                            "dimension of a tensor must have size >= 1";
 
   auto inputETy = inputType.getElementType();
   auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
@@ -341,9 +341,9 @@ LogicalResult tosa::ClampOp::verify() {
   if (inputETy != outputETy)
     return emitOpError("input/output element types are incompatible.");
 
-  // if input datatype is float, check that the two min/max_fp attributes share
-  // the same type and that their type is either the same of the input's
-  // datatype, or a float type whose bitwidth > input datatype bitwidth
+  // If input datatype is float, check that the two min/max_fp attributes
+  // share the same type and that their type is either the same of the input's
+  // datatype, or a float type whose bitwidth > input datatype bitwidth.
   if (!inputETy.isInteger(dataTypeBitWidth)) {
     if (((maxFpType != minFpType) ||
          (maxFpType != inputETy && maxFpType.getIntOrFloatBitWidth() <=
@@ -383,7 +383,8 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
   }
 }
 
-/// Handles tosa.transpose_conv2d which has outpad and output shape attributes.
+/// Handles tosa.transpose_conv2d which has outpad and output shape
+/// attributes.
 static void buildTransConvOpWithQuantInfo(
     OpBuilder &builder, OperationState &result, Type outputType, Value input,
     Value weight, Value bias, DenseI64ArrayAttr outpad,
@@ -420,9 +421,9 @@ static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result,
   }
 }
 
-/// The tosa.matmul op is also intended to be generated where a fully_connected
-/// op must be constructed where the weight is not a constant. In this case,
-/// the fully_connected op must be expressed using matmul.
+/// The tosa.matmul op is also intended to be generated where a
+/// fully_connected op must be constructed where the weight is not a constant.
+/// In this case, the fully_connected op must be expressed using matmul.
 /// TODO: Add link to the leglization document explaining this.
 static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
                                        OperationState &result, Type outputType,
@@ -457,9 +458,9 @@ static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
   }
 }
 
-/// Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr
-/// but avg_pool operator has its own builder as it has additional parameters
-/// not part of the unary ops.
+/// Both the tosa.avg_pool2d and unary ops use the same
+/// UnaruOpQuantizationAttr but avg_pool operator has its own builder as it
+/// has additional parameters not part of the unary ops.
 static void
 buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
                               Type outputType, Value input,
@@ -526,8 +527,8 @@ static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
   for (int i = 0, e = operands.size(); i != e; ++i) {
     auto shape = operands.getShape(i);
     if (!shape.hasRank()) {
-      // TODO(jennik): Update function to have better case handling for invalid
-      // operands and for ranked tensors.
+      // TODO(jennik): Update function to have better case handling for
+      // invalid operands and for ranked tensors.
       return failure();
     }
     outRank = std::max<int64_t>(outRank, shape.getRank());
@@ -776,8 +777,8 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
     return success();
   }
 
-  // If the input rank is unknown we can info the output rank using the padding
-  // shape's first dim.
+  // If the input rank is unknown we can info the output rank using the
+  // padding shape's first dim.
   if (!inputShape.hasRank()) {
     if (paddingShape.isDynamicDim(0)) {
       inferredReturnShapes.push_back(ShapedTypeComponents());
@@ -1000,10 +1001,6 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
   TensorType inputType = getInput1().getType();
   RankedTensorType outputType = getType();
 
-  if (hasZeroDimension(inputType) || hasZeroDimension(outputType))
-    return emitOpError() << "tensor has a dimension with size zero. Each "
-                            "dimension of a tensor must have size >= 1";
-
   if ((int64_t)getNewShape().size() != outputType.getRank())
     return emitOpError() << "new shape does not match result rank";
 
@@ -1034,16 +1031,15 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
   return mlir::success();
 }
 
-LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int64_t> &perms) {
+LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int32_t> &perms) {
   // Perms must be constants.
   DenseIntElementsAttr permsAttr;
   if (!matchPattern(getPerms(), m_Constant(&permsAttr)))
     return failure();
 
-  // Transpose is not the identity transpose.
-  perms = llvm::to_vector(
-      llvm::map_range(permsAttr.getValues<APInt>(),
-                      [](const APInt &val) { return val.getSExtValue(); }));
+  perms.clear();
+  for (auto v : permsAttr.getValues<APInt>())
+    perms.push_back(v.getSExtValue());
 
   return success();
 }
@@ -1067,8 +1063,8 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
     return success();
   }
 
-  // This would imply the number of permutations does not match the rank of the
-  // input which is illegal.
+  // This would imply the number of permutations does not match the rank of
+  // the input which is illegal.
   if (permsShape.getDimSize(0) != inputShape.getRank()) {
     return failure();
   }
@@ -1154,19 +1150,38 @@ LogicalResult tosa::TransposeOp::verify() {
                            << " (output rank) but got size "
                            << permType.getDimSize(0);
 
-  SmallVector<int64_t> constantPerms;
+  SmallVector<int32_t> constantPerms;
   if (succeeded(getConstantPerms(constantPerms))) {
-    // Assert that the permutation tensor has a rank, which means that the rank
-    // has been verified above.
+    // Assert that the permutation tensor has a rank, which means that the
+    // rank has been verified above.
     assert(permType.hasRank() &&
            "Unexpectedly found permutation tensor without rank");
-    if (!isPermutationVector(constantPerms))
+    if (!llvm::all_of(constantPerms,
+                      [&constantPerms](int32_t s) {
+                        return s >= 0 &&
+                               static_cast<size_t>(s) < constantPerms.size();
+                      }) ||
+        !isPermutationVector(llvm::to_vector(llvm::map_range(
+            constantPerms, [](int32_t v) -> int64_t { return v; }))))
       return emitOpError() << "expected valid permutation tensor";
 
-    if (inputType.hasRank() && !llvm::all_of(constantPerms, [&](int64_t s) {
-          return s < inputType.getRank();
-        })) {
-      return emitOpError() << "permutation must be within input bounds";
+    // Verify that the types of the input and output tensors are properly
+    // permuted.
+    if (inputType.hasRank() && outputType.hasRank()) {
+      assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) &&
+             inputType.getRank() == outputType.getRank());
+
+      for (auto i = 0; i < outputType.getRank(); i++) {
+        if (inputType.isDynamicDim(constantPerms[i]) ||
+            outputType.isDynamicDim(i))
+          continue;
+
+        if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
+          return emitOpError()
+                 << "expected output tensor dim " << i << " to match "
+                 << "input dim " << constantPerms[i] << " with value of "
+                 << inputType.getDimSize(constantPerms[i]);
+      }
     }
   }
   return success();
@@ -1175,7 +1190,7 @@ LogicalResult tosa::TransposeOp::verify() {
 LogicalResult TransposeOp::reifyResultShapes(
     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
 
-  SmallVector<int64_t> transposePerms;
+  SmallVector<int32_t> transposePerms;
   if (getConstantPerms(transposePerms).failed())
     return failure();
 
@@ -1184,7 +1199,7 @@ LogicalResult TransposeOp::reifyResultShapes(
 
   SmallVector<OpFoldResult> returnedDims(inputType.getRank());
   for (auto dim : transposePerms) {
-    int64_t dimInInput = transposePerms[dim];
+    int32_t dimInInput = transposePerms[dim];
     if (inputType.isDynamicDim(dimInInput))
       returnedDims[dim] =
           builder.create<tensor::DimOp>(getLoc(), input, dimInInput)
@@ -1378,8 +1393,8 @@ static LogicalResult verifyReduceOp(T op) {
           << ")";
       return failure();
     }
-    // We can only verify the reduced dimension size to be 1 if this is not the
-    // special case of output rank == 0.
+    // We can only verify the reduced dimension size to be 1 if this is not
+    // the special case of output rank == 0.
     if (outputRank != 0) {
       auto outputShape = outputType.getShape();
       if (!outputType.isDynamicDim(reduceAxis) &&
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 39699ee315e6cb..0d55d1899c713e 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -1,6 +1,6 @@
-// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named))" %s -verify-diagnostics -o -| FileCheck %s
-// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named{prefer-conv2d-kernel-layout-hwcf=true}))" %s -verify-diagnostics -o -| FileCheck --check-prefix="HWCF" %s
-// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,cse))" %s -verify-diagnostics -o -| FileCheck --check-prefix="CHECK-CSE" %s
+// RUN: mlir-opt --verify-each --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named))" %s -verify-diagnostics -o -| FileCheck %s
+// RUN: mlir-opt --verify-each --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named{prefer-conv2d-kernel-layout-hwcf=true}))" %s -verify-diagnostics -o -| FileCheck --check-prefix="HWCF" %s
+// RUN: mlir-opt --verify-each --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,cse))" %s -verify-diagnostics -o -| FileCheck --check-prefix="CHECK-CSE" %s
 
 // CHECK-LABEL: @matmul
 func.func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor<1x5x6xf32>) {
@@ -521,7 +521,7 @@ func.func @conv2d_scalar_bias_f32(%input: tensor<1x49x42x27xf32>, %weights: tens
 
 // CHECK-LABEL: @conv2d_i8
 func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi8>, %bias: tensor<28xi8>) -> () {
-  // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
+  // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi32>
   // HWCF: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x1x1x27xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<1x1x27x28xi8>) permutation = [1, 2, 3, 0]
   // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xi32>
   // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xi8>) outs(%[[INIT]] : tensor<1x45x40x28xi32>) {
@@ -542,7 +542,7 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi
 
 // CHECK-LABEL: @conv2d_f32
 func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
-  // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
+  // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi32>
   // HWCF: %[[TRANSPOSE:.+]] =  linalg.transpose ins(%arg1 : tensor<28x3x3x27xf32>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x3x27x28xf32>) permutation = [1, 2, 3, 0]
 
   // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xf32>
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
index c2bbfd5130ebcd..73da2810abe044 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
@@ -24,7 +24,7 @@ func.func @tensor_with_unknown_rank(%arg0: tensor<*xi8>) -> tensor<*xi8> {
 
 // check that tosa verify kick in
 func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32> {
-  // expected-error at +1 {{'tosa.avg_pool2d' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
+  // expected-error at +1 {{'tosa.avg_pool2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x0x?x9xf32>'}}
     %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
       : (tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32>
     return %0 : tensor<1x7x7x9xf32>
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index 8e19f87dbf4aa8..2902c4a62009e9 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -80,14 +80,14 @@ func.func @transpose_fold_4d_int() -> tensor<3x1x4x2xi32> {
     [[ 0,  1,  2,  3], [ 4,  5,  6,  7], [ 8,  9, 10, 11]],
     [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
   ]]> : tensor<1x2x3x4xi32>} : () -> tensor<1x2x3x4xi32>
-  %perms = "tosa.const"() {value = dense<[2, 0, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64>
+  %perms = "tosa.const"() {value = dense<[2, 0, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
   //               CHECK: %[[CST:.+]] = "tosa.const"() <{
   // CHECK-SAME{LITERAL}: value = dense<[
   // CHECK-SAME{LITERAL}:   [[[0, 12], [1, 13], [2, 14], [3, 15]]],
   // CHECK-SAME{LITERAL}:   [[[4, 16], [5, 17], [6, 18], [7, 19]]],
   // CHECK-SAME{LITERAL}:   [[[8, 20], [9, 21], [10, 22], [11, 23]]]
   // CHECK-SAME{LITERAL}: ]>
-  %1 = tosa.transpose %input, %perms : (tensor<1x2x3x4xi32>, tensor<4xi64>) -> tensor<3x1x4x2xi32>
+  %1 = tosa.transpose %input, %perms : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<3x1x4x2xi32>
   // CHECK: return %[[CST]]
   return %1 : tensor<3x1x4x2xi32>
 }
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 418f7687b3cce8..414bcfe237d753 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1,6 +1,22 @@
 // RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate=strict-op-spec-alignment
 
 
+func.func @test_const() -> tensor<1xf32> {
+  // expected-error at +1{{'tosa.const' op expected same attr/result element types}}
+  %0 = "tosa.const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xf32>
+  return %0 : tensor<1xf32>
+}
+
+// -----
+
+func.func @test_const_non_tensor_attr() {
+  // expected-error at +1{{tosa.const' op expected tensors for attr/result type}}
+  %0 = "tosa.const"() {value = dense<1.0> : vector<f32>} : () -> tensor<f32>
+  return
+}
+
+// -----
+
 func.func @test_conv2d(%arg0: tensor<1x29x29x4xf32>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
   // expected-error at +1 {{expect both input and weight to be float or not together, got 'f32' and 'i8'}}
   %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
@@ -148,6 +164,42 @@ func.func @test_transpose_invalid_permutation_tensor(%arg0: tensor<13x21x3xf32>)
 
 // -----
 
+func.func @test_transpose_invalid_permutation_negative(%arg0: tensor<3x2xi32>) -> tensor<*xi32> {
+  %perms = "tosa.const"() {value = dense<[-1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+  // expected-error at +1 {{'tosa.transpose' op expected valid permutation tensor}}
+  %1 = tosa.transpose %arg0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<*xi32>
+  return %1 : tensor<*xi32>
+}
+
+// -----
+
+func.func @test_transpose_invalid_permutation_tensor_above_range(%arg0: tensor<3x2xi32>) -> tensor<*xi32> {
+  %perms = "tosa.const"() {value = dense<[2, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+  // expected-error at +1 {{'tosa.transpose' op expected valid permutation tensor}}
+  %1 = tosa.transpose %arg0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<*xi32>
+  return %1 : tensor<*xi32>
+}
+
+// -----
+
+func.func @test_transpose_invalid_permutation_types(%arg0: tensor<3x2xi32>) -> tensor<3x4xi32> {
+  %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+  // expected-error at +1 {{'tosa.transpose' op expected output tensor dim 0 to match input dim 1 with value of 2}}
+  %1 = tosa.transpose %arg0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<3x4xi32>
+  return %1 : tensor<3x4xi32>
+}
+
+// -----
+
+func.func @test_transpose_invalid_permutation_types_dynamic_dim_ok(%arg0: tensor<2x?xi32>) -> tensor<3x4xi32> {
+  %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+  // expected-error at +1 {{'tosa.transpose' op expected output tensor dim 1 to match input dim 0 with value of 2}}
+  %1 = tosa.transpose %arg0, %perms : (tensor<2x?xi32>, tensor<2xi32>) -> tensor<3x4xi32>
+  return %1 : tensor<3x4xi32>
+}
+
+// -----
+
 func.func @test_fully_connected_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<2x3xf32>) -> tensor<273x2xf32> {
   %0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32>
   %1 = tosa.reshape %arg0 {new_shape = array<i64: 273, 3>} : (tensor<13x21x3xf32>) -> tensor<273x3xf32>
@@ -269,7 +321,7 @@ func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () {
 // -----
 
 func.func @test_reshape_static_zero_dim_input(%arg0 : tensor<13x0x3xf32>) -> () {
-  // expected-error at +1 {{'tosa.reshape' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
+  // expected-error at +1 {{'tosa.reshape' op operand #0 must be tosa-conformant tensor of number values, but got 'tensor<13x0x3xf32>'}}
   %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 13, 21, 3>} : (tensor<13x0x3xf32>) -> tensor<13x0x3xf32>
   return
 }
@@ -277,7 +329,7 @@ func.func @test_reshape_static_zero_dim_input(%arg0 : tensor<13x0x3xf32>) -> ()
 // -----
 
 func.func @test_reshape_zero_dim_input(%arg0 : tensor<?x0x3xf32>) -> () {
-  // expected-error at +1 {{'tosa.reshape' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
+  // expected-error at +1 {{'tosa.reshape' op operand #0 must be tosa-conformant tensor of number values, but got 'tensor<?x0x3xf32>'}}
   %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 13, 21, 3>} : (tensor<?x0x3xf32>) -> tensor<13x0x3xf32>
   return
 }
@@ -341,7 +393,7 @@ func.func @test_const_attribute_type_mismatch() -> tensor<100x100xf32> {
 // -----
 
 func.func @test_conv2d_static_zero_dim_input(%arg0: tensor<1x29x0x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<16xf32>) -> tensor<1x27x27x16xf32> {
-  // expected-error at +1 {{'tosa.conv2d' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
+  // expected-error at +1 {{'tosa.conv2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x29x0x4xf32>'}}
   %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
            : (tensor<1x29x0x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>) -> tensor<1x27x27x16xf32>
   return %0 : tensor<1x27x27x16xf32>
@@ -350,8 +402,8 @@ func.func @test_conv2d_static_zero_dim_input(%arg0: tensor<1x29x0x4xf32>, %arg1:
 // -----
 
 func.func @test_conv2d_zero_dim_input(%arg0: tensor<1x?x0x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<16xf32>) -> tensor<1x27x27x16xf32> {
-  // expected-error at +1 {{'tosa.conv2d' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
-  %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
+  // expected-error at +1 {{'tosa.conv2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x?x0x4xf32>'}}
+  %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
            : (tensor<1x?x0x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>) -> tensor<1x27x27x16xf32>
   return %0 : tensor<1x27x27x16xf32>
 }
@@ -360,7 +412,7 @@ func.func @test_conv2d_zero_dim_input(%arg0: tensor<1x?x0x4xf32>, %arg1: tensor<
 // -----
 
 func.func @test_avg_pool2d_static_zero_dim_input(%arg0: tensor<1x0x7x9xf32>) -> tensor<1x7x7x9xf32> {
-  // expected-error at +1 {{'tosa.avg_pool2d' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
+  // expected-error at +1 {{'tosa.avg_pool2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x0x7x9xf32>'}}
     %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
       : (tensor<1x0x7x9xf32>) -> tensor<1x7x7x9xf32>
     return %0 : tensor<1x7x7x9xf32>
@@ -369,7 +421,7 @@ func.func @test_avg_pool2d_static_zero_dim_input(%arg0: tensor<1x0x7x9xf32>) ->
 // -----
 
 func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32> {
-  // expected-error at +1 {{'tosa.avg_pool2d' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
+  // expected-error at +1 {{'tosa.avg_pool2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x0x?x9xf32>'}}
     %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
       : (tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32>
     return %0 : tensor<1x7x7x9xf32>
@@ -469,7 +521,7 @@ func.func @test_tile_io_rank_mismatch() {
 
 // CHECK-LABEL: @test_invalid_constant_permutation
 func.func @test_invalid_constant_permutation() {
-  // expected-error at +3 {{permutation must be within input bounds}}
+  // expected-error at +3 {{'tosa.transpose' op expected valid permutation tensor}}
   %0 = tensor.empty() : tensor<3x4x5xi32>
   %1 = arith.constant dense<[3, 0, 1]> : tensor<3xi32>
   %2 = tosa.transpose %0, %1 : (tensor<3x4x5xi32>, tensor<3xi32>) -> tensor<3x4x5xi32>
@@ -480,7 +532,7 @@ func.func @test_invalid_constant_permutation() {
 
 // CHECK-LABEL: test_rank_size_constant_permutation
 func.func @test_rank_size_constant_permutation() {
-  // expected-error at +4 {{permutation must be within input bounds}}
+  // expected-error at +4 {{'tosa.transpose' op expected valid permutation tensor}}
   %0 = arith.constant 6 : index
   %1 = arith.constant dense<[0, 2]> : tensor<2xi32>
   %2 = tensor.empty(%0) : tensor<?x27xi64>
@@ -492,7 +544,7 @@ func.func @test_rank_size_constant_permutation() {
 
 // CHECK-LABEL: test_large_constant_permutation
 func.func @test_large_constant_permutation() {
-  // expected-error at +4 {{permutation must be within input bounds}}
+  // expected-error at +4 {{'tosa.transpose' op expected valid permutation tensor}}
   %0 = arith.constant 6 : index
   %1 = arith.constant dense<[1185677355, 332462212]> : tensor<2xi32>
   %2 = tensor.empty(%0) : tensor<?x27xi64>
@@ -504,7 +556,7 @@ func.func @test_large_constant_permutation() {
 
 // CHECK-LABEL: test_table_rank0_table
 func.func @test_table_rank0_table(%arg0: tensor<64xi16>, %arg1: tensor<i16>) {
-  // expected-error at +1 {{'tosa.table' op operand #1 must be 1-d tensor, but got 'tensor<i16>'}}
+  // expected-error at +1 {{'tosa.table' op operand #1 must be 1-d tosa-conformant tensor, but got 'tensor<i16>'}}
   %0 = tosa.table %arg0, %arg1 : (tensor<64xi16>, tensor<i16>) -> tensor<64xi16>
   return
 }
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 525ee917ccd9fd..a1600fd33c54b4 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -573,6 +573,22 @@ func.func @test_transpose(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21xf32> {
   return %1 : tensor<3x13x21xf32>
 }
 
+// -----
+// CHECK-LABEL: transpose_dynamic_dim
+func.func @test_transpose_dynamic_dim(%arg0: tensor<13x?x3xf32>) -> tensor<3x13x?xf32> {
+  %0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
+  %1 = tosa.transpose %arg0, %0 : (tensor<13x?x3xf32>, tensor<3xi32>) -> tensor<3x13x?xf32>
+  return %1 : tensor<3x13x?xf32>
+}
+
+// -----
+// CHECK-LABEL: transpose_half_dynamic_dim
+func.func @test_transpose_half_dynamic_dim(%arg0: tensor<13x3x3xf32>) -> tensor<3x13x?xf32> {
+  %0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
+  %1 = tosa.transpose %arg0, %0 : (tensor<13x3x3xf32>, tensor<3xi32>) -> tensor<3x13x?xf32>
+  return %1 : tensor<3x13x?xf32>
+}
+
 // -----
 // CHECK-LABEL: gather
 func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<13x26x3xf32> {



More information about the Mlir-commits mailing list