[Mlir-commits] [mlir] [mlir][tosa] Change Transpose perms operand to attribute (PR #128115)
Tai Ly
llvmlistbot at llvm.org
Thu Feb 20 18:59:00 PST 2025
https://github.com/Tai78641 created https://github.com/llvm/llvm-project/pull/128115
This patch changes the perms operand for Tosa Transpose operator to an i32 array attribute
>From 3add2919a09702a0280b4e125b342c21db15aa86 Mon Sep 17 00:00:00 2001
From: Tai Ly <tai.ly at arm.com>
Date: Wed, 11 Dec 2024 18:46:41 +0000
Subject: [PATCH] [mlir][tosa] Change Transpose perms operand to attribute
This patch changes the perms operand for Tosa Transpose operator to
an i32 array attribute
Signed-off-by: Tai Ly <tai.ly at arm.com>
Change-Id: I6c54d203ee7eb8d77442ff29fdbff2d2e3e6950b
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 6 +-
.../TosaToLinalg/TosaToLinalgNamed.cpp | 16 +-
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 31 +--
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 141 ++++------
.../Transforms/TosaDecomposeTransposeConv.cpp | 12 +-
.../Dialect/Tosa/Transforms/TosaFolders.cpp | 7 +-
.../Tosa/Transforms/TosaReduceTransposes.cpp | 15 +-
.../Tosa/Transforms/TosaValidation.cpp | 10 -
.../TosaToLinalg/tosa-to-linalg-named.mlir | 16 +-
mlir/test/Dialect/MemRef/resolve-dim-ops.mlir | 6 +-
mlir/test/Dialect/Tosa/availability.mlir | 3 +-
mlir/test/Dialect/Tosa/canonicalize.mlir | 25 +-
mlir/test/Dialect/Tosa/constant-op-fold.mlir | 46 ++-
mlir/test/Dialect/Tosa/invalid.mlir | 68 ++---
mlir/test/Dialect/Tosa/level_check.mlir | 3 +-
mlir/test/Dialect/Tosa/ops.mlir | 9 +-
.../Tosa/tosa-decompose-transpose-conv.mlir | 18 +-
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 23 +-
.../Dialect/Tosa/tosa-reduce-transposes.mlir | 262 ++++++------------
mlir/test/Dialect/Tosa/transpose-fold.mlir | 27 +-
20 files changed, 243 insertions(+), 501 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 3de1c21f40b43..a06e03f831985 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2011,7 +2011,7 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
let arguments = (ins
Tosa_Tensor:$input1,
- Tosa_Int32Tensor:$perms
+ DenseI32ArrayAttr:$perms
);
let results = (
@@ -2023,10 +2023,6 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
];
- let extraClassDeclaration = [{
- LogicalResult getConstantPerms(llvm::SmallVector<int32_t> &perms);
- }];
-
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index a8fd536dd2548..42e88ee9026ac 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -329,13 +329,11 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
SmallVector<int64_t> newWeightShape;
for (auto dim : weightPerm)
newWeightShape.push_back(weightShape[dim]);
- auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
- Value weightPermValue =
- rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
+ auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
Type newWeightTy =
RankedTensorType::get(newWeightShape, weightTy.getElementType());
weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
- weightPermValue);
+ weightPermAttr);
}
}
@@ -353,13 +351,11 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
SmallVector<int64_t> newWeightShape;
for (auto dim : weightPerm)
newWeightShape.push_back(weightShape[dim]);
- auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
- Value weightPermValue =
- rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
+ auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
Type newWeightTy =
RankedTensorType::get(newWeightShape, weightTy.getElementType());
weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
- weightPermValue);
+ weightPermAttr);
}
// Extract the attributes for convolution.
@@ -970,9 +966,7 @@ class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
LogicalResult matchAndRewrite(tosa::TransposeOp op,
PatternRewriter &rewriter) const final {
- SmallVector<int32_t> constantPerms;
- if (failed(op.getConstantPerms(constantPerms)))
- return failure();
+ const llvm::ArrayRef<int32_t> constantPerms = op.getPerms();
Location loc = op.getLoc();
// The verifier should have made sure we have a valid TOSA permutation
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 9bfc2aae1d6a5..8e2d8662ece8d 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -88,13 +88,10 @@ struct ConsolidateTransposeOptimization
return rewriter.notifyMatchFailure(transposeOp,
"input must be transpose operation");
- SmallVector<int32_t> transposePerms, innerTransposePerms;
- if (transposeOp.getConstantPerms(transposePerms).failed())
- return rewriter.notifyMatchFailure(transposeOp,
- "transpose perms must be constant");
- if (innerTranspose.getConstantPerms(innerTransposePerms).failed())
- return rewriter.notifyMatchFailure(
- transposeOp, "inner transpose perms must be constant");
+ const llvm::ArrayRef<int32_t> transposePerms = transposeOp.getPerms();
+ const llvm::ArrayRef<int32_t> innerTransposePerms =
+ innerTranspose.getPerms();
+
if (transposePerms.size() != innerTransposePerms.size())
return rewriter.notifyMatchFailure(
transposeOp,
@@ -108,15 +105,9 @@ struct ConsolidateTransposeOptimization
for (int i = 0, s = transposePerms.size(); i < s; ++i)
perms[i] = innerTransposePerms[transposePerms[i]];
- auto permsTy =
- RankedTensorType::get(transposePerms.size(), rewriter.getI32Type());
- auto permsAttr = DenseIntElementsAttr::get(permsTy, perms);
- Value permsValue = rewriter.create<tosa::ConstOp>(transposeOp.getLoc(),
- permsTy, permsAttr);
-
rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
transposeOp, transposeOp.getResult().getType(),
- innerTranspose.getInput1(), permsValue);
+ innerTranspose.getInput1(), rewriter.getDenseI32ArrayAttr(perms));
return success();
}
@@ -128,10 +119,6 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
LogicalResult matchAndRewrite(tosa::TransposeOp op,
PatternRewriter &rewriter) const override {
- DenseIntElementsAttr permAttr;
- if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
- return rewriter.notifyMatchFailure(op, "Non-constant permutation");
-
if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
return rewriter.notifyMatchFailure(
op, "Src is from transpose, can compose transposes");
@@ -156,9 +143,7 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
if (numDynDims > 1)
return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
- SmallVector<int64_t> permValues = llvm::to_vector<6>(
- llvm::map_range(permAttr.getValues<APInt>(),
- [](const APInt &val) { return val.getSExtValue(); }));
+ const llvm::ArrayRef<int32_t> permValues = op.getPerms();
SmallVector<int64_t> nonZeroPerms;
nonZeroPerms.reserve(permValues.size());
@@ -1175,9 +1160,7 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
}
// Transpose is not the identity transpose.
- SmallVector<int32_t> perms;
- if (getConstantPerms(perms).failed())
- return {};
+ const llvm::ArrayRef<int32_t> perms = getPerms();
if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
return {};
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e9c33e1b1bf10..7030dccd693a4 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1374,41 +1374,22 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
return mlir::success();
}
-LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int32_t> &perms) {
- // Perms must be constants.
- DenseIntElementsAttr permsAttr;
- if (!matchPattern(getPerms(), m_Constant(&permsAttr)))
- return failure();
-
- perms.clear();
- for (auto v : permsAttr.getValues<APInt>())
- perms.push_back(v.getSExtValue());
-
- return success();
-}
-
LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
TransposeOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape(adaptor.getInput1().getType());
- ShapeAdaptor permsShape(adaptor.getPerms().getType());
-
- // We cannot infer anything from a rank-0 "permutation" tensor.
- if (permsShape.hasRank() && permsShape.getRank() == 0)
- return failure();
// If input rank and permutation length is unknown, the output rank is
// unknown.
- if (!inputShape.hasRank() || !permsShape.hasRank() ||
- permsShape.isDynamicDim(0)) {
+ if (!inputShape.hasRank()) {
inferredReturnShapes.push_back(ShapedTypeComponents());
return success();
}
// This would imply the number of permutations does not match the rank of
// the input which is illegal.
- if (permsShape.getDimSize(0) != inputShape.getRank()) {
+ if (adaptor.getPerms().size() != static_cast<size_t>(inputShape.getRank())) {
return failure();
}
@@ -1437,28 +1418,16 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
}
outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
- // If the permuations are a constant we can directly determine the output
- // shape.
- DenseIntElementsAttr attr;
- if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) &&
- attr.getType().getRank() == 1) {
- ShapeAdaptor permShape = attr;
- // Constant permutation must be the same length as the input rank.
- if (inputShape.getRank() != permShape.getRank())
- return emitOptionalError(location,
- "constant permutation must be the same length"
- " as the input rank");
-
- // Constant permutation values must be within the input rank.
- for (int i = 0, e = inputShape.getRank(); i < e; i++) {
- if (inputShape.getRank() <= permShape.getDimSize(i))
- return failure();
- }
- outputShape.reserve(inputShape.getRank());
- for (int i = 0, s = inputShape.getRank(); i < s; i++) {
- outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i));
- }
+ // Constant permutation values must be within the input rank.
+ for (auto i : adaptor.getPerms()) {
+ if (inputShape.getRank() <= i)
+ return failure();
+ }
+
+ outputShape.reserve(inputShape.getRank());
+ for (int i = 0, s = inputShape.getRank(); i < s; i++) {
+ outputShape[i] = inputShape.getDimSize(adaptor.getPerms()[i]);
}
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
@@ -1467,75 +1436,61 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
LogicalResult tosa::TransposeOp::verify() {
TensorType inputType = getInput1().getType();
- TensorType permType = getPerms().getType();
TensorType outputType = getOutput().getType();
+ const llvm::ArrayRef<int32_t> constantPerms = getPerms();
- if (permType.hasRank() && permType.getRank() != 1)
- return emitOpError()
- << "expected permutation tensor to be rank 1 but got rank "
- << permType.getRank();
- if (inputType.hasRank() && permType.hasRank())
- if (!permType.isDynamicDim(0) &&
- permType.getDimSize(0) != inputType.getRank())
- return emitOpError() << "expected permutation tensor dim 0 to have size "
+ if (inputType.hasRank())
+ if (constantPerms.size() != static_cast<size_t>(inputType.getRank()))
+ return emitOpError() << "expected perms attribute to have size "
<< inputType.getRank()
<< " (input rank) but got size "
- << permType.getDimSize(0);
+ << constantPerms.size();
if (inputType.hasRank() && outputType.hasRank() &&
inputType.getRank() != outputType.getRank())
return emitOpError()
<< "expected input tensor rank to equal result tensor rank";
- if (outputType.hasRank() && permType.hasRank())
- if (!permType.isDynamicDim(0) &&
- permType.getDimSize(0) != outputType.getRank())
- return emitOpError() << "expected permutation tensor dim 0 to have size "
+ if (outputType.hasRank())
+ if (constantPerms.size() != static_cast<size_t>(outputType.getRank()))
+ return emitOpError() << "expected perms attribute to have size "
<< outputType.getRank()
<< " (output rank) but got size "
- << permType.getDimSize(0);
-
- SmallVector<int32_t> constantPerms;
- if (succeeded(getConstantPerms(constantPerms))) {
- // Assert that the permutation tensor has a rank, which means that the
- // rank has been verified above.
- assert(permType.hasRank() &&
- "Unexpectedly found permutation tensor without rank");
- if (!llvm::all_of(constantPerms,
- [&constantPerms](int32_t s) {
- return s >= 0 &&
- static_cast<size_t>(s) < constantPerms.size();
- }) ||
- !isPermutationVector(llvm::to_vector(llvm::map_range(
- constantPerms, [](int32_t v) -> int64_t { return v; }))))
- return emitOpError() << "expected valid permutation tensor";
-
- // Verify that the types of the input and output tensors are properly
- // permuted.
- if (inputType.hasRank() && outputType.hasRank()) {
- assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) &&
- inputType.getRank() == outputType.getRank());
-
- for (auto i = 0; i < outputType.getRank(); i++) {
- if (inputType.isDynamicDim(constantPerms[i]) ||
- outputType.isDynamicDim(i))
- continue;
-
- if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
- return emitOpError()
- << "expected output tensor dim " << i << " to match "
- << "input dim " << constantPerms[i] << " with value of "
- << inputType.getDimSize(constantPerms[i]);
- }
+ << constantPerms.size();
+
+ if (!llvm::all_of(constantPerms,
+ [&constantPerms](int32_t s) {
+ return s >= 0 &&
+ static_cast<size_t>(s) < constantPerms.size();
+ }) ||
+ !isPermutationVector(llvm::to_vector(llvm::map_range(
+ constantPerms, [](int32_t v) -> int64_t { return v; }))))
+ return emitOpError() << "expected valid permutation indices";
+
+ // Verify that the types of the input and output tensors are properly
+ // permuted.
+ if (inputType.hasRank() && outputType.hasRank()) {
+ assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) &&
+ inputType.getRank() == outputType.getRank());
+
+ for (auto i = 0; i < outputType.getRank(); i++) {
+ if (inputType.isDynamicDim(constantPerms[i]) ||
+ outputType.isDynamicDim(i))
+ continue;
+
+ if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
+ return emitOpError()
+ << "expected output tensor dim " << i << " to match "
+ << "input dim " << constantPerms[i] << " with value of "
+ << inputType.getDimSize(constantPerms[i]);
}
}
+
return success();
}
LogicalResult TransposeOp::reifyResultShapes(
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
- SmallVector<int32_t> transposePerms;
- if (getConstantPerms(transposePerms).failed())
- return failure();
+ const llvm::ArrayRef<int32_t> transposePerms = getPerms();
Value input = getInput1();
auto inputType = cast<TensorType>(input.getType());
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index 26baddcf1dd15..61011b6df4617 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -166,13 +166,9 @@ class TransposeConvStridedConverter
getTosaConstShape(rewriter, loc, weightReshapeDims0));
// Transpose the factored-out stride to the output channels.
- Value transposeWeightVal = rewriter.create<tosa::ConstOp>(
- loc, RankedTensorType::get({6}, rewriter.getI32Type()),
- rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5}));
-
weight = CreateOpAndInferShape<tosa::TransposeOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
- transposeWeightVal);
+ rewriter.getDenseI32ArrayAttr({2, 4, 0, 1, 3, 5}));
// Collapse the strides and output channels into a single dimension.
llvm::SmallVector<int64_t, 4> weightReshapeDims1 = {
@@ -269,13 +265,9 @@ class TransposeConvStridedConverter
convReshapeDims0Value);
// Transpose the factored-out stride to the output channels.
- Value transposeConvVal = rewriter.create<tosa::ConstOp>(
- loc, RankedTensorType::get({6}, rewriter.getI32Type()),
- rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5}));
-
conv2d = CreateOpAndInferShape<tosa::TransposeOp>(
rewriter, loc, UnrankedTensorType::get(convETy), conv2d,
- transposeConvVal);
+ rewriter.getDenseI32ArrayAttr({0, 1, 3, 2, 4, 5}));
// Fuse striding behavior back into width / height.
llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
index 403ac48b91559..43e9507b4d95a 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
@@ -224,13 +224,8 @@ struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->getUsers()))
return failure();
- DenseIntElementsAttr permAttr;
- if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
- return failure();
auto permValues = llvm::map_to_vector(
- // TOSA allows both 32- and 64-bit integer tensors here.
- permAttr.getValues<APInt>(),
- [](const APInt &val) { return val.getSExtValue(); });
+ op.getPerms(), [](const int32_t v) { return static_cast<int64_t>(v); });
auto inputType = cast<ShapedType>(op.getInput1().getType());
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
index 64e5c31793f84..d4d8aae8b0316 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
@@ -367,9 +367,7 @@ std::optional<Value> TosaReduceTransposes::buildMappedToValue(
std::optional<Value> TosaReduceTransposes::buildMappedToValue(
TransposeOp transposeOp, const DenseMap<Value, Value> &valuesMap,
IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
- SmallVector<int32_t> perms;
- if (failed(transposeOp.getConstantPerms(perms)) ||
- !areInvolutionTransposes(hoistedPerms, perms))
+ if (!areInvolutionTransposes(hoistedPerms, transposeOp.getPerms()))
return std::nullopt;
return transposeOp.getInput1();
}
@@ -506,14 +504,11 @@ bool TosaReduceTransposes::dependenciesAreValid(
// replaced.
Operation *user = use.getOwner();
if (auto otherTranspose = llvm::dyn_cast<TransposeOp>(user)) {
- SmallVector<int32_t> otherPerms;
-
// Can later think about cases where transpose -> transpose
// or reshape -> transpose, where the transposes are not necessarily
// the same perms as the hoisted, if implementing a more general
// transform. These could be permitted.
- if (failed(otherTranspose.getConstantPerms(otherPerms)) ||
- !llvm::equal(perms, otherPerms))
+ if (!llvm::equal(perms, otherTranspose.getPerms()))
return false;
} else if (userNotContainedInValidTransposeDependencies(
user, validTransposes, transposeInfo)) {
@@ -607,9 +602,9 @@ void TosaReduceTransposes::runOnOperation() {
!llvm::isa<RankedTensorType>(output.getType()))
return;
- // No transformation when transpose permutation non-constant.
- if (failed(transposeOp.getConstantPerms(perms)))
- return;
+ for (int32_t v : transposeOp.getPerms()) {
+ perms.push_back(v);
+ }
// We let --canonicalize deal with identity transpose.
if (llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index f74a4b4c58b80..f2abb29b4fe66 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -56,15 +56,6 @@ static LogicalResult checkConstantOperandPad(Operation *op) {
return success();
}
-static LogicalResult checkConstantOperandTranspose(Operation *op) {
- if (auto transposeOp = dyn_cast<tosa::TransposeOp>(op)) {
- DenseElementsAttr perms;
- if (!matchPattern(transposeOp.getPerms(), m_Constant(&perms)))
- return op->emitOpError("perms of transpose is not constant");
- }
- return success();
-}
-
struct TosaLevel {
int32_t MAX_RANK = 0;
int32_t MAX_KERNEL = 0;
@@ -118,7 +109,6 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
private:
void populateConstantOperandChecks() {
constCheckers.emplace_back(checkConstantOperandPad);
- constCheckers.emplace_back(checkConstantOperandTranspose);
}
bool levelCheckKernel(Operation *op, int32_t v,
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index a524359b49759..e0955513ccfd8 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -463,7 +463,6 @@ func.func @conv2d_scalar_bias_f32(%input: tensor<1x49x42x27xf32>, %weights: tens
// CHECK-LABEL: @conv2d_i8
func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi8>, %bias: tensor<28xi8>) -> () {
- // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi32>
// HWCF: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x1x1x27xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<1x1x27x28xi8>) permutation = [1, 2, 3, 0]
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xi32>
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xi8>) outs(%[[INIT]] : tensor<1x45x40x28xi32>) {
@@ -485,7 +484,6 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi
// CHECK-LABEL: @conv2d_f32
func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
- // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi32>
// HWCF: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x3x3x27xf32>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x3x27x28xf32>) permutation = [1, 2, 3, 0]
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xf32>
@@ -786,7 +784,6 @@ func.func @depthwise_conv2d_dyn_w_h(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<3x
// CHECK-LABEL: @conv3d_f32
func.func @conv3d_f32(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4x5x27xf32>, %bias: tensor<28xf32>) -> () {
- // CHECK-DAG: %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]>
// CHECK-DAG: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x3x4x5x27xf32>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x4x5x27x28xf32>) permutation = [1, 2, 3, 4, 0]
// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xf32>
// CHECK: %[[BROADCAST:.+]] = linalg.generic
@@ -824,7 +821,6 @@ func.func @conv3d_scalar_bias_f32(%input: tensor<1x49x48x47x27xf32>, %weights: t
// CHECK-LABEL: @conv3d_i8
func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<28x3x4x5x27xi8>, %bias: tensor<28xi32>) -> () {
- // CHECK-DAG: %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]>
// CHECK-DAG: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x3x4x5x27xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x4x5x27x28xi8>) permutation = [1, 2, 3, 4, 0]
// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xi32>
// CHECK: %[[BROADCAST:.+]] = linalg.generic
@@ -852,10 +848,9 @@ func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<28x3x4x5
// CHECK-LABEL: @test_transpose
// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x2x3xi32>)
func.func @test_transpose(%arg0: tensor<1x2x3xi32>) -> () {
- %0 = arith.constant dense<[1, 2, 0]> : tensor<3xi32>
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<2x3x1xi32>
// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<1x2x3xi32>) outs(%[[INIT]] : tensor<2x3x1xi32>) permutation = [1, 2, 0]
- %1 = tosa.transpose %arg0, %0 : (tensor<1x2x3xi32>, tensor<3xi32>) -> tensor<2x3x1xi32>
+ %1 = tosa.transpose %arg0 {perms = array<i32: 1, 2, 0>}: (tensor<1x2x3xi32>) -> tensor<2x3x1xi32>
return
}
@@ -864,12 +859,11 @@ func.func @test_transpose(%arg0: tensor<1x2x3xi32>) -> () {
// CHECK-LABEL: @test_transpose_dyn
// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x?x3x4xi32>)
func.func @test_transpose_dyn(%arg0: tensor<1x?x3x4xi32>) -> () {
- %0 = arith.constant dense<[1, 3, 0, 2]> : tensor<4xi32>
// CHECK: %[[C1:.+]] = arith.constant 1
// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]]) : tensor<?x4x1x3xi32>
// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<1x?x3x4xi32>) outs(%[[INIT]] : tensor<?x4x1x3xi32>) permutation = [1, 3, 0, 2]
- %1 = tosa.transpose %arg0, %0 : (tensor<1x?x3x4xi32>, tensor<4xi32>) -> tensor<?x4x1x3xi32>
+ %1 = tosa.transpose %arg0 {perms = array<i32: 1, 3, 0, 2>}: (tensor<1x?x3x4xi32>) -> tensor<?x4x1x3xi32>
return
}
@@ -878,14 +872,13 @@ func.func @test_transpose_dyn(%arg0: tensor<1x?x3x4xi32>) -> () {
// CHECK-LABEL: @test_transpose_dyn_multiple_2d
// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x?xf32>)
func.func @test_transpose_dyn_multiple_2d(%arg0: tensor<?x?xf32>) -> () {
- %0 = arith.constant dense<[1, 0]> : tensor<2xi32>
// CHECK-DAG: %[[C0:.+]] = arith.constant 0
// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[C1:.+]] = arith.constant 1
// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM1]], %[[DIM0]])
// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<?x?xf32>) outs(%[[INIT]] : tensor<?x?xf32>) permutation = [1, 0]
- %1 = tosa.transpose %arg0, %0 : (tensor<?x?xf32>, tensor<2xi32>) -> tensor<?x?xf32>
+ %1 = tosa.transpose %arg0 {perms = array<i32: 1, 0>}: (tensor<?x?xf32>) -> tensor<?x?xf32>
return
}
@@ -894,7 +887,6 @@ func.func @test_transpose_dyn_multiple_2d(%arg0: tensor<?x?xf32>) -> () {
// CHECK-LABEL: @test_transpose_dyn_multiple_3d
// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x?x?xf32>)
func.func @test_transpose_dyn_multiple_3d(%arg0: tensor<?x?x?xf32>) {
- %0 = arith.constant dense<[2, 0, 1]> : tensor<3xi32>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[DIM0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
@@ -903,6 +895,6 @@ func.func @test_transpose_dyn_multiple_3d(%arg0: tensor<?x?x?xf32>) {
// CHECK-DAG: %[[DIM2:.*]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM2]], %[[DIM0]], %[[DIM1]]) : tensor<?x?x?xf32>
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[INIT]] : tensor<?x?x?xf32>) permutation = [2, 0, 1]
- %1 = "tosa.transpose"(%arg0, %0) : (tensor<?x?x?xf32>, tensor<3xi32>) -> tensor<?x?x?xf32>
+ %1 = "tosa.transpose"(%arg0) {perms = array<i32: 2, 0, 1>} : (tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return
}
diff --git a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
index ef8b80f6b5c22..e354eb91d7557 100644
--- a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
+++ b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
@@ -34,8 +34,7 @@ func.func @dim_out_of_bounds_2(%idx1 : index, %idx2 : index) -> index {
// CHECK-NEXT: tensor.dim %[[arg]], %[[c2]]
// CHECK-NEXT: return
func.func @dynamic_dim_of_transpose_op(%arg0: tensor<1x2x?x8xi8>) -> index {
- %0 = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
- %1 = tosa.transpose %arg0, %0 : (tensor<1x2x?x8xi8>, tensor<4xi32>) -> tensor<1x8x2x?xi8>
+ %1 = tosa.transpose %arg0 { perms = array<i32: 0, 3, 1, 2> }: (tensor<1x2x?x8xi8>) -> tensor<1x8x2x?xi8>
%c3 = arith.constant 3 : index
%dim = tensor.dim %1, %c3 : tensor<1x8x2x?xi8>
return %dim : index
@@ -47,8 +46,7 @@ func.func @dynamic_dim_of_transpose_op(%arg0: tensor<1x2x?x8xi8>) -> index {
// CHECK: arith.constant 100 : index
// CHECK: return
func.func @static_dim_of_transpose_op(%arg0: tensor<1x100x?x8xi8>) -> index {
- %0 = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
- %1 = tosa.transpose %arg0, %0 : (tensor<1x100x?x8xi8>, tensor<4xi32>) -> tensor<1x8x100x?xi8>
+ %1 = tosa.transpose %arg0 { perms = array<i32: 0, 3, 1, 2> }: (tensor<1x100x?x8xi8>) -> tensor<1x8x100x?xi8>
%c2 = arith.constant 2 : index
%dim = tensor.dim %1, %c2 : tensor<1x8x100x?xi8>
return %dim : index
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index e66ff4cacfd89..879ee450cf88f 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -553,10 +553,9 @@ func.func @test_tile(%arg0: tensor<13x21x3xf32>) -> tensor<39x21x6xf32> {
// -----
// CHECK-LABEL: transpose
func.func @test_transpose(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21xf32> {
- %0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
// CHECK: profiles: [ [pro_int, pro_fp] ]
// CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ]
- %1 = tosa.transpose %arg0, %0 : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21xf32>
+ %1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>}: (tensor<13x21x3xf32>) -> tensor<3x13x21xf32>
return %1 : tensor<3x13x21xf32>
}
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 0e177a076ee7a..9f3a6057da844 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -631,12 +631,11 @@ func.func @reshape_canonicalize_quant_nofold() -> (tensor<1x3x!quant.uniform<i8:
// CHECK-LABEL: @transpose_canonicalize_strip_quant
func.func @transpose_canonicalize_strip_quant() -> (tensor<2x1x3x!quant.uniform<i8:f32, 1.000000e+00>>) {
- // CHECK-DAG: tosa.const_shape {value = dense<[2, 1, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
- // CHECK-DAG: "tosa.const"() <{value = dense<0> : tensor<1x2x3xi8>}> : () -> tensor<1x2x3x!quant.uniform<i8:f32, 1.000000e+00>>
- // CHECK: tosa.reshape %0, %1 : (tensor<1x2x3x!quant.uniform<i8:f32, 1.000000e+00>>, !tosa.shape<3>) -> tensor<2x1x3x!quant.uniform<i8:f32, 1.000000e+00>>
- %perms = "tosa.const"() {value = dense<[1, 0, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
+ // CHECK-DAG: %[[SHAPE:.*]] = tosa.const_shape {value = dense<[2, 1, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ // CHECK-DAG: %[[CONST:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x2x3xi8>}> : () -> tensor<1x2x3x!quant.uniform<i8:f32, 1.000000e+00>>
+ // CHECK: tosa.reshape %[[CONST]], %[[SHAPE]] : (tensor<1x2x3x!quant.uniform<i8:f32, 1.000000e+00>>, !tosa.shape<3>) -> tensor<2x1x3x!quant.uniform<i8:f32, 1.000000e+00>>
%0 = "tosa.const"() {value = dense<0> : tensor<1x2x3xi8>} : ()-> tensor<1x2x3x!quant.uniform<i8:f32, 1.000000e+00>>
- %1 = tosa.transpose %0, %perms : (tensor<1x2x3x!quant.uniform<i8:f32, 1.000000e+00>>, tensor<3xi32>) -> tensor<2x1x3x!quant.uniform<i8:f32, 1.000000e+00>>
+ %1 = tosa.transpose %0 { perms = array<i32: 1, 0, 2> }: (tensor<1x2x3x!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<2x1x3x!quant.uniform<i8:f32, 1.000000e+00>>
return %1 : tensor<2x1x3x!quant.uniform<i8:f32, 1.000000e+00>>
}
@@ -688,8 +687,7 @@ func.func @tile_nofold(%arg0: tensor<3x4xf32>) -> tensor<3x8xf32> {
func.func @transpose_no_op(%arg0: tensor<3x4x5x6xf32>) -> tensor<3x4x5x6xf32> {
// CHECK: return %arg0
// CHECK-NOT: tosa.transpose
- %perms = "tosa.const"() {value = dense<[0, 1, 2, 3]> : tensor<4xi32>} : () -> tensor<4xi32>
- %1 = tosa.transpose %arg0, %perms : (tensor<3x4x5x6xf32>, tensor<4xi32>) -> tensor<3x4x5x6xf32>
+ %1 = tosa.transpose %arg0 { perms = array<i32: 0, 1, 2, 3> }: (tensor<3x4x5x6xf32>) -> tensor<3x4x5x6xf32>
return %1 : tensor<3x4x5x6xf32>
}
@@ -699,13 +697,22 @@ func.func @transpose_no_op(%arg0: tensor<3x4x5x6xf32>) -> tensor<3x4x5x6xf32> {
func.func @transpose_is_reshape(%arg0: tensor<1x4x5x1xf32>) -> tensor<1x4x1x5xf32> {
// CHECK: %[[CONST0:.+]] = tosa.const_shape {value = dense<[1, 4, 1, 5]> : tensor<4xindex>} : () -> !tosa.shape<4>
// CHECK: tosa.reshape %arg0, %[[CONST0]]
- %perms = "tosa.const"() <{value = dense<[3, 1, 0, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
- %0 = tosa.transpose %arg0, %perms : (tensor<1x4x5x1xf32>, tensor<4xi32>) -> tensor<1x4x1x5xf32>
+ %0 = tosa.transpose %arg0 { perms = array<i32: 3, 1, 0, 2> }: (tensor<1x4x5x1xf32>) -> tensor<1x4x1x5xf32>
return %0 : tensor<1x4x1x5xf32>
}
// -----
+// CHECK-LABEL: @transpose_is_reshape_unknown_dim
+func.func @transpose_is_reshape_unknown_dim(%arg0: tensor<1x4x?x1xf32>) -> tensor<1x4x1x?xf32> {
+ // CHECK: %[[CONST0:.+]] = tosa.const_shape {value = dense<[1, 4, 1, -1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ // CHECK: tosa.reshape %arg0, %[[CONST0]]
+ %0 = tosa.transpose %arg0 { perms = array<i32: 3, 1, 0, 2> }: (tensor<1x4x?x1xf32>) -> tensor<1x4x1x?xf32>
+ return %0 : tensor<1x4x1x?xf32>
+}
+
+// -----
+
// CHECK-LABEL: @single_bit_reshape
// https://github.com/llvm/llvm-project/issues/55440
func.func @single_bit_reshape() -> tensor<1xi1> {
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index e6fb741df9598..190aa777d3470 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -20,34 +20,30 @@ func.func @argmax_dynamic_shape_no_fold_dim_size_1(%arg0: tensor<?x1x3xf32>) ->
// CHECK-LABEL: @transpose_fold
func.func @transpose_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
// CHECK: return %arg0
- %0 = arith.constant dense<[0, 1]> : tensor<2xi32>
- %1 = tosa.transpose %arg0, %0 { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor<3x4xf32>
+ %1 = tosa.transpose %arg0 { perms = array<i32: 0, 1> }: (tensor<3x4xf32>) -> tensor<3x4xf32>
return %1 : tensor<3x4xf32>
}
// CHECK-LABEL: @transpose_nofold
func.func @transpose_nofold(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> {
// CHECK: tosa.transpose
- %0 = arith.constant dense<[1, 0]> : tensor<2xi32>
- %1 = tosa.transpose %arg0, %0 { perms = [1, 0] }: (tensor<3x3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
+ %1 = tosa.transpose %arg0 { perms = array<i32: 1, 0> }: (tensor<3x3xf32>) -> tensor<3x3xf32>
return %1 : tensor<3x3xf32>
}
// CHECK-LABEL: @transpose_nofold_shape
func.func @transpose_nofold_shape(%arg0: tensor<3x4xf32>) -> tensor<?x?xf32> {
// CHECK: tosa.transpose
- %0 = arith.constant dense<[1, 0]> : tensor<2xi32>
- %1 = tosa.transpose %arg0, %0 { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor<?x?xf32>
+ %1 = tosa.transpose %arg0 { perms = array<i32: 1, 0> }: (tensor<3x4xf32>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK-LABEL: @transpose_fold_splat
func.func @transpose_fold_splat() -> tensor<3x2xf32> {
%input = "tosa.const"() {value = dense<4.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
- %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK: %[[CST:.+]] = "tosa.const"() <{
// CHECK-SAME{LITERAL}: value = dense<4.000000e+00> : tensor<3x2xf32>
- %1 = tosa.transpose %input, %perms : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
+ %1 = tosa.transpose %input { perms = array<i32: 1, 0> }: (tensor<2x3xf32>) -> tensor<3x2xf32>
// CHECK: return %[[CST]]
return %1 : tensor<3x2xf32>
}
@@ -55,10 +51,9 @@ func.func @transpose_fold_splat() -> tensor<3x2xf32> {
// CHECK-LABEL: @transpose_fold_2d_float
func.func @transpose_fold_2d_float() -> tensor<3x2xf32> {
%input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
- %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK: %[[CST:.+]] = "tosa.const"() <{
// CHECK-SAME{LITERAL}: value = dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32>
- %1 = tosa.transpose %input, %perms : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
+ %1 = tosa.transpose %input { perms = array<i32: 1, 0> }: (tensor<2x3xf32>) -> tensor<3x2xf32>
// CHECK: return %[[CST]]
return %1 : tensor<3x2xf32>
}
@@ -66,10 +61,9 @@ func.func @transpose_fold_2d_float() -> tensor<3x2xf32> {
// CHECK-LABEL: @transpose_fold_2d_bool
func.func @transpose_fold_2d_bool() -> tensor<3x2xi1> {
%input = "tosa.const"() {value = dense<[[true, false, false], [false, false, true]]> : tensor<2x3xi1>} : () -> tensor<2x3xi1>
- %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK: %[[CST:.+]] = "tosa.const"() <{
// CHECK-SAME{LITERAL}: value = dense<[[true, false], [false, false], [false, true]]> : tensor<3x2xi1>
- %1 = tosa.transpose %input, %perms : (tensor<2x3xi1>, tensor<2xi32>) -> tensor<3x2xi1>
+ %1 = tosa.transpose %input { perms = array<i32: 1, 0> }: (tensor<2x3xi1>) -> tensor<3x2xi1>
// CHECK: return %[[CST]]
return %1 : tensor<3x2xi1>
}
@@ -80,50 +74,46 @@ func.func @transpose_fold_4d_int() -> tensor<3x1x4x2xi32> {
[[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]],
[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
]]> : tensor<1x2x3x4xi32>} : () -> tensor<1x2x3x4xi32>
- %perms = "tosa.const"() {value = dense<[2, 0, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
// CHECK: %[[CST:.+]] = "tosa.const"() <{
// CHECK-SAME{LITERAL}: value = dense<[
// CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]],
// CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]],
// CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]]
// CHECK-SAME{LITERAL}: ]>
- %1 = tosa.transpose %input, %perms : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<3x1x4x2xi32>
+ %1 = tosa.transpose %input { perms = array<i32: 2, 0, 3, 1> }: (tensor<1x2x3x4xi32>) -> tensor<3x1x4x2xi32>
// CHECK: return %[[CST]]
return %1 : tensor<3x1x4x2xi32>
}
// CHECK-LABEL: @transpose_nofold_non_cst_input
func.func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>) -> tensor<3x2xf32> {
- %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK: tosa.transpose
- %1 = tosa.transpose %input, %perms : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
- return %1 : tensor<3x2xf32>
-}
-
-// CHECK-LABEL: @transpose_nofold_non_cst_perms
-func.func @transpose_nofold_non_cst_perms(%perms: tensor<2xi32>) -> tensor<3x2xf32> {
- %input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
- // CHECK: tosa.transpose
- %1 = tosa.transpose %input, %perms : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
+ %1 = tosa.transpose %input { perms = array<i32: 1, 0> }: (tensor<2x3xf32>) -> tensor<3x2xf32>
return %1 : tensor<3x2xf32>
}
// CHECK-LABEL: @transpose_nofold_multi_users
func.func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>) {
%input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
- %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK: tosa.transpose
- %1 = tosa.transpose %input, %perms : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
+ %1 = tosa.transpose %input { perms = array<i32: 1, 0> }: (tensor<2x3xf32>) -> tensor<3x2xf32>
return %1, %input : tensor<3x2xf32>, tensor<2x3xf32>
}
+// CHECK-LABEL: @transpose_nofold_quantized_types
+func.func @transpose_nofold_quantized_types() -> tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>> {
+ %input = "tosa.const"() {value = dense<-127> : tensor<2x1x1x2xi8>} : () -> tensor<2x1x1x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
+ // CHECK: tosa.transpose
+ %0 = tosa.transpose %input { perms = array<i32: 1, 2, 3, 0> }: (tensor<2x1x1x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>) -> tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
+ return %0: tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
+}
+
// CHECK-LABEL: @transpose_nofold_dense_resource
func.func @transpose_nofold_dense_resource() -> tensor<2x2xf32> {
%0 = "tosa.const"() <{value = dense_resource<resource> : tensor<2x2xf32>}> : () -> tensor<2x2xf32>
- %1 = "tosa.const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32>
// CHECK: tosa.transpose
- %2 = tosa.transpose %0, %1 : (tensor<2x2xf32>, tensor<2xi32>) -> tensor<2x2xf32>
+ %2 = tosa.transpose %0 { perms = array<i32: 1, 0> }: (tensor<2x2xf32>) -> tensor<2x2xf32>
return %2 : tensor<2x2xf32>
}
{-#
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 1aa8547cb2fdb..44f208aee212d 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -235,97 +235,74 @@ func.func @test_pad_padding_shape_mismatch(%arg0: tensor<13x21x3xf32>) -> tensor
// -----
-func.func @test_transpose_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3xi32>) -> tensor<3x13x21xf32> {
- // expected-error at +1 {{'tosa.transpose' op perms of transpose is not constant}}
- %0 = tosa.transpose %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21xf32>
- return %0 : tensor<3x13x21xf32>
-}
-
-// -----
-
func.func @test_transpose_io_rank_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3xi32>) -> tensor<3x13x21x1xf32> {
// expected-error at +1 {{'tosa.transpose' op expected input tensor rank to equal result tensor rank}}
- %0 = tosa.transpose %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21x1xf32>
+ %0 = tosa.transpose %arg0 {perms = array<i32: 2, 1, 0>}: (tensor<13x21x3xf32>) -> tensor<3x13x21x1xf32>
return %0 : tensor<3x13x21x1xf32>
}
// -----
-func.func @test_transpose_invalid_perms_rank(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3x2xi32>) -> tensor<3x13x21xf32> {
- // expected-error at +1 {{'tosa.transpose' op expected permutation tensor to be rank 1 but got rank 2}}
- %0 = tosa.transpose %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<3x2xi32>) -> tensor<3x13x21xf32>
- return %0 : tensor<3x13x21xf32>
-}
-
-// -----
-
func.func @test_transpose_rank0_perms() {
%14 = tensor.empty() : tensor<5x27xi64>
- %cst = tensor.empty() : tensor<i32>
- // expected-error at +1 {{'tosa.transpose' op expected permutation tensor to be rank 1 but got rank 0}}
- %72 = tosa.transpose %14, %cst : (tensor<5x27xi64>, tensor<i32>) -> tensor<?x?xi64>
+ // expected-error at +1 {{'tosa.transpose' op expected perms attribute to have size 2 (input rank) but got size 0}}
+ %72 = tosa.transpose %14 {perms = array<i32> }: (tensor<5x27xi64>) -> tensor<?x?xi64>
return
}
// -----
-func.func @test_transpose_invalid_perms_size(%arg0: tensor<13x21x3xf32>, %arg1: tensor<7xi32>) -> tensor<3x13x21xf32> {
- // expected-error at +1 {{'tosa.transpose' op expected permutation tensor dim 0 to have size 3 (input rank) but got size 7}}
- %0 = tosa.transpose %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<7xi32>) -> tensor<3x13x21xf32>
+func.func @test_transpose_invalid_perms_size(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21xf32> {
+ // expected-error at +1 {{'tosa.transpose' op expected perms attribute to have size 3 (input rank) but got size 7}}
+ %0 = tosa.transpose %arg0 {perms = array<i32: 6, 5, 4, 3, 2, 1, 0> }: (tensor<13x21x3xf32>) -> tensor<3x13x21xf32>
return %0 : tensor<3x13x21xf32>
}
// -----
func.func @test_transpose_invalid_permutation_tensor(%arg0: tensor<13x21x3xf32>) -> tensor<?x?x?xf32> {
- %perms = arith.constant dense<[2, 0, 0]> : tensor<3xi32>
- // expected-error at +1 {{'tosa.transpose' op expected valid permutation tensor}}
- %0 = tosa.transpose %arg0, %perms : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<?x?x?xf32>
+ // expected-error at +1 {{'tosa.transpose' op expected valid permutation indices}}
+ %0 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 0> }: (tensor<13x21x3xf32>) -> tensor<?x?x?xf32>
return %0 : tensor<?x?x?xf32>
}
// -----
func.func @test_transpose_invalid_permutation_negative(%arg0: tensor<3x2xi32>) -> tensor<*xi32> {
- %perms = "tosa.const"() {value = dense<[-1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
- // expected-error at +1 {{'tosa.transpose' op expected valid permutation tensor}}
- %1 = tosa.transpose %arg0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<*xi32>
+ // expected-error at +1 {{'tosa.transpose' op expected valid permutation indices}}
+ %1 = tosa.transpose %arg0 {perms = array<i32: -1, 0> }: (tensor<3x2xi32>) -> tensor<*xi32>
return %1 : tensor<*xi32>
}
// -----
func.func @test_transpose_invalid_permutation_tensor_above_range(%arg0: tensor<3x2xi32>) -> tensor<*xi32> {
- %perms = "tosa.const"() {value = dense<[2, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
- // expected-error at +1 {{'tosa.transpose' op expected valid permutation tensor}}
- %1 = tosa.transpose %arg0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<*xi32>
+ // expected-error at +1 {{'tosa.transpose' op expected valid permutation indices}}
+ %1 = tosa.transpose %arg0 {perms = array<i32: 2, 0> }: (tensor<3x2xi32>) -> tensor<*xi32>
return %1 : tensor<*xi32>
}
// -----
func.func @test_transpose_invalid_permutation_types(%arg0: tensor<3x2xi32>) -> tensor<3x4xi32> {
- %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
// expected-error at +1 {{'tosa.transpose' op expected output tensor dim 0 to match input dim 1 with value of 2}}
- %1 = tosa.transpose %arg0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<3x4xi32>
+ %1 = tosa.transpose %arg0 {perms = array<i32: 1, 0> }: (tensor<3x2xi32>) -> tensor<3x4xi32>
return %1 : tensor<3x4xi32>
}
// -----
func.func @test_transpose_invalid_permutation_types_dynamic_dim_ok(%arg0: tensor<2x?xi32>) -> tensor<3x4xi32> {
- %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
// expected-error at +1 {{'tosa.transpose' op expected output tensor dim 1 to match input dim 0 with value of 2}}
- %1 = tosa.transpose %arg0, %perms : (tensor<2x?xi32>, tensor<2xi32>) -> tensor<3x4xi32>
+ %1 = tosa.transpose %arg0 {perms = array<i32: 1, 0> }: (tensor<2x?xi32>) -> tensor<3x4xi32>
return %1 : tensor<3x4xi32>
}
// -----
func.func @test_transpose_element_type_mismatch(%arg0: tensor<2x3xi32>) -> tensor<3x2xf32> {
- %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
// expected-error at +1 {{'tosa.transpose' op failed to verify that all of {input1, output} have same element type}}
- %1 = tosa.transpose %arg0, %perms : (tensor<2x3xi32>, tensor<2xi32>) -> tensor<3x2xf32>
+ %1 = tosa.transpose %arg0 {perms = array<i32: 1, 0>} : (tensor<2x3xi32>) -> tensor<3x2xf32>
return %1 : tensor<3x2xf32>
}
@@ -674,10 +651,9 @@ func.func @test_tile_io_rank_mismatch() {
// CHECK-LABEL: @test_invalid_constant_permutation
func.func @test_invalid_constant_permutation() {
- // expected-error at +3 {{'tosa.transpose' op expected valid permutation tensor}}
%0 = tensor.empty() : tensor<3x4x5xi32>
- %1 = arith.constant dense<[3, 0, 1]> : tensor<3xi32>
- %2 = tosa.transpose %0, %1 : (tensor<3x4x5xi32>, tensor<3xi32>) -> tensor<3x4x5xi32>
+ // expected-error at +1 {{'tosa.transpose' op expected valid permutation indices}}
+ %2 = tosa.transpose %0 {perms = array<i32: 3, 0, 1>}: (tensor<3x4x5xi32>) -> tensor<3x4x5xi32>
return
}
@@ -685,11 +661,10 @@ func.func @test_invalid_constant_permutation() {
// CHECK-LABEL: test_rank_size_constant_permutation
func.func @test_rank_size_constant_permutation() {
- // expected-error at +4 {{'tosa.transpose' op expected valid permutation tensor}}
%0 = arith.constant 6 : index
- %1 = arith.constant dense<[0, 2]> : tensor<2xi32>
%2 = tensor.empty(%0) : tensor<?x27xi64>
- %3 = tosa.transpose %2, %1 : (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
+ // expected-error at +1 {{'tosa.transpose' op expected valid permutation indices}}
+ %3 = tosa.transpose %2 {perms = array<i32: 0, 2>}: (tensor<?x27xi64>) -> tensor<?x27xi64>
return
}
@@ -697,11 +672,10 @@ func.func @test_rank_size_constant_permutation() {
// CHECK-LABEL: test_large_constant_permutation
func.func @test_large_constant_permutation() {
- // expected-error at +4 {{'tosa.transpose' op expected valid permutation tensor}}
%0 = arith.constant 6 : index
- %1 = arith.constant dense<[1185677355, 332462212]> : tensor<2xi32>
%2 = tensor.empty(%0) : tensor<?x27xi64>
- %3 = tosa.transpose %2, %1 : (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
+ // expected-error at +1 {{'tosa.transpose' op expected valid permutation indices}}
+ %3 = tosa.transpose %2 {perms = array<i32: 1185677355, 332462212>}: (tensor<?x27xi64>) -> tensor<?x27xi64>
return
}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 90c4551564d1e..67efab7f1559f 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -105,9 +105,8 @@ func.func @test_tile(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x39x21
// -----
func.func @test_transpose(%arg0: tensor<13x21x3x1x1x1x1xf32>) -> tensor<3x13x21x1x1x1x1xf32> {
- %0 = "tosa.const"() {value = dense<[2, 0, 1, 3, 4, 5, 6]> : tensor<7xi32>} : () -> tensor<7xi32>
// expected-error at +1 {{'tosa.transpose' op failed level check: operand rank(shape) <= MAX_RANK}}
- %1 = "tosa.transpose"(%arg0, %0) : (tensor<13x21x3x1x1x1x1xf32>, tensor<7xi32>) -> tensor<3x13x21x1x1x1x1xf32>
+ %1 = "tosa.transpose"(%arg0) {perms = array<i32: 2, 0, 1, 3, 4, 5, 6>} : (tensor<13x21x3x1x1x1x1xf32>) -> tensor<3x13x21x1x1x1x1xf32>
return %1 : tensor<3x13x21x1x1x1x1xf32>
}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index cd73377c7f587..3694c7d472bc3 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -640,24 +640,21 @@ func.func @test_tile(%arg0: tensor<13x21x3xf32>) -> tensor<39x21x6xf32> {
// -----
// CHECK-LABEL: transpose
func.func @test_transpose(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21xf32> {
- %0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
- %1 = tosa.transpose %arg0, %0 : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21xf32>
+ %1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>} : (tensor<13x21x3xf32>) -> tensor<3x13x21xf32>
return %1 : tensor<3x13x21xf32>
}
// -----
// CHECK-LABEL: transpose_dynamic_dim
func.func @test_transpose_dynamic_dim(%arg0: tensor<13x?x3xf32>) -> tensor<3x13x?xf32> {
- %0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
- %1 = tosa.transpose %arg0, %0 : (tensor<13x?x3xf32>, tensor<3xi32>) -> tensor<3x13x?xf32>
+ %1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>} : (tensor<13x?x3xf32>) -> tensor<3x13x?xf32>
return %1 : tensor<3x13x?xf32>
}
// -----
// CHECK-LABEL: transpose_half_dynamic_dim
func.func @test_transpose_half_dynamic_dim(%arg0: tensor<13x3x3xf32>) -> tensor<3x13x?xf32> {
- %0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
- %1 = tosa.transpose %arg0, %0 : (tensor<13x3x3xf32>, tensor<3xi32>) -> tensor<3x13x?xf32>
+ %1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>} : (tensor<13x3x3xf32>) -> tensor<3x13x?xf32>
return %1 : tensor<3x13x?xf32>
}
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
index ae7a8e90b4281..944dd5f2a4d10 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
@@ -54,11 +54,10 @@ func.func @transpose_conv2d_quantized_padded(%arg0: tensor<2x16x14x3xi8>, %arg1:
func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<5x3x5x3xf32>, %arg2: tensor<5xf32>) -> tensor<2x?x?x5xf32> {
// Manipulate the weight matrix to handle striding.
// CHECK-DAG: %[[PADV:.+]] = tosa.const_shape {value = dense<[0, 0, 0, 1, 0, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
- // CHECK-DAG: %[[TRANSV:.+]] = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
// CHECK-DAG: %[[PADW:.+]] = tosa.pad %arg1, %[[PADV]]
// CHECK-DAG: %[[CONST1:.+]] = tosa.const_shape {value = dense<[5, 2, 2, 2, 3, 3]> : tensor<6xindex>}
// CHECK-DAG: %[[RESW1:.+]] = tosa.reshape %[[PADW]], %[[CONST1]]
- // CHECK-DAG: %[[TRANS:.+]] = tosa.transpose %[[RESW1]], %[[TRANSV]]
+ // CHECK-DAG: %[[TRANS:.+]] = tosa.transpose %[[RESW1]] {perms = array<i32: 2, 4, 0, 1, 3, 5>}
// CHECK-DAG: %[[CONST3:.+]] = tosa.const_shape {value = dense<[30, 2, 2, 3]> : tensor<4xindex>}
// CHECK-DAG: %[[RESW2:.+]] = tosa.reshape %[[TRANS]], %[[CONST3]]
// CHECK-DAG: %[[REV1:.+]] = tosa.reverse %[[RESW2]] {axis = 1 : i32}
@@ -68,7 +67,6 @@ func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<
// Pad out the input matrix to handle the transpose conv.
// CHECK-DAG: %[[PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
- // CHECK-DAG: %[[TRANS2:.+]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
// CHECK-DAG: %[[NEWINPUT:.+]] = tosa.pad %arg0, %[[PAD]]
// Manipulate the final shape.
@@ -76,7 +74,7 @@ func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<
// CHECK-DAG: %[[CONV:.+]] = tosa.conv2d %[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]] {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
// CHECK-DAG: %[[CONST6:.+]] = tosa.const_shape {value = dense<[2, 18, 16, 2, 3, 5]> : tensor<6xindex>}
// CHECK-DAG: %[[RESHAPE_OUT_1:.+]] = tosa.reshape %[[CONV]], %[[CONST6]]
- // CHECK-DAG: %[[TRANS_OUT:.+]] = tosa.transpose %[[RESHAPE_OUT_1]], %[[TRANS2]]
+ // CHECK-DAG: %[[TRANS_OUT:.+]] = tosa.transpose %[[RESHAPE_OUT_1]] {perms = array<i32: 0, 1, 3, 2, 4, 5>}
// CHECK-DAG: %[[CONST8:.+]] = tosa.const_shape {value = dense<[2, 36, 48, 5]> : tensor<4xindex>
// CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = tosa.reshape %[[TRANS_OUT]], %[[CONST8]]
// CHECK-DAG: %[[SLICE:.+]] = tosa.slice %[[RESHAPE_OUT_2]], %[[START]], %[[SIZE]]
@@ -95,11 +93,10 @@ func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<
func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1: tensor<5x3x5x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x35x47x5xi32>) {
// Manipulate the weight matrix to handle striding.
// CHECK-DAG: %[[PADV:.+]] = tosa.const_shape {value = dense<[0, 0, 0, 1, 0, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
- // CHECK-DAG: %[[TRANSV:.+]] = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
// CHECK-DAG: %[[PADW:.+]] = tosa.pad %arg1, %[[PADV]] {input_zp = 42 : i32}
// CHECK-DAG: %[[CONST1:.+]] = tosa.const_shape {value = dense<[5, 2, 2, 2, 3, 3]> : tensor<6xindex>}
// CHECK-DAG: %[[RESW1:.+]] = tosa.reshape %[[PADW]], %[[CONST1]]
- // CHECK-DAG: %[[TRANS:.+]] = tosa.transpose %[[RESW1]], %[[TRANSV]]
+ // CHECK-DAG: %[[TRANS:.+]] = tosa.transpose %[[RESW1]] {perms = array<i32: 2, 4, 0, 1, 3, 5>}
// CHECK-DAG: %[[CONST3:.+]] = tosa.const_shape {value = dense<[30, 2, 2, 3]> : tensor<4xindex>}
// CHECK-DAG: %[[RESW2:.+]] = tosa.reshape %[[TRANS]], %[[CONST3]]
// CHECK-DAG: %[[REV1:.+]] = tosa.reverse %[[RESW2]] {axis = 1 : i32}
@@ -109,7 +106,6 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1
// Pad out the input matrix to handle the transpose conv.
// CHECK-DAG: %[[PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
- // CHECK-DAG: %[[TRANS2:.+]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
// CHECK-DAG: %[[NEWINPUT:.+]] = tosa.pad %arg0, %[[PAD]] {input_zp = -22 : i32}
// Manipulate the final shape.
@@ -119,7 +115,7 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1
// CHECK-DAG: %[[CONV:.+]] = tosa.conv2d %[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]], %[[INPUT_ZP]], %[[WEIGHT_ZP]] {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
// CHECK-DAG: %[[CONV_NEW_SHAPE:.*]] = tosa.const_shape {value = dense<[2, 18, 16, 2, 3, 5]> : tensor<6xindex>}
// CHECK-DAG: %[[RESHAPE_OUT_1:.+]] = tosa.reshape %[[CONV]], %[[CONV_NEW_SHAPE]]
- // CHECK-DAG: %[[TRANS_OUT:.+]] = tosa.transpose %[[RESHAPE_OUT_1]], %[[TRANS2]]
+ // CHECK-DAG: %[[TRANS_OUT:.+]] = tosa.transpose %[[RESHAPE_OUT_1]] {perms = array<i32: 0, 1, 3, 2, 4, 5>}
// CHECK-DAG: %[[TRANS_NEW_SHAPE:.+]] = tosa.const_shape {value = dense<[2, 36, 48, 5]> : tensor<4xindex>}
// CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = tosa.reshape %[[TRANS_OUT]], %[[TRANS_NEW_SHAPE]]
// CHECK-DAG: %[[SLICE:.+]] = tosa.slice %[[RESHAPE_OUT_2]], %[[START]], %[[SIZE]]
@@ -138,12 +134,10 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1
func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 : tensor<1x2x1x1xi8>, %arg2 : tensor<1xi32>) -> (tensor<1x19x2x1xi32>) {
// CHECK-DAG: %[[WEIGHT_PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 0, 0, 0, 1, 0, 0]> : tensor<8xindex>}
// CHECK-DAG: %[[CONST1:.+]] = tosa.const_shape {value = dense<[1, 2, 1, 1, 2, 1]> : tensor<6xindex>}
- // CHECK-DAG: %[[WEIGHT_PERMS:.+]] = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>}
// CHECK-DAG: %[[CONST3:.+]] = tosa.const_shape {value = dense<[2, 2, 1, 1]> : tensor<4xindex>}
// CHECK-DAG: %[[INPUT_PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 1, 1, 0, 0, 0, 0]> : tensor<8xindex>}
// CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<2xi32>}
// CHECK-DAG: %[[CONST6:.+]] = tosa.const_shape {value = dense<[1, 17, 1, 1, 2, 1]> : tensor<6xindex>}
- // CHECK-DAG: %[[RESULT_PERMS:.+]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
// CHECK-DAG: %[[CONST8:.+]] = tosa.const_shape {value = dense<[1, 17, 2, 1]> : tensor<4xindex>}
// CHECK-DAG: %[[RESULT_PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 2, 0, 0, 0, 0, 0]> : tensor<8xindex>}
// CHECK-DAG: %[[CONST10:.+]] = tosa.const_shape {value = dense<1> : tensor<4xindex>}
@@ -151,13 +145,13 @@ func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 :
// CHECK-DAG: %[[WEIGHT_ZP:.*]] = "tosa.const"() <{value = dense<93> : tensor<1xi8>}>
// CHECK: %[[PAD_WEIGHT:.+]] = tosa.pad %arg1, %[[WEIGHT_PAD]] {input_zp = 93 : i32}
// CHECK: %[[RESHAPE_WEIGHT_0:.+]] = tosa.reshape %[[PAD_WEIGHT]], %[[CONST1]]
- // CHECK: %[[TRANSPOSE_WEIGHT:.+]] = tosa.transpose %[[RESHAPE_WEIGHT_0]], %[[WEIGHT_PERMS]]
+ // CHECK: %[[TRANSPOSE_WEIGHT:.+]] = tosa.transpose %[[RESHAPE_WEIGHT_0]] {perms = array<i32: 2, 4, 0, 1, 3, 5>}
// CHECK: %[[RESHAPE_WEIGHT_1:.+]] = tosa.reshape %[[TRANSPOSE_WEIGHT]], %[[CONST3]]
// CHECK: %[[REVERSE:.+]] = tosa.reverse %[[RESHAPE_WEIGHT_1]] {axis = 1 : i32}
// CHECK: %[[PAD_INPUT:.+]] = tosa.pad %arg0, %[[INPUT_PAD]] {input_zp = -103 : i32}
// CHECK: %[[CONV:.+]] = tosa.conv2d %[[PAD_INPUT]], %[[REVERSE]], %[[ZERO]], %[[INPUT_ZP]], %[[WEIGHT_ZP]] {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
// CHECK: %[[RESHAPE_RESULT_0:.+]] = tosa.reshape %[[CONV]], %[[CONST6]]
- // CHECK: %[[TRANSPOSE_RESULT:.+]] = tosa.transpose %[[RESHAPE_RESULT_0]], %[[RESULT_PERMS]]
+ // CHECK: %[[TRANSPOSE_RESULT:.+]] = tosa.transpose %[[RESHAPE_RESULT_0]] {perms = array<i32: 0, 1, 3, 2, 4, 5>}
// CHECK: %[[RESHAPE_RESULT_1:.+]] = tosa.reshape %[[TRANSPOSE_RESULT]], %[[CONST8]]
// CHECK: %[[PAD_RESULT:.+]] = tosa.pad %[[RESHAPE_RESULT_1]], %[[RESULT_PAD]]
// CHECK: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2, %[[CONST10]]
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 7e714d0f8547a..008dcc356730b 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -577,29 +577,10 @@ func.func @test_tile(%arg0 : tensor<2x3x?xi32>) -> () {
// -----
-// CHECK-LABEL: @test_transpose_same
-func.func @test_transpose_same(%arg0 : tensor<4x4x4xi32>, %arg1 : tensor<3xi32>) -> () {
- // CHECK: tosa.transpose %arg0, %arg1 : (tensor<4x4x4xi32>, tensor<3xi32>) -> tensor<4x4x4xi32>
- %0 = tosa.transpose %arg0, %arg1 : (tensor<4x4x4xi32>, tensor<3xi32>) -> tensor<?x?x?xi32>
- return
-}
-
-// -----
-
-// CHECK-LABEL: @test_transpose_perm_unknown
-func.func @test_transpose_perm_unknown(%arg0 : tensor<4x4x5xi32>, %arg1 : tensor<3xi32>) -> () {
- // CHECK: tosa.transpose %arg0, %arg1 : (tensor<4x4x5xi32>, tensor<3xi32>) -> tensor<?x?x?xi32>
- %0 = tosa.transpose %arg0, %arg1 : (tensor<4x4x5xi32>, tensor<3xi32>) -> tensor<?x?x?xi32>
- return
-}
-
-// -----
-
// CHECK-LABEL: @test_transpose_static
func.func @test_transpose_static(%arg0 : tensor<3x4x5xi32>) -> () {
- %0 = arith.constant dense<[2, 1, 0]> : tensor<3xi32>
- // CHECK: tosa.transpose %arg0, %cst : (tensor<3x4x5xi32>, tensor<3xi32>) -> tensor<5x4x3xi32>
- %1 = tosa.transpose %arg0, %0 : (tensor<3x4x5xi32>, tensor<3xi32>) -> tensor<?x?x?xi32>
+ // CHECK: tosa.transpose %arg0 {perms = array<i32: 2, 1, 0>} : (tensor<3x4x5xi32>) -> tensor<5x4x3xi32>
+ %1 = tosa.transpose %arg0 { perms = array<i32: 2, 1, 0> }: (tensor<3x4x5xi32>) -> tensor<?x?x?xi32>
return
}
diff --git a/mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir b/mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir
index 3a293009a5455..ed87a6d7f73db 100644
--- a/mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir
@@ -4,11 +4,9 @@
// CHECK-NEXT: %[[RESULT:.*]] = tosa.ceil %arg0
// CHECK-NEXT: return %[[RESULT]]
func.func @test_transpose_tracks_to_nullifying_single_step(%arg0: tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> {
- %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
- %0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %0 = tosa.transpose %arg0 {perms = array<i32: 0, 2, 3, 1>}: (tensor<1x2x3x4xi32>) -> tensor<1x3x4x2xi32>
%ceil = tosa.ceil %0 : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
- %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
- %1 = tosa.transpose %ceil, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ %1 = tosa.transpose %ceil {perms = array<i32: 0, 3, 1, 2>}: (tensor<1x3x4x2xi32>) -> tensor<1x2x3x4xi32>
return %1 : tensor<1x2x3x4xi32>
}
@@ -20,13 +18,11 @@ func.func @test_transpose_tracks_to_nullifying_single_step(%arg0: tensor<1x2x3x4
// CHECK-NEXT: %[[NOT:.*]] = tosa.bitwise_not %[[ABS]]
// CHECK-NEXT: return %[[NOT]]
func.func @test_transpose_tracks_to_nullifying_multi_unary_step(%arg0: tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> {
- %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
- %0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %0 = tosa.transpose %arg0 {perms = array<i32: 0, 2, 3, 1>}: (tensor<1x2x3x4xi32>) -> tensor<1x3x4x2xi32>
%clamp = tosa.clamp %0 {max_val = 1 : i32, min_val = 0 : i32} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
%abs = tosa.abs %clamp : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
%bitwise_not = tosa.bitwise_not %abs : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
- %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
- %1 = tosa.transpose %bitwise_not, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ %1 = tosa.transpose %bitwise_not {perms = array<i32: 0, 3, 1, 2>}: (tensor<1x3x4x2xi32>) -> tensor<1x2x3x4xi32>
return %1 : tensor<1x2x3x4xi32>
}
@@ -38,14 +34,12 @@ func.func @test_transpose_tracks_to_nullifying_multi_unary_step(%arg0: tensor<1x
// CHECK-NEXT: %[[ADD:.*]] = tosa.add %[[CLAMP]], %[[ABS]]
// CHECK-NEXT: return %[[ADD]]
func.func @test_transpose_tracks_to_nullifying_diverging_binary(%arg0: tensor<1x2x3x4xi32>, %arg1: tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> {
- %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
- %transpose0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
- %transpose1 = tosa.transpose %arg1, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %transpose0 = tosa.transpose %arg0 {perms = array<i32: 0, 2, 3, 1>}: (tensor<1x2x3x4xi32>) -> tensor<1x3x4x2xi32>
+ %transpose1 = tosa.transpose %arg1 {perms = array<i32: 0, 2, 3, 1>}: (tensor<1x2x3x4xi32>) -> tensor<1x3x4x2xi32>
%clamp = tosa.clamp %transpose0 {max_val = 1 : i32, min_val = 0 : i32} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
%abs = tosa.abs %transpose1 : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
%add = tosa.add %clamp, %abs : (tensor<1x3x4x2xi32>, tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
- %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
- %result = tosa.transpose %add, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ %result = tosa.transpose %add {perms = array<i32: 0, 3, 1, 2>}: (tensor<1x3x4x2xi32>) -> tensor<1x2x3x4xi32>
return %result : tensor<1x2x3x4xi32>
}
@@ -58,14 +52,12 @@ func.func @test_transpose_tracks_to_nullifying_diverging_binary(%arg0: tensor<1x
// CHECK-NEXT: %[[ADD:.*]] = tosa.add %[[CLAMP]], %[[ABS]]
// CHECK-NEXT: return %[[ADD]]
func.func @test_transpose_tracks_to_nullifying_diverging_binary_with_broadcasting(%arg0: tensor<1x2x3x4xi32>, %arg1: tensor<1x2x1x4xi32>) -> tensor<1x2x3x4xi32> {
- %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
- %transpose0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
- %transpose1 = tosa.transpose %arg1, %perms0 : (tensor<1x2x1x4xi32>, tensor<4xi32>) -> tensor<1x1x4x2xi32>
+ %transpose0 = tosa.transpose %arg0 {perms = array<i32: 0, 2, 3, 1>}: (tensor<1x2x3x4xi32>) -> tensor<1x3x4x2xi32>
+ %transpose1 = tosa.transpose %arg1 {perms = array<i32: 0, 2, 3, 1>}: (tensor<1x2x1x4xi32>) -> tensor<1x1x4x2xi32>
%clamp = tosa.clamp %transpose0 {max_val = 1 : i32, min_val = 0 : i32} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
%abs = tosa.abs %transpose1 : (tensor<1x1x4x2xi32>) -> tensor<1x1x4x2xi32>
%add = tosa.add %clamp, %abs : (tensor<1x3x4x2xi32>, tensor<1x1x4x2xi32>) -> tensor<1x3x4x2xi32>
- %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
- %result = tosa.transpose %add, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ %result = tosa.transpose %add {perms = array<i32: 0, 3, 1, 2>}: (tensor<1x3x4x2xi32>) -> tensor<1x2x3x4xi32>
return %result : tensor<1x2x3x4xi32>
}
@@ -75,11 +67,9 @@ func.func @test_transpose_tracks_to_nullifying_diverging_binary_with_broadcastin
// CHECK-NEXT: %[[RESULT:.*]] = tosa.add %arg0, %arg0
// CHECK-NEXT: return %[[RESULT]]
func.func @test_transpose_tracks_to_nullifying__converging_binary(%arg0: tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> {
- %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
- %0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %0 = tosa.transpose %arg0 {perms = array<i32: 0, 2, 3, 1>}: (tensor<1x2x3x4xi32>) -> tensor<1x3x4x2xi32>
%clamp = tosa.add %0, %0 : (tensor<1x3x4x2xi32>, tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
- %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
- %1 = tosa.transpose %clamp, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ %1 = tosa.transpose %clamp {perms = array<i32: 0, 3, 1, 2>}: (tensor<1x3x4x2xi32>) -> tensor<1x2x3x4xi32>
return %1 : tensor<1x2x3x4xi32>
}
@@ -102,20 +92,20 @@ func.func @test_torch_conv2d_with_elementwise_in_between(%arg0: tensor<3x3x10x10
%5 = "tosa.const"() <{value = dense_resource<torch_tensor_3_torch.float32> : tensor<3xf32>}> : () -> tensor<3xf32>
%6 = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32>
%7 = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
- %8 = tosa.transpose %arg0, %6 : (tensor<3x3x10x10xf32>, tensor<4xi32>) -> tensor<3x10x10x3xf32>
- %9 = tosa.transpose %4, %6 : (tensor<3x3x2x2xf32>, tensor<4xi32>) -> tensor<3x2x2x3xf32>
+ %8 = tosa.transpose %arg0 {perms = array<i32: 0, 2, 3, 1>}: (tensor<3x3x10x10xf32>) -> tensor<3x10x10x3xf32>
+ %9 = tosa.transpose %4 {perms = array<i32: 0, 2, 3, 1>}: (tensor<3x3x2x2xf32>) -> tensor<3x2x2x3xf32>
%10 = tosa.conv2d %8, %9, %5 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<3x10x10x3xf32>, tensor<3x2x2x3xf32>, tensor<3xf32>) -> tensor<3x9x9x3xf32>
- %11 = tosa.transpose %10, %7 : (tensor<3x9x9x3xf32>, tensor<4xi32>) -> tensor<3x3x9x9xf32>
+ %11 = tosa.transpose %10 {perms = array<i32: 0, 3, 1, 2>}: (tensor<3x9x9x3xf32>) -> tensor<3x3x9x9xf32>
%12 = tosa.ceil %11 : (tensor<3x3x9x9xf32>) -> tensor<3x3x9x9xf32>
- %13 = tosa.transpose %12, %6 : (tensor<3x3x9x9xf32>, tensor<4xi32>) -> tensor<3x9x9x3xf32>
- %14 = tosa.transpose %3, %6 : (tensor<3x3x2x2xf32>, tensor<4xi32>) -> tensor<3x2x2x3xf32>
+ %13 = tosa.transpose %12 {perms = array<i32: 0, 2, 3, 1>}: (tensor<3x3x9x9xf32>) -> tensor<3x9x9x3xf32>
+ %14 = tosa.transpose %3 {perms = array<i32: 0, 2, 3, 1>}: (tensor<3x3x2x2xf32>) -> tensor<3x2x2x3xf32>
%15 = tosa.conv2d %13, %14, %2 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<3x9x9x3xf32>, tensor<3x2x2x3xf32>, tensor<3xf32>) -> tensor<3x8x8x3xf32>
- %16 = tosa.transpose %15, %7 : (tensor<3x8x8x3xf32>, tensor<4xi32>) -> tensor<3x3x8x8xf32>
+ %16 = tosa.transpose %15 {perms = array<i32: 0, 3, 1, 2>}: (tensor<3x8x8x3xf32>) -> tensor<3x3x8x8xf32>
%17 = tosa.floor %16 : (tensor<3x3x8x8xf32>) -> tensor<3x3x8x8xf32>
- %18 = tosa.transpose %17, %6 : (tensor<3x3x8x8xf32>, tensor<4xi32>) -> tensor<3x8x8x3xf32>
- %19 = tosa.transpose %1, %6 : (tensor<3x3x2x2xf32>, tensor<4xi32>) -> tensor<3x2x2x3xf32>
+ %18 = tosa.transpose %17 {perms = array<i32: 0, 2, 3, 1>}: (tensor<3x3x8x8xf32>) -> tensor<3x8x8x3xf32>
+ %19 = tosa.transpose %1 {perms = array<i32: 0, 2, 3, 1>}: (tensor<3x3x2x2xf32>) -> tensor<3x2x2x3xf32>
%20 = tosa.conv2d %18, %19, %0 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<3x8x8x3xf32>, tensor<3x2x2x3xf32>, tensor<3xf32>) -> tensor<3x7x7x3xf32>
- %21 = tosa.transpose %20, %7 : (tensor<3x7x7x3xf32>, tensor<4xi32>) -> tensor<3x3x7x7xf32>
+ %21 = tosa.transpose %20 {perms = array<i32: 0, 3, 1, 2>}: (tensor<3x7x7x3xf32>) -> tensor<3x3x7x7xf32>
return %21 : tensor<3x3x7x7xf32>
}
@@ -126,13 +116,11 @@ func.func @test_torch_conv2d_with_elementwise_in_between(%arg0: tensor<3x3x10x10
// CHECK-NEXT: %[[RES:.*]] = tosa.mul %arg0, %arg1, %[[SHIFT]]
// CHECK-NEXT: return %[[RES]]
func.func @test_mulop_conversion(%arg0: tensor<1x2x3x4xi32>, %arg1: tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> {
- %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
- %transpose0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
- %transpose1 = tosa.transpose %arg1, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %transpose0 = tosa.transpose %arg0 {perms = array<i32: 0, 2, 3, 1>}: (tensor<1x2x3x4xi32>) -> tensor<1x3x4x2xi32>
+ %transpose1 = tosa.transpose %arg1 {perms = array<i32: 0, 2, 3, 1>}: (tensor<1x2x3x4xi32>) -> tensor<1x3x4x2xi32>
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
%mul = tosa.mul %transpose0, %transpose1, %shift : (tensor<1x3x4x2xi32>, tensor<1x3x4x2xi32>, tensor<1xi8>) -> tensor<1x3x4x2xi32>
- %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
- %result = tosa.transpose %mul, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ %result = tosa.transpose %mul {perms = array<i32: 0, 3, 1, 2>}: (tensor<1x3x4x2xi32>) -> tensor<1x2x3x4xi32>
return %result : tensor<1x2x3x4xi32>
}
@@ -141,15 +129,14 @@ func.func @test_mulop_conversion(%arg0: tensor<1x2x3x4xi32>, %arg1: tensor<1x2x3
// COM: this case is a reshape we don't convert, since can't fold the transpose into it.
// COM: a transform actually occurs underneath the hood, but it results in identical IR.
// CHECK-LABEL: @test_basic_non_broadcasting_reshape
-// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {value = dense<[1, 3, 2]> : tensor<3xindex>}
-// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}>
-// CHECK: %[[VAL_3:.*]] = tosa.reshape %arg0, %[[VAL_1]] : (tensor<2x3xi32>, !tosa.shape<3>) -> tensor<1x3x2xi32>
-// CHECK: %[[VAL_4:.*]] = tosa.transpose %[[VAL_3]], %[[VAL_2]] : (tensor<1x3x2xi32>, tensor<3xi32>) -> tensor<1x2x3xi32>
+// CHECK: %[[SHAPE:.+]] = tosa.const_shape {value = dense<[1, 3, 2]> : tensor<3xindex>}
+// CHECK: %[[RESHAPED:.+]] = tosa.reshape %arg0, %[[SHAPE]] : (tensor<2x3xi32>, !tosa.shape<3>) -> tensor<1x3x2xi32>
+// CHECK: tosa.transpose %[[RESHAPED]] {perms = array<i32: 0, 2, 1>} : (tensor<1x3x2xi32>) -> tensor<1x2x3xi32>
func.func @test_basic_non_broadcasting_reshape(%arg0: tensor<2x3xi32>) -> tensor<1x2x3xi32> {
%shape = tosa.const_shape {value = dense<[1, 3, 2]> : tensor<3xindex>} : () -> !tosa.shape<3>
%perms = "tosa.const"() {value = dense<[0, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
%1 = tosa.reshape %arg0, %shape : (tensor<2x3xi32>, !tosa.shape<3>) -> tensor<1x3x2xi32>
- %2 = tosa.transpose %1, %perms : (tensor<1x3x2xi32>, tensor<3xi32>) -> tensor<1x2x3xi32>
+ %2 = tosa.transpose %1 {perms = array<i32: 0, 2, 1>}: (tensor<1x3x2xi32>) -> tensor<1x2x3xi32>
return %2 : tensor<1x2x3xi32>
}
@@ -163,7 +150,7 @@ func.func @test_dynamic_broadcasting_reshape(%arg0: tensor<?xi32>) -> tensor<1x1
%shape = tosa.const_shape {value = dense<[1, -1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
%perms = "tosa.const"() {value = dense<[0, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
%1 = tosa.reshape %arg0, %shape : (tensor<?xi32>, !tosa.shape<3>) -> tensor<1x?x1xi32>
- %2 = tosa.transpose %1, %perms : (tensor<1x?x1xi32>, tensor<3xi32>) -> tensor<1x1x?xi32>
+ %2 = tosa.transpose %1 {perms = array<i32: 0, 2, 1>}: (tensor<1x?x1xi32>) -> tensor<1x1x?xi32>
return %2 : tensor<1x1x?xi32>
}
@@ -179,10 +166,9 @@ func.func @test_reshape_for_broadcast(%arg0: tensor<4x3x2xi32>) -> tensor<4x3x2x
%0 = "tosa.const"() {value = dense<[1,2,3,4]> : tensor<4xi32>} : () -> tensor<4xi32>
%1 = tosa.const_shape {value = dense<[1, 1, 4]> : tensor<3xindex>} : () -> !tosa.shape<3>
%reshape = tosa.reshape %0, %1 : (tensor<4xi32>, !tosa.shape<3>) -> tensor<1x1x4xi32>
- %perms0 = "tosa.const"() {value = dense<[2, 1, 0]> : tensor<3xi32>} : () -> tensor<3xi32>
- %transpose0 = tosa.transpose %arg0, %perms0 : (tensor<4x3x2xi32>, tensor<3xi32>) -> tensor<2x3x4xi32>
+ %transpose0 = tosa.transpose %arg0 {perms = array<i32: 2, 1, 0>}: (tensor<4x3x2xi32>) -> tensor<2x3x4xi32>
%add = tosa.add %transpose0, %reshape : (tensor<2x3x4xi32>, tensor<1x1x4xi32>) -> tensor<2x3x4xi32>
- %transpose1 = tosa.transpose %add, %perms0 : (tensor<2x3x4xi32>, tensor<3xi32>) -> tensor<4x3x2xi32>
+ %transpose1 = tosa.transpose %add {perms = array<i32: 2, 1, 0>}: (tensor<2x3x4xi32>) -> tensor<4x3x2xi32>
return %transpose1 : tensor<4x3x2xi32>
}
@@ -223,7 +209,7 @@ func.func @test_resnet18_common_case(%arg0: tensor<64xf32>, %arg1: tensor<64xf32
%64 = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
%69 = "tosa.const"() <{value = dense<9.99999974E-6> : tensor<1xf32>}> : () -> tensor<1xf32>
%70 = "tosa.const"() <{value = dense<5.000000e-01> : tensor<1xf32>}> : () -> tensor<1xf32>
- %75 = tosa.transpose %74, %64 : (tensor<1x112x112x64xf32>, tensor<4xi32>) -> tensor<1x64x112x112xf32>
+ %75 = tosa.transpose %74 {perms = array<i32: 0, 3, 1, 2>}: (tensor<1x112x112x64xf32>) -> tensor<1x64x112x112xf32>
%76 = tosa.add %arg1, %69 : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32>
%77 = tosa.pow %76, %70 : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32>
%78 = tosa.reciprocal %77 : (tensor<64xf32>) -> tensor<64xf32>
@@ -236,54 +222,45 @@ func.func @test_resnet18_common_case(%arg0: tensor<64xf32>, %arg1: tensor<64xf32
%85 = tosa.reshape %59, %58 : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x64x1x1xf32>
%86 = tosa.add %84, %85 : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32>
%87 = tosa.clamp %86 {max_val = 3.40282347E+38 : f32, min_val = 0.000000e+00 : f32} : (tensor<1x64x112x112xf32>) -> tensor<1x64x112x112xf32>
- %88 = tosa.transpose %87, %63 : (tensor<1x64x112x112xf32>, tensor<4xi32>) -> tensor<1x112x112x64xf32>
+ %88 = tosa.transpose %87 {perms = array<i32: 0, 2, 3, 1>}: (tensor<1x64x112x112xf32>) -> tensor<1x112x112x64xf32>
return %88 : tensor<1x112x112x64xf32>
}
// -----
// CHECK-LABEL: @test_back_to_back_nullifiers
-// CHECK: %[[PERMS:.*]] = "tosa.const"
-// CHECK: %[[RES:.*]] = tosa.transpose %arg0, %[[PERMS]]
+// CHECK: %[[RES:.*]] = tosa.transpose %arg0 {perms = array<i32: 1, 0>}
// CHECK: return %[[RES]]
func.func @test_back_to_back_nullifiers(%arg0: tensor<2x3xi32>) -> tensor<3x2xi32> {
- %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
- %0 = tosa.transpose %arg0, %perms : (tensor<2x3xi32>, tensor<2xi32>) -> tensor<3x2xi32>
- %1 = tosa.transpose %0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
- %2 = tosa.transpose %1, %perms : (tensor<2x3xi32>, tensor<2xi32>) -> tensor<3x2xi32>
+ %0 = tosa.transpose %arg0 {perms = array<i32: 1, 0>}: (tensor<2x3xi32>) -> tensor<3x2xi32>
+ %1 = tosa.transpose %0 {perms = array<i32: 1, 0>}: (tensor<3x2xi32>) -> tensor<2x3xi32>
+ %2 = tosa.transpose %1 {perms = array<i32: 1, 0>}: (tensor<2x3xi32>) -> tensor<3x2xi32>
return %2 : tensor<3x2xi32>
}
// -----
// CHECK-LABEL: @test_back_to_back_nullifiers_different_transposes
-// CHECK: %[[PERMS:.*]] = "tosa.const"
-// CHECK: %[[RES:.*]] = tosa.transpose %arg0, %[[PERMS]]
+// CHECK: %[[RES:.*]] = tosa.transpose %arg0 {perms = array<i32: 0, 2, 3, 1>}
// CHECK: return %[[RES]]
func.func @test_back_to_back_nullifiers_different_transposes(%arg0: tensor<2x3x4x5xi32>) -> tensor<2x4x5x3xi32> {
- %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
- %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
- %0 = tosa.transpose %arg0, %perms0 : (tensor<2x3x4x5xi32>, tensor<4xi32>) -> tensor<2x4x5x3xi32>
- %1 = tosa.transpose %0, %perms1 : (tensor<2x4x5x3xi32>, tensor<4xi32>) -> tensor<2x3x4x5xi32>
- %2 = tosa.transpose %1, %perms0 : (tensor<2x3x4x5xi32>, tensor<4xi32>) -> tensor<2x4x5x3xi32>
+ %0 = tosa.transpose %arg0 {perms = array<i32: 0, 2, 3, 1>}: (tensor<2x3x4x5xi32>) -> tensor<2x4x5x3xi32>
+ %1 = tosa.transpose %0 {perms = array<i32: 0, 3, 1, 2>}: (tensor<2x4x5x3xi32>) -> tensor<2x3x4x5xi32>
+ %2 = tosa.transpose %1 {perms = array<i32: 0, 2, 3, 1>}: (tensor<2x3x4x5xi32>) -> tensor<2x4x5x3xi32>
return %2 : tensor<2x4x5x3xi32>
}
// -----
// CHECK-LABEL: @test_no_transform_if_outside_fan_in_cone
-// CHECK: tosa.const
// CHECK: %[[CLAMP_IN:.*]] = tosa.transpose
// CHECK: %[[RES2:.*]] = tosa.clamp %[[CLAMP_IN]]
-// CHECK: tosa.const
// CHECK: %[[RES1:.*]] = tosa.transpose
// CHECK: return %[[RES1]], %[[RES2]]
func.func @test_no_transform_if_outside_fan_in_cone(%arg0: tensor<3x3x3x3xi32>) -> (tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>) {
- %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
- %0 = tosa.transpose %arg0, %perms0 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ %0 = tosa.transpose %arg0 {perms = array<i32: 0, 2, 3, 1>} : (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
%clamp = tosa.clamp %0 {max_val = 2147483647 : i32, min_val = 0 : i32} : (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
- %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
- %1 = tosa.transpose %clamp, %perms1 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ %1 = tosa.transpose %clamp {perms = array<i32: 0, 3, 1, 2>} : (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
return %1, %clamp : tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>
}
@@ -298,29 +275,25 @@ func.func @test_two_different_downstream_converge_to_reshape_same_perms(%arg0: t
%shape = tosa.const_shape {value = dense<[1, 64, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
%1 = tosa.reshape %arg0, %shape : (tensor<64xf32>, !tosa.shape<3>) -> tensor<1x64x1xf32>
%2 = tosa.clamp %1 {max_val = 3.40282347E+38 : f32, min_val = 0.000000e+00 : f32} : (tensor<1x64x1xf32>) -> tensor<1x64x1xf32>
- %3 = tosa.transpose %1, %0 : (tensor<1x64x1xf32>, tensor<3xi32>) -> tensor<1x1x64xf32>
- %4 = tosa.transpose %2, %0 : (tensor<1x64x1xf32>, tensor<3xi32>) -> tensor<1x1x64xf32>
+ %3 = tosa.transpose %1 {perms = array<i32: 0, 2, 1>}: (tensor<1x64x1xf32>) -> tensor<1x1x64xf32>
+ %4 = tosa.transpose %2 {perms = array<i32: 0, 2, 1>}: (tensor<1x64x1xf32>) -> tensor<1x1x64xf32>
return %3, %4 : tensor<1x1x64xf32>, tensor<1x1x64xf32>
}
// -----
// CHECK-LABEL: @test_two_different_downstream_converge_to_reshape_different_perms
-// CHECK-DAG: tosa.const
-// CHECK-DAG: tosa.const
// CHECK-DAG: %[[RESHAPE:.*]] = tosa.reshape
// CHECK-DAG: %[[CLAMP:.*]] = tosa.clamp %[[RESHAPE]]
// CHECK-DAG: %[[RET1:.*]] = tosa.transpose
// CHECK-DAG: %[[RET2:.*]] = tosa.transpose
// CHECK-DAG: return %[[RET1]], %[[RET2]]
func.func @test_two_different_downstream_converge_to_reshape_different_perms(%arg0: tensor<64xf32>) -> (tensor<1x1x64xf32>, tensor<64x1x1xf32>) {
- %0 = "tosa.const"() <{value = dense<[1, 2, 0]> : tensor<3xi32>}> : () -> tensor<3xi32>
- %1 = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
%shape = tosa.const_shape {value = dense<[1, 64, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
%2 = tosa.reshape %arg0, %shape : (tensor<64xf32>, !tosa.shape<3>) -> tensor<1x64x1xf32>
%3 = tosa.clamp %2 {max_val = 3.40282347E+38 : f32, min_val = 0.000000e+00 : f32} : (tensor<1x64x1xf32>) -> tensor<1x64x1xf32>
- %4 = tosa.transpose %2, %1 : (tensor<1x64x1xf32>, tensor<3xi32>) -> tensor<1x1x64xf32>
- %5 = tosa.transpose %3, %0 : (tensor<1x64x1xf32>, tensor<3xi32>) -> tensor<64x1x1xf32>
+ %4 = tosa.transpose %2 {perms = array<i32: 0, 2, 1>}: (tensor<1x64x1xf32>) -> tensor<1x1x64xf32>
+ %5 = tosa.transpose %3 {perms = array<i32: 1, 2, 0>}: (tensor<1x64x1xf32>) -> tensor<64x1x1xf32>
return %4, %5 : tensor<1x1x64xf32>, tensor<64x1x1xf32>
}
@@ -328,16 +301,15 @@ func.func @test_two_different_downstream_converge_to_reshape_different_perms(%ar
// COM: no transform
// CHECK-LABEL: @test_outside_perms_usage_of_fan_in
-// CHECK: tosa.const
// CHECK: tosa.transpose
// CHECK: tosa.clamp
// CHECK: %[[RES1:.*]] = tosa.transpose
// CHECK: %[[RES2:.*]] = tosa.add
// CHECK: return %[[RES1]], %[[RES2]]
-func.func @test_outside_perms_usage_of_fan_in(%arg0: tensor<2x3xf32>, %arg1: tensor<3x2xf32>) -> (tensor<2x3xf32>, tensor<3x2xf32>) { %0 = "tosa.const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32>
- %1 = tosa.transpose %arg0, %0 : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
+func.func @test_outside_perms_usage_of_fan_in(%arg0: tensor<2x3xf32>, %arg1: tensor<3x2xf32>) -> (tensor<2x3xf32>, tensor<3x2xf32>) {
+ %1 = tosa.transpose %arg0 {perms = array<i32: 1, 0>}: (tensor<2x3xf32>) -> tensor<3x2xf32>
%2 = tosa.clamp %1 {max_val = 3.40282347E+38 : f32, min_val = 0.000000e+00 : f32} : (tensor<3x2xf32>) -> tensor<3x2xf32>
- %3 = tosa.transpose %2, %0 : (tensor<3x2xf32>, tensor<2xi32>) -> tensor<2x3xf32>
+ %3 = tosa.transpose %2 {perms = array<i32: 1, 0>}: (tensor<3x2xf32>) -> tensor<2x3xf32>
%4 = tosa.add %arg1, %2 : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32>
return %3, %4: tensor<2x3xf32>, tensor<3x2xf32>
}
@@ -351,13 +323,12 @@ func.func @test_outside_perms_usage_of_fan_in(%arg0: tensor<2x3xf32>, %arg1: ten
// CHECK-DAG: %[[NEW_ADD:.*]] = tosa.add %arg1, %[[NEW_CLAMP]]
// CHECK: return %[[NEW_CLAMP]], %[[NEW_ADD]]
func.func @test_use_present_in_another_valid_perms_fan_in(%arg0: tensor<2x3xf32>, %arg1: tensor<2x3xf32>) -> (tensor<2x3xf32>, tensor<2x3xf32>) {
- %0 = "tosa.const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32>
- %1 = tosa.transpose %arg0, %0 : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
+ %1 = tosa.transpose %arg0 {perms = array<i32: 1, 0>}: (tensor<2x3xf32>) -> tensor<3x2xf32>
%2 = tosa.clamp %1 {max_val = 3.40282347E+38 : f32, min_val = 0.000000e+00 : f32} : (tensor<3x2xf32>) -> tensor<3x2xf32>
- %3 = tosa.transpose %2, %0 : (tensor<3x2xf32>, tensor<2xi32>) -> tensor<2x3xf32>
- %4 = tosa.transpose %arg1, %0 : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
+ %3 = tosa.transpose %2 {perms = array<i32: 1, 0>}: (tensor<3x2xf32>) -> tensor<2x3xf32>
+ %4 = tosa.transpose %arg1 {perms = array<i32: 1, 0>}: (tensor<2x3xf32>) -> tensor<3x2xf32>
%5 = tosa.add %4, %2 : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32>
- %6 = tosa.transpose %5, %0 : (tensor<3x2xf32>, tensor<2xi32>) -> tensor<2x3xf32>
+ %6 = tosa.transpose %5 {perms = array<i32: 1, 0>}: (tensor<3x2xf32>) -> tensor<2x3xf32>
return %3, %6: tensor<2x3xf32>, tensor<2x3xf32>
}
@@ -365,7 +336,6 @@ func.func @test_use_present_in_another_valid_perms_fan_in(%arg0: tensor<2x3xf32>
// COM: no transform, since we would get duplicates
// CHECK-LABEL: @test_two_same_perms_fan_in_but_one_doesnt_convert_dependents
-// CHECK: tosa.const
// CHECK: tosa.transpose
// CHECK: %[[CEIL:.*]] = tosa.ceil
// CHECK: %[[ADD:.*]] = tosa.add %[[CEIL]]
@@ -373,12 +343,11 @@ func.func @test_use_present_in_another_valid_perms_fan_in(%arg0: tensor<2x3xf32>
// CHECK: %[[RES2:.*]] = tosa.transpose %[[ADD]]
// CHECK: return %[[RES1]], %[[RES2]]
func.func @test_two_same_perms_fan_in_but_one_doesnt_convert_dependents(%arg0: tensor<2x3xi32>, %arg1: tensor<3x2xi32>) -> (tensor<2x3xi32>, tensor<2x3xi32>) {
- %0 = "tosa.const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32>
- %1 = tosa.transpose %arg0, %0 : (tensor<2x3xi32>, tensor<2xi32>) -> tensor<3x2xi32>
+ %1 = tosa.transpose %arg0 {perms = array<i32: 1, 0>}: (tensor<2x3xi32>) -> tensor<3x2xi32>
%2 = tosa.ceil %1 : (tensor<3x2xi32>) -> tensor<3x2xi32>
%3 = tosa.add %2, %arg1 : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
- %4 = tosa.transpose %2, %0 : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
- %5 = tosa.transpose %3, %0 : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
+ %4 = tosa.transpose %2 {perms = array<i32: 1, 0>}: (tensor<3x2xi32>) -> tensor<2x3xi32>
+ %5 = tosa.transpose %3 {perms = array<i32: 1, 0>}: (tensor<3x2xi32>) -> tensor<2x3xi32>
return %4, %5 : tensor<2x3xi32>, tensor<2x3xi32>
}
@@ -388,12 +357,10 @@ func.func @test_two_same_perms_fan_in_but_one_doesnt_convert_dependents(%arg0: t
// CHECK-NEXT: %[[RES:.*]] = tosa.clamp %arg0
// CHECK-NEXT: return %[[RES]], %[[RES]]
func.func @test_direct_use_in_other_transpose_with_same_perms(%arg0: tensor<3x3x3x3xi32>) -> (tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>) {
- %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
- %0 = tosa.transpose %arg0, %perms0 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ %0 = tosa.transpose %arg0 {perms = array<i32: 0, 2, 3, 1>}: (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
%clamp = tosa.clamp %0 {max_val = 2147483647 : i32, min_val = 0 : i32} : (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
- %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
- %1 = tosa.transpose %clamp, %perms1 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
- %2 = tosa.transpose %clamp, %perms1 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ %1 = tosa.transpose %clamp {perms = array<i32: 0, 3, 1, 2>}: (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
+ %2 = tosa.transpose %clamp {perms = array<i32: 0, 3, 1, 2>}: (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
return %1, %2 : tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>
}
@@ -405,8 +372,7 @@ func.func @test_direct_use_in_other_transpose_with_same_perms(%arg0: tensor<3x3x
// CHECK: return %[[NEW]]
func.func @test_const_transpose() -> tensor<2x3xi32> {
%0 = "tosa.const"() {value = dense<0> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
- %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
- %1 = tosa.transpose %0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
+ %1 = tosa.transpose %0 {perms = array<i32: 1, 0>}: (tensor<3x2xi32>) -> tensor<2x3xi32>
return %1 : tensor<2x3xi32>
}
@@ -420,8 +386,7 @@ func.func @test_const_transpose() -> tensor<2x3xi32> {
func.func @test_transpose_tracks_to_const_single_step() -> tensor<1x2x3x4xi32> {
%0 = "tosa.const"() {value = dense<0> : tensor<1x3x4x2xi32>} : () -> tensor<1x3x4x2xi32>
%clamp = tosa.clamp %0 {max_val = 2147483647 : i32, min_val = 0 : i32} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
- %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
- %1 = tosa.transpose %clamp, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ %1 = tosa.transpose %clamp {perms = array<i32: 0, 3, 1, 2>}: (tensor<1x3x4x2xi32>) -> tensor<1x2x3x4xi32>
return %1 : tensor<1x2x3x4xi32>
}
@@ -434,13 +399,11 @@ func.func @test_transpose_tracks_to_const_single_step() -> tensor<1x2x3x4xi32> {
// CHECK: %[[NEW_NOT:.*]] = tosa.bitwise_not %[[NEW_ABS]] : (tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32>
// CHECK: return %[[NEW_NOT]]
func.func @test_static_unary_path_to_const() -> tensor<1x2x3x4xi32> {
- %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
%0 = "tosa.const"() {value = dense<1> : tensor<1x3x4x2xi32>} : () -> tensor<1x3x4x2xi32>
%clamp = tosa.clamp %0 {max_val = 2147483647 : i32, min_val = 0 : i32} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
%abs = tosa.abs %clamp : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
%bitwise_not = tosa.bitwise_not %abs : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
- %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
- %1 = tosa.transpose %bitwise_not, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ %1 = tosa.transpose %bitwise_not {perms = array<i32: 0, 3, 1, 2>}: (tensor<1x3x4x2xi32>) -> tensor<1x2x3x4xi32>
return %1 : tensor<1x2x3x4xi32>
}
@@ -455,16 +418,14 @@ func.func @test_static_unary_path_to_const() -> tensor<1x2x3x4xi32> {
// CHECK: %[[NEW_ADD:.*]] = tosa.add %[[NEW_ABS]], %[[NEW_CLAMP]] : (tensor<1x2x3x4xi32>, tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32>
// CHECK: return %[[NEW_ADD]]
func.func @test_static_diverges_to_non_splat_const_and_nullifying(%arg0: tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> {
- %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
- %transpose0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %transpose0 = tosa.transpose %arg0 {perms = array<i32: 0, 2, 3, 1>}: (tensor<1x2x3x4xi32>) -> tensor<1x3x4x2xi32>
%const = "tosa.const"() {value = dense<[[[[1, 2], [3, 4], [5, 6], [7, 8]],
[[9, 10], [11, 12], [13, 14], [15, 16]],
[[17, 18], [19, 20], [21, 22], [23, 24]]]]> : tensor<1x3x4x2xi32>} : () -> tensor<1x3x4x2xi32>
%clamp = tosa.clamp %transpose0 {max_val = 2147483647 : i32, min_val = 0 : i32} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
%abs = tosa.abs %const : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
%add = tosa.add %abs, %clamp : (tensor<1x3x4x2xi32>, tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
- %perms2 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
- %result = tosa.transpose %add, %perms2 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ %result = tosa.transpose %add {perms = array<i32: 0, 3, 1, 2>}: (tensor<1x3x4x2xi32>) -> tensor<1x2x3x4xi32>
return %result : tensor<1x2x3x4xi32>
}
@@ -474,12 +435,10 @@ func.func @test_static_diverges_to_non_splat_const_and_nullifying(%arg0: tensor<
// CHECK-NEXT: %[[RES:.*]] = tosa.clamp %arg0
// CHECK-NEXT: return %[[RES]], %[[RES]]
func.func @test_multi_downstream_both_nullify(%arg0: tensor<3x3x3x3xi32>) -> (tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>) {
- %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
- %0 = tosa.transpose %arg0, %perms0 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ %0 = tosa.transpose %arg0 {perms = array<i32: 0, 2, 3, 1>}: (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
%clamp = tosa.clamp %0 {max_val = 2147483647 : i32, min_val = 0 : i32} : (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
- %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
- %1 = tosa.transpose %clamp, %perms1 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
- %2 = tosa.transpose %clamp, %perms1 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ %1 = tosa.transpose %clamp {perms = array<i32: 0, 3, 1, 2>}: (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
+ %2 = tosa.transpose %clamp {perms = array<i32: 0, 3, 1, 2>}: (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
return %1, %2 : tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>
}
@@ -487,19 +446,15 @@ func.func @test_multi_downstream_both_nullify(%arg0: tensor<3x3x3x3xi32>) -> (te
// COM: we don't perform this transformation intentionally, since we would then get duplicates
// CHECK-LABEL: @test_multi_downstream_one_nullifies_upstream_other_does_not
-// CHECK: tosa.const
// CHECK: tosa.transpose
// CHECK: tosa.clamp
-// CHECK: tosa.const
// CHECK: tosa.transpose
// CHECK: tosa.transpose
func.func @test_multi_downstream_one_nullifies_upstream_other_does_not(%arg0: tensor<3x3x3x3xi32>) -> (tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>) {
- %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
- %0 = tosa.transpose %arg0, %perms0 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ %0 = tosa.transpose %arg0 {perms = array<i32: 0, 2, 3, 1>}: (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
%clamp = tosa.clamp %0 {max_val = 2147483647 : i32, min_val = 0 : i32} : (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
- %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
- %1 = tosa.transpose %clamp, %perms1 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
- %2 = tosa.transpose %clamp, %perms0 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ %1 = tosa.transpose %clamp {perms = array<i32: 0, 3, 1, 2>}: (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
+ %2 = tosa.transpose %clamp {perms = array<i32: 0, 2, 3, 1>}: (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
return %1, %2 : tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>
}
@@ -508,9 +463,8 @@ func.func @test_multi_downstream_one_nullifies_upstream_other_does_not(%arg0: te
// CHECK-LABEL: @test_unknown_dim_inner_replacement_matches
// CHECK-NEXT: return %arg0
func.func @test_unknown_dim_inner_replacement_matches(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
- %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
- %0 = tosa.transpose %arg0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<?x3xi32>
- %1 = tosa.transpose %0, %perms : (tensor<?x3xi32>, tensor<2xi32>) -> tensor<3x2xi32>
+ %0 = tosa.transpose %arg0 {perms = array<i32: 1, 0>}: (tensor<3x2xi32>) -> tensor<?x3xi32>
+ %1 = tosa.transpose %0 {perms = array<i32: 1, 0>}: (tensor<?x3xi32>) -> tensor<3x2xi32>
return %1 : tensor<3x2xi32>
}
@@ -520,9 +474,8 @@ func.func @test_unknown_dim_inner_replacement_matches(%arg0: tensor<3x2xi32>) ->
// CHECK-LABEL: @test_unknown_dim_outer_replacement_matches
// CHECK-NEXT: return %arg0
func.func @test_unknown_dim_outer_replacement_matches(%arg0: tensor<3x?xi32>) -> tensor<3x?xi32> {
- %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
- %0 = tosa.transpose %arg0, %perms : (tensor<3x?xi32>, tensor<2xi32>) -> tensor<2x3xi32>
- %1 = tosa.transpose %0, %perms : (tensor<2x3xi32>, tensor<2xi32>) -> tensor<3x?xi32>
+ %0 = tosa.transpose %arg0 {perms = array<i32: 1, 0>}: (tensor<3x?xi32>) -> tensor<2x3xi32>
+ %1 = tosa.transpose %0 {perms = array<i32: 1, 0>}: (tensor<2x3xi32>) -> tensor<3x?xi32>
return %1 : tensor<3x?xi32>
}
@@ -534,47 +487,28 @@ func.func @test_unknown_dim_outer_replacement_matches(%arg0: tensor<3x?xi32>) ->
// CHECK-NEXT: %[[ADD:.*]] = tosa.add %[[CLAMP]], %[[ABS]]
// CHECK-NEXT: return %[[ADD]]
func.func @test_transpose_tracks_to_nullifying_diverging_binary_unknown_dim_replacements_match(%arg0: tensor<1x?x3x4xi32>, %arg1: tensor<1x2x?x4xi32>) -> tensor<1x2x3x4xi32> {
- %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
- %transpose0 = tosa.transpose %arg0, %perms0 : (tensor<1x?x3x4xi32>, tensor<4xi32>) -> tensor<?x3x4x?xi32>
- %transpose1 = tosa.transpose %arg1, %perms0 : (tensor<1x2x?x4xi32>, tensor<4xi32>) -> tensor<1x?x?x2xi32>
+ %transpose0 = tosa.transpose %arg0 {perms = array<i32: 0, 2, 3, 1>}: (tensor<1x?x3x4xi32>) -> tensor<?x3x4x?xi32>
+ %transpose1 = tosa.transpose %arg1 {perms = array<i32: 0, 2, 3, 1>}: (tensor<1x2x?x4xi32>) -> tensor<1x?x?x2xi32>
%clamp = tosa.clamp %transpose0 {min_val = 0 : i32, max_val = 1 : i32} : (tensor<?x3x4x?xi32>) -> tensor<?x3x4x?xi32>
%abs = tosa.abs %transpose1 : (tensor<1x?x?x2xi32>) -> tensor<1x?x?x2xi32>
%add = tosa.add %clamp, %abs : (tensor<?x3x4x?xi32>, tensor<1x?x?x2xi32>) -> tensor<1x3x4x2xi32>
- %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
- %result = tosa.transpose %add, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ %result = tosa.transpose %add {perms = array<i32: 0, 3, 1, 2>}: (tensor<1x3x4x2xi32>) -> tensor<1x2x3x4xi32>
return %result : tensor<1x2x3x4xi32>
}
// -----
-// COM: we cannot do anything to the transpose in this case.
-// CHECK-LABEL: @test_unimplemented_non_const_perms
-// CHECK: tosa.const
-// CHECK-NEXT: tosa.transpose
-// CHECK-NEXT: return
-func.func @test_unimplemented_non_const_perms(%perms: tensor<2xi32>) -> tensor<?x?xi32> {
- %0 = "tosa.const"() {value = dense<0> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
- %1 = tosa.transpose %0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<?x?xi32>
- return %1 : tensor<?x?xi32>
-}
-
-// -----
-
// COM: due to tracking back to a non-nullifying transpose, we can't get rid of the transposes entirely.
// COM: later editions of the pass may wish to fold these into a single transpose.
// CHECK-LABEL: @test_unimplemented_transpose_tracks_to_non_nullifying_transpose_single_step
-// CHECK: tosa.const
// CHECK-NEXT: tosa.transpose
// CHECK-NEXT: tosa.clamp
-// CHECK-NEXT: tosa.const
// CHECK-NEXT: tosa.transpose
// CHECK-NEXT: return
func.func @test_unimplemented_transpose_tracks_to_non_nullifying_transpose_single_step(%arg0: tensor<1x2x3x4xi32>) -> tensor<1x2x4x3xi32> {
- %perms0 = "tosa.const"() {value = dense<[0, 3, 2, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
- %0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x4x3x2xi32>
+ %0 = tosa.transpose %arg0 {perms = array<i32: 0, 3, 2, 1>}: (tensor<1x2x3x4xi32>) -> tensor<1x4x3x2xi32>
%clamp = tosa.clamp %0 {min_val = 0 : i32, max_val = 1 : i32} : (tensor<1x4x3x2xi32>) -> tensor<1x4x3x2xi32>
- %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
- %1 = tosa.transpose %clamp, %perms1 : (tensor<1x4x3x2xi32>, tensor<4xi32>) -> tensor<1x2x4x3xi32>
+ %1 = tosa.transpose %clamp {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x4x3x2xi32>) -> tensor<1x2x4x3xi32>
return %1 : tensor<1x2x4x3xi32>
}
@@ -582,28 +516,24 @@ func.func @test_unimplemented_transpose_tracks_to_non_nullifying_transpose_singl
// COM: we don't deal with this case. resolution of shapes required.
// CHECK-LABEL: @test_unimplemented_unknown_dim_input_nullifying_pair
-// CHECK: tosa.const
// CHECK-NEXT: tosa.transpose
// CHECK-NEXT: tosa.transpose
// CHECK-NEXT: return
func.func @test_unimplemented_unknown_dim_input_nullifying_pair(%arg0: tensor<3x?xi32>) -> tensor<3x2xi32> {
- %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
- %0 = tosa.transpose %arg0, %perms : (tensor<3x?xi32>, tensor<2xi32>) -> tensor<2x3xi32>
- %1 = tosa.transpose %0, %perms : (tensor<2x3xi32>, tensor<2xi32>) -> tensor<3x2xi32>
+ %0 = tosa.transpose %arg0 {perms = array<i32: 1, 0>}: (tensor<3x?xi32>) -> tensor<2x3xi32>
+ %1 = tosa.transpose %0 {perms = array<i32: 1, 0>}: (tensor<2x3xi32>) -> tensor<3x2xi32>
return %1 : tensor<3x2xi32>
}
// -----
// CHECK-LABEL: @test_unimplemented_unknown_dim_replacement_does_not_match
-// CHECK: tosa.const
// CHECK-NEXT: tosa.transpose
// CHECK-NEXT: tosa.transpose
// CHECK-NEXT: return
func.func @test_unimplemented_unknown_dim_replacement_does_not_match(%arg0: tensor<3x?xi32>) -> tensor<?x?xi32> {
- %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
- %0 = tosa.transpose %arg0, %perms : (tensor<3x?xi32>, tensor<2xi32>) -> tensor<?x3xi32>
- %1 = tosa.transpose %0, %perms : (tensor<?x3xi32>, tensor<2xi32>) -> tensor<?x?xi32>
+ %0 = tosa.transpose %arg0 {perms = array<i32: 1, 0>}: (tensor<3x?xi32>) -> tensor<?x3xi32>
+ %1 = tosa.transpose %0 {perms = array<i32: 1, 0>}: (tensor<?x3xi32>) -> tensor<?x?xi32>
return %1 : tensor<?x?xi32>
}
@@ -611,53 +541,43 @@ func.func @test_unimplemented_unknown_dim_replacement_does_not_match(%arg0: tens
// COM: this would be able to be converted if --tosa-infer-shapes was run beforehand
// CHECK-LABEL: @test_unimplemented_unranked_tensors_present
-// CHECK: tosa.const
// CHECK-NEXT: tosa.transpose
// CHECK-NEXT: tosa.transpose
// CHECK-NEXT: return
func.func @test_unimplemented_unranked_tensors_present(%arg0: tensor<3x2xi32>) -> tensor<*xi32> {
- %perms = "tosa.const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> tensor<2xi32>
- %0 = tosa.transpose %arg0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<*xi32>
- %1 = tosa.transpose %0, %perms : (tensor<*xi32>, tensor<2xi32>) -> tensor<*xi32>
+ %0 = tosa.transpose %arg0 {perms = array<i32: 0, 1>}: (tensor<3x2xi32>) -> tensor<*xi32>
+ %1 = tosa.transpose %0 {perms = array<i32: 0, 1>}: (tensor<*xi32>) -> tensor<*xi32>
return %1 : tensor<*xi32>
}
// -----
// CHECK-LABEL: @test_unimplemented_unranked_everything
-// CHECK: tosa.const
// CHECK-NEXT: tosa.transpose
// CHECK-NEXT: tosa.transpose
// CHECK-NEXT: return
func.func @test_unimplemented_unranked_everything(%arg0: tensor<*xi32>) -> tensor<*xi32> {
- %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
- %0 = tosa.transpose %arg0, %perms : (tensor<*xi32>, tensor<2xi32>) -> tensor<*xi32>
- %1 = tosa.transpose %0, %perms : (tensor<*xi32>, tensor<2xi32>) -> tensor<*xi32>
+ %0 = tosa.transpose %arg0 {perms = array<i32: 1, 0>}: (tensor<*xi32>) -> tensor<*xi32>
+ %1 = tosa.transpose %0 {perms = array<i32: 1, 0>}: (tensor<*xi32>) -> tensor<*xi32>
return %1 : tensor<*xi32>
}
// -----
// CHECK-LABEL: @test_unimplemented_static_diverges_to_one_nullifying_one_non_nullifying
-// CHECK: tosa.const
-// CHECK-NEXT: tosa.const
// CHECK-NEXT: tosa.transpose
// CHECK-NEXT: tosa.transpose
// CHECK-NEXT: tosa.clamp
// CHECK-NEXT: tosa.abs
// CHECK-NEXT: tosa.add
-// CHECK-NEXT: tosa.const
// CHECK-NEXT: tosa.transpose
// CHECK-NEXT: return
func.func @test_unimplemented_static_diverges_to_one_nullifying_one_non_nullifying(%arg0: tensor<1x2x3x4xi32>, %arg1: tensor<1x2x4x3xi32>) -> tensor<1x2x3x4xi32> {
- %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
- %perms1 = "tosa.const"() {value = dense<[0, 3, 2, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
- %transpose0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
- %transpose1 = tosa.transpose %arg1, %perms1 : (tensor<1x2x4x3xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %transpose0 = tosa.transpose %arg0 {perms = array<i32: 0, 2, 3, 1>}: (tensor<1x2x3x4xi32>) -> tensor<1x3x4x2xi32>
+ %transpose1 = tosa.transpose %arg1 {perms = array<i32: 0, 3, 2, 1>}: (tensor<1x2x4x3xi32>) -> tensor<1x3x4x2xi32>
%clamp = tosa.clamp %transpose0 {min_val = 0 : i32, max_val = 1 : i32} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
%abs = tosa.abs %transpose1 : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
%add = tosa.add %clamp, %abs : (tensor<1x3x4x2xi32>, tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
- %perms2 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
- %result = tosa.transpose %add, %perms2 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ %result = tosa.transpose %add {perms = array<i32: 0, 3, 1, 2>}: (tensor<1x3x4x2xi32>) -> tensor<1x2x3x4xi32>
return %result : tensor<1x2x3x4xi32>
}
diff --git a/mlir/test/Dialect/Tosa/transpose-fold.mlir b/mlir/test/Dialect/Tosa/transpose-fold.mlir
index a7e9982c8db70..d29cb54eb2117 100644
--- a/mlir/test/Dialect/Tosa/transpose-fold.mlir
+++ b/mlir/test/Dialect/Tosa/transpose-fold.mlir
@@ -6,10 +6,8 @@
// CHECK: }
func.func @test_cancel_transpose_transpose(%arg0: tensor<1x2x3xi32>) -> (tensor<1x2x3xi32>) {
- %0 = "tosa.const"() {value = dense<[1, 2, 0]> : tensor<3xi32>} : () -> tensor<3xi32>
- %1 = tosa.transpose %arg0, %0 : (tensor<1x2x3xi32>, tensor<3xi32>) -> tensor<2x3x1xi32>
- %2 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
- %3 = tosa.transpose %1, %2 : (tensor<2x3x1xi32>, tensor<3xi32>) -> tensor<1x2x3xi32>
+ %1 = tosa.transpose %arg0 { perms = array<i32: 1, 2, 0> }: (tensor<1x2x3xi32>) -> tensor<2x3x1xi32>
+ %3 = tosa.transpose %1 { perms = array<i32: 2, 0, 1> }: (tensor<2x3x1xi32>) -> tensor<1x2x3xi32>
return %3 : tensor<1x2x3xi32>
}
@@ -21,8 +19,7 @@ func.func @test_cancel_transpose_transpose(%arg0: tensor<1x2x3xi32>) -> (tensor<
// CHECK: }
func.func @test_remove_identity_transpose(%arg0: tensor<1x2x3xi32>) -> (tensor<1x2x3xi32>) {
- %0 = "tosa.const"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
- %1 = tosa.transpose %arg0, %0 : (tensor<1x2x3xi32>, tensor<3xi32>) -> tensor<1x2x3xi32>
+ %1 = tosa.transpose %arg0 { perms = array<i32: 0, 1, 2> }: (tensor<1x2x3xi32>) -> tensor<1x2x3xi32>
return %1 : tensor<1x2x3xi32>
}
@@ -30,16 +27,13 @@ func.func @test_remove_identity_transpose(%arg0: tensor<1x2x3xi32>) -> (tensor<1
// CHECK-LABEL: func.func @test_do_not_cancel_different_transpose(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x3x4x5xi32>) -> tensor<5x4x3x2xi32> {
-// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<[3, 2, 1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
-// CHECK: %[[VAL_2:.*]] = tosa.transpose %[[VAL_0]], %[[VAL_1]] : (tensor<2x3x4x5xi32>, tensor<4xi32>) -> tensor<5x4x3x2xi32>
+// CHECK: %[[VAL_2:.*]] = tosa.transpose %[[VAL_0]] {perms = array<i32: 3, 2, 1, 0>} : (tensor<2x3x4x5xi32>) -> tensor<5x4x3x2xi32>
// CHECK: return %[[VAL_2]] : tensor<5x4x3x2xi32>
// CHECK: }
func.func @test_do_not_cancel_different_transpose(%arg0: tensor<2x3x4x5xi32>) -> (tensor<5x4x3x2xi32>) {
- %0 = "tosa.const"() {value = dense<[1, 2, 0, 3]> : tensor<4xi32>} : () -> tensor<4xi32>
- %1 = tosa.transpose %arg0, %0 : (tensor<2x3x4x5xi32>, tensor<4xi32>) -> tensor<3x4x2x5xi32>
- %2 = "tosa.const"() {value = dense<[3, 1, 0, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
- %3 = tosa.transpose %1, %2 : (tensor<3x4x2x5xi32>, tensor<4xi32>) -> tensor<5x4x3x2xi32>
+ %1 = tosa.transpose %arg0 { perms = array<i32: 1, 2, 0, 3> }: (tensor<2x3x4x5xi32>) -> tensor<3x4x2x5xi32>
+ %3 = tosa.transpose %1 { perms = array<i32: 3, 1, 0, 2> }: (tensor<3x4x2x5xi32>) -> tensor<5x4x3x2xi32>
return %3 : tensor<5x4x3x2xi32>
}
@@ -47,15 +41,12 @@ func.func @test_do_not_cancel_different_transpose(%arg0: tensor<2x3x4x5xi32>) ->
// CHECK-LABEL: func.func @test_prefer_compose_transpose(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3x4xi32>) -> tensor<4x3x2x1xi32> {
-// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<[3, 2, 1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
-// CHECK: %[[VAL_2:.*]] = tosa.transpose %[[VAL_0]], %[[VAL_1]] : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<4x3x2x1xi32>
+// CHECK: %[[VAL_2:.*]] = tosa.transpose %[[VAL_0]] {perms = array<i32: 3, 2, 1, 0>} : (tensor<1x2x3x4xi32>) -> tensor<4x3x2x1xi32>
// CHECK: return %[[VAL_2]] : tensor<4x3x2x1xi32>
// CHECK: }
func.func @test_prefer_compose_transpose(%arg0: tensor<1x2x3x4xi32>) -> (tensor<4x3x2x1xi32>) {
- %0 = "tosa.const"() {value = dense<[1, 2, 0, 3]> : tensor<4xi32>} : () -> tensor<4xi32>
- %1 = tosa.transpose %arg0, %0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<2x3x1x4xi32>
- %2 = "tosa.const"() {value = dense<[3, 1, 0, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
- %3 = tosa.transpose %1, %2 : (tensor<2x3x1x4xi32>, tensor<4xi32>) -> tensor<4x3x2x1xi32>
+ %1 = tosa.transpose %arg0 { perms = array<i32: 1, 2, 0, 3> }: (tensor<1x2x3x4xi32>) -> tensor<2x3x1x4xi32>
+ %3 = tosa.transpose %1 { perms = array<i32: 3, 1, 0, 2> }: (tensor<2x3x1x4xi32>) -> tensor<4x3x2x1xi32>
return %3 : tensor<4x3x2x1xi32>
}
More information about the Mlir-commits
mailing list