[Mlir-commits] [mlir] [MLIR][TOSA] add additional verification to TOSA (PR #108133)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Sep 10 19:03:36 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-tosa
Author: Arteen Abrishami (arteen1000)
<details>
<summary>Changes</summary>
----------
Motivation:
----------
Spec conformance. Allows assumptions to be made in TOSA code.
------------
Changes Made:
------------
Add full permutation tensor verification to tosa.TRANSPOSE. Priorly would not verify that permuted values were between 0 - (rank - 1).
Update tosa.TRANSPOSE perms data type to be strictly i32.
Verify input/output shapes for tosa.TRANSPOSE.
Add verifier to tosa.CONST, with consideration for quantization.
Fix TOSA conformance of tensor type to disallow dimensions with size 0 for ranked tensors, per spec.
This is not the same as rank 0 tensors. Here is an example of a disallowed tensor: tensor<3x0xi32>. Naturally, this means that the number of elements in a TOSA tensor will always be greater than 0.
---
Patch is 51.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/108133.diff
12 Files Affected:
- (modified) mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt (+2-2)
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+27-31)
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+43-18)
- (modified) mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h (+13)
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+12-9)
- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+42-22)
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+75-60)
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+5-5)
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir (+1-1)
- (modified) mlir/test/Dialect/Tosa/constant-op-fold.mlir (+2-2)
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+63-11)
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+16)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
index 12b4fc402c390f..1ee105f0ceb98b 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
@@ -3,8 +3,8 @@ add_mlir_doc(TosaOps TosaOps Dialects/ -gen-op-doc)
add_mlir_interface(TosaInterfaces)
set(LLVM_TARGET_DEFINITIONS TosaOps.td)
-mlir_tablegen(TosaAttributes.h.inc -gen-attrdef-decls)
-mlir_tablegen(TosaAttributes.cpp.inc -gen-attrdef-defs)
+mlir_tablegen(TosaAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=tosa)
+mlir_tablegen(TosaAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=tosa)
add_public_tablegen_target(MLIRTosaAttributesIncGen)
set(LLVM_TARGET_DEFINITIONS TosaDialectBytecode.td)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index ab6daa39708d13..63572f287b7dde 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -73,7 +73,6 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
let arguments = (ins
Tosa_Tensor4D:$input,
-
Tosa_IntArrayAttr2:$kernel,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr4:$pad,
@@ -102,9 +101,8 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
let arguments = (ins
Tosa_Tensor4D:$input,
- 4DTensorOf<[Tosa_Weight]>:$weight,
+ TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,
-
Tosa_IntArrayAttr4:$pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr2:$dilation,
@@ -132,9 +130,8 @@ def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> {
let arguments = (ins
Tosa_Tensor5D:$input,
- TensorRankOf<[Tosa_Weight], [5]>:$weight,
+ TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
Tosa_Tensor1D:$bias,
-
Tosa_IntArrayAttr6:$pad,
Tosa_IntArrayAttr3:$stride,
Tosa_IntArrayAttr3:$dilation,
@@ -163,9 +160,8 @@ def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> {
let arguments = (ins
Tosa_Tensor4D:$input,
- 4DTensorOf<[Tosa_Weight]>:$weight,
+ TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,
-
Tosa_IntArrayAttr4:$pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr2:$dilation,
@@ -232,7 +228,7 @@ def Tosa_FullyConnectedOp : Tosa_InferShapedTypeOp<"fully_connected"> {
let arguments = (ins
Tosa_Tensor2D:$input,
- 2DTensorOf<[Tosa_Weight]>:$weight,
+ TosaTensorRankOf<[Tosa_Weight], [2]>:$weight,
Tosa_Tensor1D:$bias,
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info
);
@@ -347,9 +343,8 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {
let arguments = (ins
Tosa_Tensor4D:$input,
- 4DTensorOf<[Tosa_Weight]>:$filter,
+ TosaTensorRankOf<[Tosa_Weight], [4]>:$filter,
Tosa_Tensor1D:$bias,
-
Tosa_IntArrayAttr4:$out_pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttrUpto4:$out_shape,
@@ -641,12 +636,12 @@ def Tosa_LogicalAndOp : Tosa_ElementwiseOp<"logical_and", [
}];
let arguments = (ins
- I1Tensor:$input1,
- I1Tensor:$input2
+ Tosa_I1Tensor:$input1,
+ Tosa_I1Tensor:$input2
);
let results = (outs
- I1Tensor:$z
+ Tosa_I1Tensor:$z
);
}
@@ -708,12 +703,12 @@ def Tosa_LogicalOrOp : Tosa_ElementwiseOp<"logical_or", [
}];
let arguments = (ins
- I1Tensor:$input1,
- I1Tensor:$input2
+ Tosa_I1Tensor:$input1,
+ Tosa_I1Tensor:$input2
);
let results = (outs
- I1Tensor:$z
+ Tosa_I1Tensor:$z
);
}
@@ -731,12 +726,12 @@ def Tosa_LogicalXorOp : Tosa_ElementwiseOp<"logical_xor", [
}];
let arguments = (ins
- I1Tensor:$input1,
- I1Tensor:$input2
+ Tosa_I1Tensor:$input1,
+ Tosa_I1Tensor:$input2
);
let results = (outs
- I1Tensor:$z
+ Tosa_I1Tensor:$z
);
}
@@ -1085,11 +1080,11 @@ def Tosa_LogicalNotOp : Tosa_ElementwiseOp<"logical_not",
}];
let arguments = (ins
- I1Tensor:$input1
+ Tosa_I1Tensor:$input1
);
let results = (outs
- I1Tensor:$output
+ Tosa_I1Tensor:$output
);
}
@@ -1208,7 +1203,7 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
}];
let arguments = (ins
- I1Tensor:$pred,
+ Tosa_I1Tensor:$pred,
Tosa_Tensor:$on_true,
Tosa_Tensor:$on_false
);
@@ -1249,7 +1244,7 @@ def Tosa_EqualOp : Tosa_ElementwiseOp<"equal", [
);
let results = (outs
- I1Tensor:$output
+ Tosa_I1Tensor:$output
);
let extraClassDeclaration = [{
@@ -1277,7 +1272,7 @@ def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [SameOperandsElementType]> {
);
let results = (outs
- I1Tensor:$output
+ Tosa_I1Tensor:$output
);
let hasFolder = 1;
@@ -1300,7 +1295,7 @@ def Tosa_GreaterEqualOp : Tosa_ElementwiseOp<"greater_equal",
);
let results = (outs
- I1Tensor:$output
+ Tosa_I1Tensor:$output
);
let hasFolder = 1;
@@ -1721,7 +1716,7 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
let arguments = (ins
Tosa_Tensor:$input1,
- Tosa_Int32Or64Tensor:$perms
+ Tosa_Int32Tensor:$perms
);
let results = (
@@ -1729,7 +1724,7 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
);
let extraClassDeclaration = [{
- LogicalResult getConstantPerms(llvm::SmallVector<int64_t> &perms);
+ LogicalResult getConstantPerms(llvm::SmallVector<int32_t> &perms);
}];
let hasCanonicalizer = 1;
@@ -1755,7 +1750,7 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
let arguments = (ins
Tosa_Tensor3D:$values,
- 2DTensorOf<[Tosa_Int32]>:$indices
+ TosaTensorRankOf<[Tosa_Int32], [2]>:$indices
);
let results = (outs
@@ -1776,7 +1771,7 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
let arguments = (ins
Tosa_Tensor3D:$values_in,
- 2DTensorOf<[Tosa_Int32]>:$indices,
+ TosaTensorRankOf<[Tosa_Int32], [2]>:$indices,
Tosa_Tensor3D:$input
);
@@ -1947,10 +1942,11 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
);
let results = (outs
- TensorOf<[AnyTypeOf<[Tosa_AnyNumber]>]>:$output
+ TosaTensorOf<[AnyTypeOf<[Tosa_AnyNumber]>]>:$output
);
let hasFolder = 1;
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -2054,7 +2050,7 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
}];
let arguments = (ins
- I1Tensor:$cond,
+ Tosa_I1Tensor:$cond,
Variadic<Tosa_Tensor>:$inputs
);
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 14fc9c7a6730cc..c3a0128e95a84b 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -82,58 +82,83 @@ def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
Tosa_QuantizedInt, AnyFloat]>;
+//===----------------------------------------------------------------------===//
+// TOSA Tensor Conformance
+//===----------------------------------------------------------------------===//
+
+def HasNo0Dimensions : And<[
+ IsRankedTensorTypePred,
+ CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v != 0; })">]>;
+
+class TosaTensorOf<
+ list<Type> allowedTypes, string summary = "tosa-conformant tensor">
+ : TensorOf<allowedTypes, [Or<[HasNo0Dimensions, IsUnrankedTensorTypePred]>], summary>;
+
+class TosaRankedTensorOf<
+ list<Type> allowedTypes, list<Pred> preds = [], string summary = "tosa-conformant ranked tensor">
+ : RankedTensorOf<allowedTypes, !listconcat([HasNo0Dimensions], preds), summary>;
+
+class TosaUnrankedTensorOf<list<Type> allowedTypes, list<Pred> preds = [], string summary = "tosa-conformant unranked tensor">
+ : UnrankedTensorOf<allowedTypes, preds, summary>;
+
+class TosaTensorRankOf<list<Type> allowedTypes, list<int> ranks>
+ : TosaRankedTensorOf<allowedTypes,
+ [HasAnyRankOfPred<ranks>],
+ !interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensor">;
+
//===----------------------------------------------------------------------===//
// Tensor types
//===----------------------------------------------------------------------===//
-def Tosa_Int32Tensor : TensorOf<[Tosa_Int32]>;
-def Tosa_Int32Or64Tensor : TensorOf<[Tosa_Int32Or64]>;
+def Tosa_I1Tensor : TosaTensorOf<[I1]>;
+def Tosa_Int32Tensor : TosaTensorOf<[Tosa_Int32]>;
+def Tosa_Int32Or64Tensor :TosaTensorOf<[Tosa_Int32Or64]>;
-def Tosa_FloatTensor : TensorOf<[AnyFloat]>;
+def Tosa_FloatTensor : TosaTensorOf<[AnyFloat]>;
// Either ranked or unranked tensor of TOSA supported element types.
-def Tosa_Tensor : TensorOf<[Tosa_AnyNumber]>;
+def Tosa_Tensor : TosaTensorOf<[Tosa_AnyNumber]>;
// Must be ranked but no further constraints
-def Tosa_RankedTensor : RankedTensorOf<[Tosa_AnyNumber]>;
+def Tosa_RankedTensor : TosaRankedTensorOf<[Tosa_AnyNumber]>;
// Any tensor element type allowed in Tosa ops.
def Tosa_ElementType : Type<Or<[Tosa_Int.predicate, Tosa_QuantizedInt.predicate,
AnyFloat.predicate]>, "tosa.dtype">;
class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
- AnyTypeOf<[TensorOf<allowedTypes>, NoneType], description>;
+ AnyTypeOf<[TosaTensorOf<allowedTypes>, NoneType], description>;
//===----------------------------------------------------------------------===//
// Tensor types with constrained ranks.
//===----------------------------------------------------------------------===//
// Rank-0 (scalar) tensor
-def Tosa_ScalarTensor : TensorRankOf<[Tosa_AnyNumber], [0]>;
+def Tosa_ScalarTensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;
// We include unranked tensors as a supported type for all possible tosa
// Tensors as unranked does not guarantee invalid. If unranked tensors exist
// they should be shape propagate used Tosa's shape inference pass and verified
// to not include any remaining unranked tensors.
-def Tosa_UnrankedTensor : UnrankedTensorOf<[Tosa_AnyNumber]>;
+def Tosa_UnrankedTensor : TosaUnrankedTensorOf<[Tosa_AnyNumber]>;
-def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, 1DTensorOf<[Tosa_AnyNumber]>], "1-d tensor", "::mlir::TensorType">;
-def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, 2DTensorOf<[Tosa_AnyNumber]>], "2-d tensor", "::mlir::TensorType">;
-def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, 3DTensorOf<[Tosa_AnyNumber]>], "3-d tensor", "::mlir::TensorType">;
-def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, 4DTensorOf<[Tosa_AnyNumber]>], "4-d tensor", "::mlir::TensorType">;
-def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tensor", "::mlir::TensorType">;
+def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1]>], "1-d tosa-conformant tensor", "::mlir::TensorType">;
+def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [2]>], "2-d tosa-conformant tensor", "::mlir::TensorType">;
+def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [3]>], "3-d tosa-conformant tensor", "::mlir::TensorType">;
+def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [4]>], "4-d tosa-conformant tensor", "::mlir::TensorType">;
+def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tosa-conformant tensor", "::mlir::TensorType">;
// Ranked tensors up to given rank.
def Tosa_Tensor1Dto4D : AnyTypeOf<[
- Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>]>;
+ Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>]>;
def Tosa_Tensor1Dto6D : AnyTypeOf<[
- Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [1,2,3,4,5,6]>]>;
+ Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1,2,3,4,5,6]>]>;
def Tosa_TensorUpto4D : AnyTypeOf<[
- Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>]>;
+ Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>]>;
def Tosa_Int32TensorUpto4D : AnyTypeOf<[
- Tosa_UnrankedTensor, TensorRankOf<[Tosa_Int32], [0,1,2,3,4]>]>;
+ Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [0,1,2,3,4]>]>;
//===----------------------------------------------------------------------===//
// Generic scalar, vector, or tensor of a particular type.
@@ -142,7 +167,7 @@ def Tosa_Int32TensorUpto4D : AnyTypeOf<[
class Tosa_TypeLike<list<Type> types, string description = ""> : TypeConstraint<Or<[
AnyTypeOf<types>.predicate,
VectorOf<types>.predicate,
- TensorOf<types>.predicate]>,
+ TosaTensorOf<types>.predicate]>,
description>;
def Tosa_IntLike : Tosa_TypeLike<[Tosa_Int], "signless-integer-like">;
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
index ef40b348ab5499..90fea1f68beb58 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
@@ -216,6 +216,19 @@ TosaOp CreateOpAndInferShape(PatternRewriter &rewriter, Location loc,
return CreateOpAndInferShape<TosaOp>(builder, resultTy, args...);
}
+// Apply an int32_t permutation to some input, that should be of the same
+// size as perms. Perms should contain some permutation of 0 - perms.size() - 1.
+template <typename T>
+SmallVector<T> applyTOSAPermutation(ArrayRef<T> input,
+ ArrayRef<int32_t> perms) {
+ SmallVector<T> permuted;
+ size_t N = input.size();
+ permuted.resize_for_overwrite(N);
+ for (size_t i = 0; i < N; i++)
+ permuted[i] = input[perms[i]];
+ return permuted;
+}
+
} // namespace tosa
} // namespace mlir
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 77c3d2e8757910..fe53b499674324 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -313,7 +313,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
// convolution operation.
// TODO(suderman): See if this can be efficiently folded - check whether
// the input is used anywhere else, if not fold the constant.
- SmallVector<int64_t> weightPerm;
+ SmallVector<int32_t> weightPerm;
for (int i = 1; i < resultTy.getRank(); i++)
weightPerm.push_back(i);
weightPerm.push_back(0);
@@ -321,7 +321,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
SmallVector<int64_t> newWeightShape;
for (auto dim : weightPerm)
newWeightShape.push_back(weightShape[dim]);
- auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm);
+ auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
Value weightPermValue =
rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
Type newWeightTy =
@@ -337,7 +337,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
if (5 == inputTy.getRank()) {
// TODO(suderman): See if this can be efficiently folded - check whether
// the input is used anywhere else, if not fold the constant.
- SmallVector<int64_t> weightPerm;
+ SmallVector<int32_t> weightPerm;
for (int i = 1; i < resultTy.getRank(); i++)
weightPerm.push_back(i);
weightPerm.push_back(0);
@@ -345,7 +345,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
SmallVector<int64_t> newWeightShape;
for (auto dim : weightPerm)
newWeightShape.push_back(weightShape[dim]);
- auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm);
+ auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
Value weightPermValue =
rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
Type newWeightTy =
@@ -1040,22 +1040,25 @@ class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
LogicalResult matchAndRewrite(tosa::TransposeOp op,
PatternRewriter &rewriter) const final {
- SmallVector<int64_t> constantPerms;
+ SmallVector<int32_t> constantPerms;
if (failed(op.getConstantPerms(constantPerms)))
return failure();
Location loc = op.getLoc();
- // The verifier should have made sure we have a valid permutation tensor.
- assert(isPermutationVector(constantPerms) && "Expected valid permutation");
+ // The verifier should have made sure we have a valid TOSA permutation
+ // tensor. isPermutationVector doesn't actually check the TOSA perms we
+ // expect.
SmallVector<OpFoldResult> inputSizes =
tensor::getMixedSizes(rewriter, loc, op.getInput1());
auto permutedSizes =
- applyPermutation<OpFoldResult>(inputSizes, constantPerms);
+ applyTOSAPermutation<OpFoldResult>(inputSizes, constantPerms);
auto permutedInit = rewriter.create<tensor::EmptyOp>(
loc, permutedSizes, op.getInput1().getType().getElementType());
rewriter.replaceOpWithNewOp<linalg::TransposeOp>(
- op, op.getInput1(), permutedInit, constantPerms);
+ op, op.getInput1(), permutedInit,
+ llvm::to_vector(llvm::map_range(
+ constantPerms, [](int32_t v) -> int64_t { return v; })));
return success();
}
};
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index da9a93feac4d65..03876a7c64d07c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -88,7 +88,7 @@ struct ConsolidateTransposeOptimization
return rewriter.notifyMatchFailure(transposeOp,
"input must be transpose operation");
- SmallVector<int64_t> transposePerms, innerTransposePerms;
+ SmallVector<int32_t> transposePerms, innerTransposePerms;
if (transposeOp.getConstantPerms(transposePerms).failed())
return rewriter.notifyMatchFailure(transposeOp,
"transpose perms must be constant");
@@ -497,8 +497,10 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
return {};
auto resultETy = resultTy.getElementType();
- auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
- auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
+ auto lhsAttr =
+ llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+ auto rhsAttr =
+ llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
return getInput1();
@@ -536,8 +538,10 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
// IntDivOp inputs must be integer type, no need to check for quantized type
auto resultETy = resultTy.getElementType();
- auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
- auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
+ auto lhsAttr =
+ llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+ auto rhsAttr =
+ llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
if (lhsAttr && lhsAttr.isSplat()) {
if (llvm::isa<IntegerType>(resultETy) &&
lhsAttr.getSplatValue<APInt>().isZero())
@@ -605,10 +609,13 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
return {};
auto resultETy = resultTy.getElementType();
- auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
- auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
+ auto lhsAttr =
+ llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+ auto rhsAttr =
+ llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
const int64_t shift = llvm::isa<IntegerType>(resultETy) ? getShift() : 0;
+
if (rhsTy == resultTy) {
if (isSplatZero(resultETy, lhsAttr))
return lhsAttr.resizeSplat(resultTy);
@@ -638,8 +645,10 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
return {};
auto resultETy = resultTy....
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/108133
More information about the Mlir-commits
mailing list