[Mlir-commits] [mlir] [MLIR][TOSA] add additional verification to TOSA (PR #108133)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 10 19:03:36 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tosa

Author: Arteen Abrishami (arteen1000)

<details>
<summary>Changes</summary>

----------
Motivation:
----------

Spec conformance. Allows assumptions to be made in TOSA code.

------------
Changes Made:
------------

Add full permutation tensor verification to tosa.TRANSPOSE. Priorly would not verify that permuted values were between 0 - (rank - 1).

Update tosa.TRANSPOSE perms data type to be strictly i32.

Verify input/output shapes for tosa.TRANSPOSE.

Add verifier to tosa.CONST, with consideration for quantization.

Fix TOSA conformance of tensor type to disallow dimensions with size 0 for ranked tensors, per spec.
This is not the same as rank 0 tensors. Here is an example of a disallowed tensor: tensor<3x0xi32>. Naturally, this means that the number of elements in a TOSA tensor will always be greater than 0.

---

Patch is 51.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/108133.diff


12 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt (+2-2) 
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+27-31) 
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+43-18) 
- (modified) mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h (+13) 
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+12-9) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+42-22) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+75-60) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+5-5) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir (+1-1) 
- (modified) mlir/test/Dialect/Tosa/constant-op-fold.mlir (+2-2) 
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+63-11) 
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+16) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
index 12b4fc402c390f..1ee105f0ceb98b 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
@@ -3,8 +3,8 @@ add_mlir_doc(TosaOps TosaOps Dialects/ -gen-op-doc)
 add_mlir_interface(TosaInterfaces)
 
 set(LLVM_TARGET_DEFINITIONS TosaOps.td)
-mlir_tablegen(TosaAttributes.h.inc -gen-attrdef-decls)
-mlir_tablegen(TosaAttributes.cpp.inc -gen-attrdef-defs)
+mlir_tablegen(TosaAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=tosa)
+mlir_tablegen(TosaAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=tosa)
 add_public_tablegen_target(MLIRTosaAttributesIncGen)
 
 set(LLVM_TARGET_DEFINITIONS TosaDialectBytecode.td)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index ab6daa39708d13..63572f287b7dde 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -73,7 +73,6 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
 
   let arguments = (ins
     Tosa_Tensor4D:$input,
-
     Tosa_IntArrayAttr2:$kernel,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr4:$pad,
@@ -102,9 +101,8 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
 
   let arguments = (ins
     Tosa_Tensor4D:$input,
-    4DTensorOf<[Tosa_Weight]>:$weight,
+    TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
     Tosa_Tensor1D:$bias,
-
     Tosa_IntArrayAttr4:$pad,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr2:$dilation,
@@ -132,9 +130,8 @@ def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> {
 
   let arguments = (ins
     Tosa_Tensor5D:$input,
-    TensorRankOf<[Tosa_Weight], [5]>:$weight,
+    TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
     Tosa_Tensor1D:$bias,
-
     Tosa_IntArrayAttr6:$pad,
     Tosa_IntArrayAttr3:$stride,
     Tosa_IntArrayAttr3:$dilation,
@@ -163,9 +160,8 @@ def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> {
 
   let arguments = (ins
     Tosa_Tensor4D:$input,
-    4DTensorOf<[Tosa_Weight]>:$weight,
+    TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
     Tosa_Tensor1D:$bias,
-
     Tosa_IntArrayAttr4:$pad,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr2:$dilation,
@@ -232,7 +228,7 @@ def Tosa_FullyConnectedOp : Tosa_InferShapedTypeOp<"fully_connected"> {
 
   let arguments = (ins
     Tosa_Tensor2D:$input,
-    2DTensorOf<[Tosa_Weight]>:$weight,
+    TosaTensorRankOf<[Tosa_Weight], [2]>:$weight,
     Tosa_Tensor1D:$bias,
     OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info
   );
@@ -347,9 +343,8 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {
 
   let arguments = (ins
     Tosa_Tensor4D:$input,
-    4DTensorOf<[Tosa_Weight]>:$filter,
+    TosaTensorRankOf<[Tosa_Weight], [4]>:$filter,
     Tosa_Tensor1D:$bias,
-
     Tosa_IntArrayAttr4:$out_pad,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttrUpto4:$out_shape,
@@ -641,12 +636,12 @@ def Tosa_LogicalAndOp : Tosa_ElementwiseOp<"logical_and", [
   }];
 
   let arguments = (ins
-    I1Tensor:$input1,
-    I1Tensor:$input2
+    Tosa_I1Tensor:$input1,
+    Tosa_I1Tensor:$input2
   );
 
   let results = (outs
-    I1Tensor:$z
+    Tosa_I1Tensor:$z
   );
 }
 
@@ -708,12 +703,12 @@ def Tosa_LogicalOrOp : Tosa_ElementwiseOp<"logical_or", [
   }];
 
   let arguments = (ins
-    I1Tensor:$input1,
-    I1Tensor:$input2
+    Tosa_I1Tensor:$input1,
+    Tosa_I1Tensor:$input2
   );
 
   let results = (outs
-    I1Tensor:$z
+    Tosa_I1Tensor:$z
   );
 }
 
@@ -731,12 +726,12 @@ def Tosa_LogicalXorOp : Tosa_ElementwiseOp<"logical_xor", [
   }];
 
   let arguments = (ins
-    I1Tensor:$input1,
-    I1Tensor:$input2
+    Tosa_I1Tensor:$input1,
+    Tosa_I1Tensor:$input2
   );
 
   let results = (outs
-    I1Tensor:$z
+    Tosa_I1Tensor:$z
   );
 }
 
@@ -1085,11 +1080,11 @@ def Tosa_LogicalNotOp : Tosa_ElementwiseOp<"logical_not",
   }];
 
   let arguments = (ins
-    I1Tensor:$input1
+    Tosa_I1Tensor:$input1
   );
 
   let results = (outs
-    I1Tensor:$output
+    Tosa_I1Tensor:$output
   );
 }
 
@@ -1208,7 +1203,7 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
   }];
 
   let arguments = (ins
-    I1Tensor:$pred,
+    Tosa_I1Tensor:$pred,
     Tosa_Tensor:$on_true,
     Tosa_Tensor:$on_false
   );
@@ -1249,7 +1244,7 @@ def Tosa_EqualOp : Tosa_ElementwiseOp<"equal", [
   );
 
   let results = (outs
-    I1Tensor:$output
+    Tosa_I1Tensor:$output
   );
 
   let extraClassDeclaration = [{
@@ -1277,7 +1272,7 @@ def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [SameOperandsElementType]> {
   );
 
   let results = (outs
-    I1Tensor:$output
+    Tosa_I1Tensor:$output
   );
 
   let hasFolder = 1;
@@ -1300,7 +1295,7 @@ def Tosa_GreaterEqualOp : Tosa_ElementwiseOp<"greater_equal",
   );
 
   let results = (outs
-    I1Tensor:$output
+    Tosa_I1Tensor:$output
   );
 
   let hasFolder = 1;
@@ -1721,7 +1716,7 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
 
   let arguments = (ins
     Tosa_Tensor:$input1,
-    Tosa_Int32Or64Tensor:$perms
+    Tosa_Int32Tensor:$perms
   );
 
   let results = (
@@ -1729,7 +1724,7 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
   );
 
   let extraClassDeclaration = [{
-    LogicalResult getConstantPerms(llvm::SmallVector<int64_t> &perms);
+    LogicalResult getConstantPerms(llvm::SmallVector<int32_t> &perms);
   }];
 
   let hasCanonicalizer = 1;
@@ -1755,7 +1750,7 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
 
   let arguments = (ins
     Tosa_Tensor3D:$values,
-    2DTensorOf<[Tosa_Int32]>:$indices
+    TosaTensorRankOf<[Tosa_Int32], [2]>:$indices
   );
 
   let results = (outs
@@ -1776,7 +1771,7 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
 
   let arguments = (ins
     Tosa_Tensor3D:$values_in,
-    2DTensorOf<[Tosa_Int32]>:$indices,
+    TosaTensorRankOf<[Tosa_Int32], [2]>:$indices,
     Tosa_Tensor3D:$input
   );
 
@@ -1947,10 +1942,11 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
   );
 
   let results = (outs
-    TensorOf<[AnyTypeOf<[Tosa_AnyNumber]>]>:$output
+    TosaTensorOf<[AnyTypeOf<[Tosa_AnyNumber]>]>:$output
   );
 
   let hasFolder = 1;
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -2054,7 +2050,7 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
   }];
 
   let arguments = (ins
-    I1Tensor:$cond,
+    Tosa_I1Tensor:$cond,
     Variadic<Tosa_Tensor>:$inputs
   );
 
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 14fc9c7a6730cc..c3a0128e95a84b 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -82,58 +82,83 @@ def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
 def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
                              Tosa_QuantizedInt, AnyFloat]>;
 
+//===----------------------------------------------------------------------===//
+// TOSA Tensor Conformance
+//===----------------------------------------------------------------------===//
+
+def HasNo0Dimensions : And<[
+    IsRankedTensorTypePred,
+    CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v != 0; })">]>;
+
+class TosaTensorOf<
+    list<Type> allowedTypes, string summary = "tosa-conformant tensor">
+    : TensorOf<allowedTypes, [Or<[HasNo0Dimensions, IsUnrankedTensorTypePred]>], summary>;
+
+class TosaRankedTensorOf<
+    list<Type> allowedTypes, list<Pred> preds = [], string summary = "tosa-conformant ranked tensor">
+    : RankedTensorOf<allowedTypes, !listconcat([HasNo0Dimensions], preds), summary>;
+
+class TosaUnrankedTensorOf<list<Type> allowedTypes, list<Pred> preds = [], string summary = "tosa-conformant unranked tensor">
+    : UnrankedTensorOf<allowedTypes, preds, summary>;
+
+class TosaTensorRankOf<list<Type> allowedTypes, list<int> ranks>
+    : TosaRankedTensorOf<allowedTypes,
+      [HasAnyRankOfPred<ranks>],
+      !interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensor">;
+
 //===----------------------------------------------------------------------===//
 // Tensor types
 //===----------------------------------------------------------------------===//
 
-def Tosa_Int32Tensor : TensorOf<[Tosa_Int32]>;
-def Tosa_Int32Or64Tensor : TensorOf<[Tosa_Int32Or64]>;
+def Tosa_I1Tensor : TosaTensorOf<[I1]>;
+def Tosa_Int32Tensor : TosaTensorOf<[Tosa_Int32]>;
+def Tosa_Int32Or64Tensor :TosaTensorOf<[Tosa_Int32Or64]>;
 
-def Tosa_FloatTensor : TensorOf<[AnyFloat]>;
+def Tosa_FloatTensor : TosaTensorOf<[AnyFloat]>;
 
 // Either ranked or unranked tensor of TOSA supported element types.
-def Tosa_Tensor : TensorOf<[Tosa_AnyNumber]>;
+def Tosa_Tensor : TosaTensorOf<[Tosa_AnyNumber]>;
 
 // Must be ranked but no further constraints
-def Tosa_RankedTensor : RankedTensorOf<[Tosa_AnyNumber]>;
+def Tosa_RankedTensor : TosaRankedTensorOf<[Tosa_AnyNumber]>;
 
 // Any tensor element type allowed in Tosa ops.
 def Tosa_ElementType : Type<Or<[Tosa_Int.predicate, Tosa_QuantizedInt.predicate,
                                 AnyFloat.predicate]>, "tosa.dtype">;
 
 class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
-  AnyTypeOf<[TensorOf<allowedTypes>, NoneType], description>;
+  AnyTypeOf<[TosaTensorOf<allowedTypes>, NoneType], description>;
 
 //===----------------------------------------------------------------------===//
 // Tensor types with constrained ranks.
 //===----------------------------------------------------------------------===//
 
 // Rank-0 (scalar) tensor
-def Tosa_ScalarTensor : TensorRankOf<[Tosa_AnyNumber], [0]>;
+def Tosa_ScalarTensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;
 
 // We include unranked tensors as a supported type for all possible tosa
 // Tensors as unranked does not guarantee invalid. If unranked tensors exist
 // they should be shape propagate used Tosa's shape inference pass and verified
 // to not include any remaining unranked tensors.
-def Tosa_UnrankedTensor : UnrankedTensorOf<[Tosa_AnyNumber]>;
+def Tosa_UnrankedTensor : TosaUnrankedTensorOf<[Tosa_AnyNumber]>;
 
-def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, 1DTensorOf<[Tosa_AnyNumber]>], "1-d tensor", "::mlir::TensorType">;
-def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, 2DTensorOf<[Tosa_AnyNumber]>], "2-d tensor", "::mlir::TensorType">;
-def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, 3DTensorOf<[Tosa_AnyNumber]>], "3-d tensor", "::mlir::TensorType">;
-def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, 4DTensorOf<[Tosa_AnyNumber]>], "4-d tensor", "::mlir::TensorType">;
-def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tensor", "::mlir::TensorType">;
+def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1]>], "1-d tosa-conformant tensor", "::mlir::TensorType">;
+def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [2]>], "2-d tosa-conformant tensor", "::mlir::TensorType">;
+def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [3]>], "3-d tosa-conformant tensor", "::mlir::TensorType">;
+def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [4]>], "4-d tosa-conformant tensor", "::mlir::TensorType">;
+def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tosa-conformant tensor", "::mlir::TensorType">;
 
 // Ranked tensors up to given rank.
 def Tosa_Tensor1Dto4D : AnyTypeOf<[
-  Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>]>;
+  Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>]>;
 def Tosa_Tensor1Dto6D : AnyTypeOf<[
-  Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [1,2,3,4,5,6]>]>;
+  Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1,2,3,4,5,6]>]>;
 
 def Tosa_TensorUpto4D : AnyTypeOf<[
-  Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>]>;
+  Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>]>;
 
 def Tosa_Int32TensorUpto4D : AnyTypeOf<[
-  Tosa_UnrankedTensor, TensorRankOf<[Tosa_Int32], [0,1,2,3,4]>]>;
+  Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [0,1,2,3,4]>]>;
 
 //===----------------------------------------------------------------------===//
 // Generic scalar, vector, or tensor of a particular type.
@@ -142,7 +167,7 @@ def Tosa_Int32TensorUpto4D : AnyTypeOf<[
 class Tosa_TypeLike<list<Type> types, string description = ""> : TypeConstraint<Or<[
      AnyTypeOf<types>.predicate,
      VectorOf<types>.predicate,
-     TensorOf<types>.predicate]>,
+     TosaTensorOf<types>.predicate]>,
      description>;
 
 def Tosa_IntLike : Tosa_TypeLike<[Tosa_Int], "signless-integer-like">;
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
index ef40b348ab5499..90fea1f68beb58 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
@@ -216,6 +216,19 @@ TosaOp CreateOpAndInferShape(PatternRewriter &rewriter, Location loc,
   return CreateOpAndInferShape<TosaOp>(builder, resultTy, args...);
 }
 
+// Apply an int32_t permutation to some input, that should be of the same
+// size as perms. Perms should contain some permutation of 0 - perms.size() - 1.
+template <typename T>
+SmallVector<T> applyTOSAPermutation(ArrayRef<T> input,
+                                    ArrayRef<int32_t> perms) {
+  SmallVector<T> permuted;
+  size_t N = input.size();
+  permuted.resize_for_overwrite(N);
+  for (size_t i = 0; i < N; i++)
+    permuted[i] = input[perms[i]];
+  return permuted;
+}
+
 } // namespace tosa
 } // namespace mlir
 
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 77c3d2e8757910..fe53b499674324 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -313,7 +313,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
         // convolution operation.
         // TODO(suderman): See if this can be efficiently folded - check whether
         // the input is used anywhere else, if not fold the constant.
-        SmallVector<int64_t> weightPerm;
+        SmallVector<int32_t> weightPerm;
         for (int i = 1; i < resultTy.getRank(); i++)
           weightPerm.push_back(i);
         weightPerm.push_back(0);
@@ -321,7 +321,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
         SmallVector<int64_t> newWeightShape;
         for (auto dim : weightPerm)
           newWeightShape.push_back(weightShape[dim]);
-        auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm);
+        auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
         Value weightPermValue =
             rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
         Type newWeightTy =
@@ -337,7 +337,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
     if (5 == inputTy.getRank()) {
       // TODO(suderman): See if this can be efficiently folded - check whether
       // the input is used anywhere else, if not fold the constant.
-      SmallVector<int64_t> weightPerm;
+      SmallVector<int32_t> weightPerm;
       for (int i = 1; i < resultTy.getRank(); i++)
         weightPerm.push_back(i);
       weightPerm.push_back(0);
@@ -345,7 +345,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
       SmallVector<int64_t> newWeightShape;
       for (auto dim : weightPerm)
         newWeightShape.push_back(weightShape[dim]);
-      auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm);
+      auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
       Value weightPermValue =
           rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
       Type newWeightTy =
@@ -1040,22 +1040,25 @@ class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
 
   LogicalResult matchAndRewrite(tosa::TransposeOp op,
                                 PatternRewriter &rewriter) const final {
-    SmallVector<int64_t> constantPerms;
+    SmallVector<int32_t> constantPerms;
     if (failed(op.getConstantPerms(constantPerms)))
       return failure();
 
     Location loc = op.getLoc();
-    // The verifier should have made sure we have a valid permutation tensor.
-    assert(isPermutationVector(constantPerms) && "Expected valid permutation");
+    // The verifier should have made sure we have a valid TOSA permutation
+    // tensor. isPermutationVector doesn't actually check the TOSA perms we
+    // expect.
     SmallVector<OpFoldResult> inputSizes =
         tensor::getMixedSizes(rewriter, loc, op.getInput1());
     auto permutedSizes =
-        applyPermutation<OpFoldResult>(inputSizes, constantPerms);
+        applyTOSAPermutation<OpFoldResult>(inputSizes, constantPerms);
 
     auto permutedInit = rewriter.create<tensor::EmptyOp>(
         loc, permutedSizes, op.getInput1().getType().getElementType());
     rewriter.replaceOpWithNewOp<linalg::TransposeOp>(
-        op, op.getInput1(), permutedInit, constantPerms);
+        op, op.getInput1(), permutedInit,
+        llvm::to_vector(llvm::map_range(
+            constantPerms, [](int32_t v) -> int64_t { return v; })));
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index da9a93feac4d65..03876a7c64d07c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -88,7 +88,7 @@ struct ConsolidateTransposeOptimization
       return rewriter.notifyMatchFailure(transposeOp,
                                          "input must be transpose operation");
 
-    SmallVector<int64_t> transposePerms, innerTransposePerms;
+    SmallVector<int32_t> transposePerms, innerTransposePerms;
     if (transposeOp.getConstantPerms(transposePerms).failed())
       return rewriter.notifyMatchFailure(transposeOp,
                                          "transpose perms must be constant");
@@ -497,8 +497,10 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
     return {};
 
   auto resultETy = resultTy.getElementType();
-  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
-  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
+  auto lhsAttr =
+      llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+  auto rhsAttr =
+      llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
 
   if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
     return getInput1();
@@ -536,8 +538,10 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
 
   // IntDivOp inputs must be integer type, no need to check for quantized type
   auto resultETy = resultTy.getElementType();
-  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
-  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
+  auto lhsAttr =
+      llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+  auto rhsAttr =
+      llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
   if (lhsAttr && lhsAttr.isSplat()) {
     if (llvm::isa<IntegerType>(resultETy) &&
         lhsAttr.getSplatValue<APInt>().isZero())
@@ -605,10 +609,13 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
     return {};
 
   auto resultETy = resultTy.getElementType();
-  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
-  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
+  auto lhsAttr =
+      llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+  auto rhsAttr =
+      llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
 
   const int64_t shift = llvm::isa<IntegerType>(resultETy) ? getShift() : 0;
+
   if (rhsTy == resultTy) {
     if (isSplatZero(resultETy, lhsAttr))
       return lhsAttr.resizeSplat(resultTy);
@@ -638,8 +645,10 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
     return {};
 
   auto resultETy = resultTy....
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/108133


More information about the Mlir-commits mailing list