[Mlir-commits] [mlir] [MLIR][TOSA] Add --tosa-remove-redundant-transposes pass (PR #108260)
Arteen Abrishami
llvmlistbot at llvm.org
Fri Sep 13 13:49:45 PDT 2024
https://github.com/arteen1000 updated https://github.com/llvm/llvm-project/pull/108260
>From 509076ed101a0b533647112c23cacda8c5f59282 Mon Sep 17 00:00:00 2001
From: Arteen Abrishami <arteen.abrishami at arm.com>
Date: Thu, 12 Sep 2024 20:18:29 +0000
Subject: [PATCH] [MLIR][TOSA] Add --tosa-reduce-transposes pass
----------
Motivation:
----------
Some legalization pathways introduce redundant tosa.TRANSPOSE
operations that result in avoidable data movement. For example,
PyTorch -> TOSA contains a lot of unnecessary transposes due
to conversions between NCHW and NHWC.
We wish to remove all the ones that we can, since in general
it is possible to remove the overwhelming majority.
------------
Changes Made:
------------
- Add the --tosa-reduce-transposes pass
- Add TosaElementwiseOperator trait.
-------------------
High-Level Overview:
-------------------
The pass works through the transpose operators in the program. It begins at some
transpose operator with an associated permutations tensor. It traverses upwards
through the dependencies of this transpose and verifies that we encounter only
operators with the TosaElementwiseOperator trait and terminate in either
constants, reshapes, or transposes.
We then evaluate whether there are any additional restrictions (the transposes
it terminates in must invert the one we began at, and the reshapes must be ones
in which we can fold the transpose into), and then we hoist the transpose through
the intervening operators, folding it at the constants, reshapes, and transposes.
Finally, we ensure that we do not need both the transposed form (the form that
had the transpose hoisted through it) and the untransposed form (which it was prior),
by analyzing the usages of those dependent operators of a given transpose we are
attempting to hoist and replace.
If they are such that it would require both forms to be necessary, then we do not
replace the hoisted transpose, causing the new chain to be dead. Otherwise, we do
and the old chain (untransposed form) becomes dead. Only one chain will ever then
be live, resulting in no duplication.
We then perform a simple one-pass DCE, so no canonicalization is necessary.
--------------
Impact of Pass:
--------------
Patching the dense_resource artifacts (from PyTorch) with dense attributes to
permit constant folding, we receive the following results.
Note that data movement represents total transpose data movement, calculated
by noting which dimensions moved during the transpose.
///////////
MobilenetV3:
///////////
BEFORE total data movement: 11798776 B (11.25 MiB)
AFTER total data movement: 2998016 B (2.86 MiB)
74.6% of data movement removed.
BEFORE transposes: 82
AFTER transposes: 20
75.6% of transposes removed.
////////
ResNet18:
////////
BEFORE total data movement: 20596556 B (19.64 MiB)
AFTER total data movement: 1003520 B (0.96 MiB)
95.2% of data movement removed.
BEFORE transposes: 56
AFTER transposes: 5
91.1% of transposes removed.
////////
ResNet50:
////////
BEFORE total data movement: 83236172 B (79.3 MiB)
AFTER total data movement: 3010560 B (2.87 MiB)
96.4% of data movement removed
BEFORE transposes: 120
AFTER transposes: 7
94.2% of transposes removed.
/////////
ResNet101:
/////////
BEFORE total data movement: 124336460 B (118.58 MiB)
AFTER total data movement: 3010560 B (2.87 MiB)
97.6% of data movement removed
BEFORE transposes: 239
AFTER transposes: 7
97.1% of transposes removed.
/////////
ResNet152:
/////////
BEFORE total data movement: 175052108 B (166.94 MiB)
AFTER total data movement: 3010560 B (2.87 MiB)
98.3% of data movement removed
BEFORE transposes: 358
AFTER transposes: 7
98.0% of transposes removed.
////////
Overview:
////////
We see that we remove up to 98% of transposes and eliminate
up to 98.3% of redundant transpose data movement.
In the context of ResNet50, with 120 inferences per second,
we reduce dynamic transpose data bandwidth from 9.29 GiB/s
to 344.4 MiB/s.
-----------
Future Work:
-----------
(1) Evaluate tradeoffs with permitting ConstOp to be duplicated across hoisted
transposes with different permutation tensors.
(2) Expand the class of foldable upstream ReshapeOp we permit beyond
N -> 1x1x...x1xNx1x...x1x1.
(3) Enchance the pass to permit folding arbitrary transpose pairs, beyond
those that form the identity.
(4) Add support for more instructions besides TosaElementwiseOperator as
the intervening ones (for example, the reduce_* operators).
(5) Support hoisting transposes up to an input parameter.
Signed-off-by: Arteen Abrishami <arteen.abrishami at arm.com>
---
.../mlir/Dialect/Tosa/IR/TosaOpBase.td | 10 +
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h | 6 +
.../mlir/Dialect/Tosa/Transforms/Passes.td | 21 +
.../Dialect/Tosa/Transforms/CMakeLists.txt | 1 +
.../Tosa/Transforms/TosaReduceTransposes.cpp | 693 ++++++++++++++++++
.../Dialect/Tosa/tosa-reduce-transposes.mlir | 649 ++++++++++++++++
6 files changed, 1380 insertions(+)
create mode 100644 mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
create mode 100644 mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 1412c7a2615d20..64bacd0e432fe5 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -206,6 +206,15 @@ def Tosa_ExplicitValuePadOpQuantInfoBuilder : OpBuilder<
input, paddings, pad_value);
}]>;
+//===----------------------------------------------------------------------===//
+// TOSA Operator Trait.
+//===----------------------------------------------------------------------===//
+
+// Permits broadcasting. Elementwise trait is too strict.
+def TosaElementwiseOperator : NativeOpTrait<"TosaElementwiseOperator"> {
+ let cppNamespace = "mlir::OpTrait::tosa";
+}
+
//===----------------------------------------------------------------------===//
// TOSA Operator Class.
//===----------------------------------------------------------------------===//
@@ -219,6 +228,7 @@ class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
ResultsBroadcastableShape,
+ TosaElementwiseOperator,
Pure])> {
let assemblyFormat =
"operands attr-dict `:` functional-type(operands, results)";
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index 7ed89bff474a2e..66512cbe350ec8 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -84,6 +84,12 @@ class MulOperandsAndResultElementType
}
};
+/// This class indicates that an op is tosa-elementwise (permits broadcasting,
+/// unlike Elementwise trait).
+template <typename ConcreteType>
+class TosaElementwiseOperator
+ : public TraitBase<ConcreteType, TosaElementwiseOperator> {};
+
} // namespace tosa
} // namespace OpTrait
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index a0f670de20150f..c0352fa88fe08d 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -126,4 +126,25 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
];
}
+def TosaReduceTransposes : Pass<"tosa-reduce-transposes", "func::FuncOp"> {
+ let summary = "Reduce transposes through other operators";
+ let description = [{
+ Pass that identifies and reduces tosa.TRANSPOSE operations through chains
+ of operators.
+
+ The pass traverses dependencies of tosa.TRANSPOSE operations until they
+ terminate in either a tosa.RESHAPE that we can fold the hoisted
+ tosa.TRANSPOSE into, a tosa.TRANSPOSE that forms the identity with the
+ hoisted one, or a tosa.CONST with a dense elements attribute. It then
+ propagates the hoisted transform upward through the intervening operators
+ if the support is implemented. Finally, it observes that no duplication
+ will occur of both the chain that was hoisted through and the new chain
+ that results, and if so, it replaces the hoisted tosa.TRANSPOSE.
+
+ The pass has an important use-case in cleaning up the results of frameworks
+ that introduce a lot of data-layout transformations when legalizing to TOSA,
+ a common one being transformations between NHWC and NCHW layouts.
+ }];
+}
+
#endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index c78a74b874aff1..5b0f5ec4cd5687 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
TosaLayerwiseConstantFoldPass.cpp
TosaMakeBroadcastable.cpp
TosaOptionalDecompositions.cpp
+ TosaReduceTransposes.cpp
TosaTypeConverters.cpp
TosaValidation.cpp
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
new file mode 100644
index 00000000000000..6911ecd50de4ea
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
@@ -0,0 +1,693 @@
+//===- TosaReduceTransposes.cpp -------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+// ----------
+// Motivation:
+// ----------
+
+// Some legalization pathways introduce redundant tosa.TRANSPOSE
+// operations that result in avoidable data movement. For example,
+// PyTorch -> TOSA contains a lot of unnecessary transposes due
+// to conversions between NCHW and NHWC.
+
+// We wish to remove all the ones that we can, since in general
+// it is possible to remove the overwhelming majority.
+
+// -------------------
+// High-Level Overview:
+// -------------------
+
+// The pass works through the transpose operators in the program. It begins at
+// some transpose operator with an associated permutations tensor. It traverses
+// upwards through the dependencies of this transpose and verifies that we
+// encounter only operators with the TosaElementwiseOperator trait and terminate
+// in either constants, reshapes, or transposes.
+
+// We then evaluate whether there are any additional restrictions (the
+// transposes it terminates in must invert the one we began at, and the reshapes
+// must be ones in which we can fold the transpose into), and then we hoist the
+// transpose through the intervening operators, folding it at the constants,
+// reshapes, and transposes.
+
+// Finally, we ensure that we do not need both the transposed form (the form
+// that had the transpose hoisted through it) and the untransposed form (which
+// it was prior), by analyzing the usages of those dependent operators of a
+// given transpose we are attempting to hoist and replace.
+
+// If they are such that it would require both forms to be necessary, then we do
+// not replace the hoisted transpose, causing the new chain to be dead.
+// Otherwise, we do and the old chain (untransposed form) becomes dead. Only one
+// chain will ever then be live, resulting in no duplication.
+
+// We then perform a simple one-pass DCE, so no canonicalization is necessary.
+
+// -----------
+// Future Work:
+// -----------
+
+// (1) Evaluate tradeoffs with permitting ConstOp to be duplicated across
+// hoisted
+// transposes with different permutation tensors.
+
+// (2) Expand the class of foldable upstream ReshapeOp we permit beyond
+// N -> 1x1x...x1xNx1x...x1x1.
+
+// (3) Enchance the pass to permit folding arbitrary transpose pairs, beyond
+// those that form the identity.
+
+// (4) Add support for more instructions besides TosaElementwiseOperator as
+// the intervening ones (for example, the reduce_* operators).
+
+// (5) Support hoisting transposes up to an input parameter.
+
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/IR/Iterators.h"
+#include "mlir/IR/Matchers.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include <memory>
+#include <set>
+#include <stack>
+
+namespace mlir {
+namespace tosa {
+#define GEN_PASS_DEF_TOSAREDUCETRANSPOSES
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+} // namespace tosa
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+//===----------------------------------------------------------------------===//
+// TOSA Reduce Transposes Pass.
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+struct TosaReduceTransposes final
+ : public tosa::impl::TosaReduceTransposesBase<TosaReduceTransposes> {
+ void runOnOperation() override;
+
+private:
+ // This will collect all the data dependencies for the given Operation
+ // up to and including ConstOp, ReshapeOp, and TransposeOp.
+ bool collectFanIn(Operation *op, SetVector<Operation *> &collected);
+ bool convertDependentOps(SetVector<Operation *> &dependentOps,
+ DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter,
+ ArrayRef<int32_t> hoistedPerms);
+
+ // Checks if the two permutations, when applied consecutively, result
+ // in the identity.
+ bool areInvolutionTransposes(ArrayRef<int32_t> perms1,
+ ArrayRef<int32_t> perms2);
+
+ // This is meant to apply to operations with the TosaElementwiseOperator
+ // trait.
+ std::optional<Value>
+ buildMappedToValue(Operation *op, const DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
+
+ // This updates valuesMap when we encounter another TransposeOp as a
+ // dependency of the hoisted one. %0 = tosa.transpose %arg0 <- applies to
+ // this %1 = tosa.transpose %0 <- when tracking back from this
+ std::optional<Value>
+ buildMappedToValue(TransposeOp transposeOp,
+ const DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
+
+ // Checks if ReshapeOp can have hoisted TransposeOp folded into it. If so,
+ // it creates new ReshapeOp with that fold.
+ std::optional<Value>
+ buildMappedToValue(ReshapeOp reshapeOp,
+ const DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
+
+ // We may have something like:
+ // %0 = tosa.const
+ // %1 = tosa.transpose
+ // %2 = tosa.add %0, %1
+ // %3 = tosa.transpose %2
+ // that --tosa-layerwise-const-fold wouldn't handle. This use shows up
+ // in MobilenetV3.
+ std::optional<Value>
+ buildMappedToValue(ConstOp constOp, const DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
+
+ // Checks which TransposeOp we should "replace", turning their converted
+ // chains of ops, through which they were propagated, "live", and the old code
+ // "dead." Attempts to avoid doing so when doing so would result in the old
+ // code staying "live," resulting in duplication.
+ std::set<TransposeOp> getGoodReplacements(
+ ArrayRef<int32_t> perms,
+ std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
+ &transposeInfo);
+
+ // Helper function for dependenciesAreValid.
+ bool userNotContainedInValidTransposeDependencies(
+ Operation *user, std::set<TransposeOp> &validTransposes,
+ std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
+ &transposeInfo);
+
+ // Helper function for getGoodReplacements to check if some TransposeOp's
+ // dependencies are OK.
+ bool dependenciesAreValid(
+ ArrayRef<int32_t> perms, const SetVector<Operation *> &dependentOps,
+ std::set<TransposeOp> &validTransposes,
+ std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
+ &transposeInfo);
+
+ // Applies perms to the DenseElementsAttr.
+ // If it returns std::nullopt, it also triggers pass failure, since verifier
+ // guarantees from TOSA are not in place (and otherwise, if used elsewhere,
+ // it should fail).
+ // This is a basic API and may benefit from refactor into the core MLIR APIs.
+ std::optional<DenseElementsAttr>
+ transposeDenseAttribute(DenseElementsAttr input, ArrayRef<int32_t> perms);
+};
+
+std::optional<DenseElementsAttr>
+TosaReduceTransposes::transposeDenseAttribute(DenseElementsAttr input,
+ ArrayRef<int32_t> perms) {
+ RankedTensorType oldType = llvm::cast<RankedTensorType>(input.getType());
+ RankedTensorType newType =
+ RankedTensorType::get(applyTOSAPermutation(oldType.getShape(), perms),
+ oldType.getElementType());
+ size_t rank = oldType.getRank();
+
+ if (input.isSplat())
+ return input.reshape(newType);
+ // Asserted by TransposeOp verifier and TOSA disallowing tensor with dimension
+ // 0. If not in place, something is very wrong.
+ if (rank <= 0 || oldType.getNumElements() <= 0 || perms.size() != rank) {
+ signalPassFailure();
+ return std::nullopt;
+ }
+
+ // The algorithm is approximately as follows:
+ // input: perms, input flat array, input tensor type
+ // (1/2) determine the strides of input/output if
+ // they were strided in row-major order. (3) adjust the strides for the
+ // input to be in the same order of indices as the output is written.
+ // (4) process dimension by dimension. example: perms 2, 0, 1; input
+ // 2x3x4; output 4x2x3 for i ... 4, j ... 2, k ... 3: output[i][j][k] =
+ // input[j][k][i] output[6i + 3j + k] = input[12j + 4k + i] and we adjust
+ // input strides to be as input[i + 12j + 4k] so we may process
+ // layer-by-layer.
+
+ // Step 1/2: Strides for input. We ignore output since row-major and can just
+ // push_back.
+
+ SmallVector<int64_t> originalInputStrides(rank);
+ originalInputStrides[rank - 1] = 1;
+ // index with int64_t to avoid overflow
+ for (int64_t i = rank - 2; i >= 0; i--)
+ originalInputStrides[i] =
+ originalInputStrides[i + 1] * oldType.getDimSize(i + 1);
+
+ // Step 3: Transpose strides of input to be same indexing (i, j, k, ...) as
+ // output which is done in row-major order.
+
+ SmallVector<int64_t> newInputStrides;
+ newInputStrides.reserve(rank);
+ for (int32_t v : perms)
+ newInputStrides.push_back(originalInputStrides[v]);
+
+ // Step 4: Write out the transposed "flat array" dimension by dimension.
+
+ auto inputArray = input.getValues<Attribute>();
+ SmallVector<std::pair<int64_t, int64_t>> boundsAndStrides;
+ for (size_t i = 0; i < rank; i++)
+ boundsAndStrides.push_back({newType.getDimSize(i), newInputStrides[i]});
+
+ SmallVector<Attribute> resultArray;
+ resultArray.reserve(inputArray.size());
+
+ std::function<void(int64_t,
+ SmallVector<std::pair<int64_t, int64_t>>::const_iterator)>
+ processTransposeDim = [&](auto accumulatedIndex, auto it) {
+ if (it == boundsAndStrides.end()) {
+ resultArray.push_back(inputArray[accumulatedIndex]);
+ return;
+ }
+
+ for (int64_t i = 0; i < it->first; i++) {
+ int64_t j = accumulatedIndex + i * it->second;
+ processTransposeDim(j, it + 1);
+ }
+ };
+
+ processTransposeDim(0, boundsAndStrides.begin());
+
+ return DenseElementsAttr::get(newType, resultArray);
+}
+
+// The SetVector should only contain ConstOp, ReshapeOp, TransposeOp
+// as the sources of the data dependencies, and TosaElementWiseOperator
+// after that, if the function returns true.
+bool TosaReduceTransposes::collectFanIn(Operation *op,
+ SetVector<Operation *> &collected) {
+ // Can occur if defined through the parameter to a func.func.
+ if (!op)
+ return false;
+
+ if (!llvm::isa_and_present<tosa::TosaDialect>(op->getDialect()))
+ return false;
+
+ // Prevent extra work if already seen.
+ if (collected.contains(op))
+ return true;
+
+ // Throw it out so later don't have to deal with this.
+ if (op->getNumResults() != 1 ||
+ !llvm::isa<RankedTensorType>(op->getResult(0).getType()))
+ return false;
+
+ // We don't wish to traverse up a ReshapeOp, since generally we can't
+ // propagate a TransposeOp through it. TransposeOp, ReshapeOp, ConstOp
+ // will have no in-edges in the data dependency graph we construct for
+ // the downstream TransposeOp.
+ if (!llvm::isa<tosa::TransposeOp>(op) && !llvm::isa<tosa::ReshapeOp>(op) &&
+ !llvm::isa<tosa::ConstOp>(op)) {
+
+ if (!op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>())
+ return false;
+
+ for (Value operand : op->getOperands())
+ // If this is a problem in future, think about alternatives to recursion.
+ if (!collectFanIn(operand.getDefiningOp(), collected))
+ return false;
+ }
+
+ // Insert in topological order.
+ collected.insert(op);
+
+ return true;
+}
+
+// Assuming that due to the verification of TransposeOp perms arrays are
+// permutations of 0 - perms.size() - 1.
+bool TosaReduceTransposes::areInvolutionTransposes(ArrayRef<int32_t> perms1,
+ ArrayRef<int32_t> perms2) {
+ if (perms1.size() != perms2.size())
+ return false;
+ int32_t N = perms1.size();
+ for (int32_t i = 0; i < N; i++)
+ if (perms2[perms1[i]] != i)
+ return false;
+ return true;
+}
+
+// Primary overload for those with TosaElementwiseOperator trait.
+// The other ones handle the case of the operations that occur at the
+// roots of the data dependency graph (ConstOp, ReshapeOp, TransposeOp).
+std::optional<Value> TosaReduceTransposes::buildMappedToValue(
+ Operation *op, const DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
+ if (op->getNumResults() != 1 ||
+ !op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>())
+ return std::nullopt;
+
+ auto resultType = llvm::cast<RankedTensorType>(op->getResult(0).getType());
+ SmallVector<Value, 3> operands;
+ for (Value v : op->getOperands()) {
+ if (valuesMap.contains(v)) {
+ operands.push_back(valuesMap.at(v));
+ } else {
+ return std::nullopt;
+ }
+ }
+
+ // Conceptually, we propagate the hoisted TransposeOp through
+ // these interveaning operations. For example,
+
+ // %0 = tosa.clamp %input : (tensor<2x3xi32>) -> tensor<2x3xi32>
+ // %1 = tosa.transpose %0 {perms = [1, 0]} : (tensor<2x3xi32>) ->
+ // tensor<3x2xi32>
+
+ // becomes:
+ // %0 = tosa.transpose %input {perms = [1, 0]} : (tensor<2x3xi32>) ->
+ // tensor<3x2xi32>
+ // %1 = tosa.clamp %0 : (tensor<3x2xi32>) -> tensor<3x2xi32>)
+
+ // We construct this new tosa.clamp here, but it doesn't
+ // turn "live" until the transpose being hoisted through this chain
+ // is replaced with the proper value from the new chain.
+
+ return rewriter
+ .create(op->getLoc(), op->getName().getIdentifier(), operands,
+ RankedTensorType::get(
+ applyTOSAPermutation(resultType.getShape(), hoistedPerms),
+ resultType.getElementType()),
+ op->getAttrs())
+ ->getResult(0);
+}
+
+std::optional<Value> TosaReduceTransposes::buildMappedToValue(
+ TransposeOp transposeOp, const DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
+ SmallVector<int32_t> perms;
+ if (failed(transposeOp.getConstantPerms(perms)) ||
+ !areInvolutionTransposes(hoistedPerms, perms))
+ return std::nullopt;
+ return transposeOp.getInput1();
+}
+
+std::optional<Value> TosaReduceTransposes::buildMappedToValue(
+ ReshapeOp reshapeOp, const DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
+ auto reshapeOutput = reshapeOp.getOutput();
+ auto reshapeInputType =
+ llvm::dyn_cast<RankedTensorType>(reshapeOp.getInput1().getType());
+ auto reshapeInputShape = reshapeInputType.getShape();
+ // want reshape N -> 1x1x...x1xNx1x...x1x1
+ if (!reshapeInputType || reshapeInputShape.size() != 1)
+ return std::nullopt;
+ auto reshapeOutputType =
+ llvm::cast<RankedTensorType>(reshapeOutput.getType());
+
+ // Instead of inserting a TransposeOp here, we check if we can fold it into
+ // the ReshapeOp. There is more complex cases where this is possible, and
+ // this check can be extended.
+
+ // Checking if reshape is N -> 1x1x...x1xNx1x...x1x1
+ auto shape = reshapeOutputType.getShape();
+ size_t ones = llvm::count(shape, 1);
+ // N == 1 and N != 1
+ if (ones != shape.size() - 1 &&
+ !(ones == shape.size() && reshapeInputShape[0] == 1))
+ return std::nullopt;
+
+ // Do not insert a TransposeOp, instead we fold the reshape and its attribute.
+ auto foldedReshape = rewriter.create<ReshapeOp>(
+ reshapeOp.getLoc(),
+ RankedTensorType::get(applyTOSAPermutation(shape, hoistedPerms),
+ reshapeOutputType.getElementType()),
+ reshapeOp.getInput1(),
+ rewriter.getDenseI64ArrayAttr(
+ applyTOSAPermutation(reshapeOp.getNewShape(), hoistedPerms)));
+ return foldedReshape->getResult(0);
+}
+
+std::optional<Value> TosaReduceTransposes::buildMappedToValue(
+ ConstOp constOp, const DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
+ auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(constOp.getValue());
+ if (!denseAttr)
+ return std::nullopt;
+ auto maybeNewDenseAttr = transposeDenseAttribute(denseAttr, hoistedPerms);
+ if (!maybeNewDenseAttr.has_value())
+ return std::nullopt;
+ auto newDenseAttr = maybeNewDenseAttr.value();
+ auto newConstOp = rewriter.create<ConstOp>(
+ constOp.getLoc(), newDenseAttr.getType(), newDenseAttr);
+ return newConstOp->getResult(0);
+}
+
+bool TosaReduceTransposes::convertDependentOps(
+ SetVector<Operation *> &dependentOps, DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
+
+ for (Operation *op : dependentOps) {
+ if (!op || op->getNumResults() != 1)
+ return false;
+
+ Value priorValue = op->getResult(0);
+
+ // It's possible on a prior transposeOp we had the same dependency and
+ // already resolved it.
+ if (valuesMap.contains(priorValue))
+ continue;
+
+ // Keep converted ops close to the original.
+ rewriter.setInsertionPointAfter(op);
+
+ std::optional<Value> maybeValue =
+ llvm::TypeSwitch<Operation *, std::optional<Value>>(op)
+ .Case<TransposeOp, ReshapeOp, ConstOp>([&](auto transposeOp) {
+ return buildMappedToValue(transposeOp, valuesMap, rewriter,
+ hoistedPerms);
+ })
+ .Default([&](Operation *op) {
+ return buildMappedToValue(op, valuesMap, rewriter, hoistedPerms);
+ });
+
+ if (!maybeValue.has_value())
+ return false;
+
+ valuesMap[priorValue] = maybeValue.value();
+ }
+
+ return true;
+}
+
+bool TosaReduceTransposes::userNotContainedInValidTransposeDependencies(
+ Operation *user, std::set<TransposeOp> &validTransposes,
+ std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
+ &transposeInfo) {
+ return llvm::none_of(
+ transposeInfo,
+ [&validTransposes,
+ user](const std::pair<TransposeOp, SetVector<Operation *>> &info) {
+ const auto &[transposeOp, dependentOps] = info;
+ return validTransposes.count(transposeOp) &&
+ dependentOps.contains(user);
+ });
+}
+
+// Dependencies are valid for an operation if none of them occur outside
+// of the proper fan-in cones of the hoisted TransposeOp with the same perms
+// that we can replace. Described in more detail within.
+bool TosaReduceTransposes::dependenciesAreValid(
+ ArrayRef<int32_t> perms, const SetVector<Operation *> &dependentOps,
+ std::set<TransposeOp> &validTransposes,
+ std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
+ &transposeInfo) {
+ for (Operation *op : dependentOps) {
+
+ // It's OK wherever ConstOp has uses -- in the worst case, we duplicate.
+ // This can be changed later if we find the memory impact is too high.
+ if (llvm::isa<ConstOp>(op))
+ continue;
+
+ for (OpOperand &use : op->getUses()) {
+ // Want the uses to be (1) contained in the dependentOps of other
+ // validTransposes, or (2) to be directly used in a TransposeOp with the
+ // same perms. For (2) it means the fan-in is a subset of our
+ // dependentOps, so it is also a validTranspose that will eventually be
+ // replaced.
+ Operation *user = use.getOwner();
+ if (auto otherTranspose = llvm::dyn_cast<TransposeOp>(user)) {
+ SmallVector<int32_t> otherPerms;
+
+ // Can later think about cases where transpose -> transpose
+ // or reshape -> transpose, where the transposes are not necessarily
+ // the same perms as the hoisted, if implementing a more general
+ // transform. These could be permitted.
+ if (failed(otherTranspose.getConstantPerms(otherPerms)) ||
+ !llvm::equal(perms, otherPerms))
+ return false;
+ } else if (userNotContainedInValidTransposeDependencies(
+ user, validTransposes, transposeInfo)) {
+ return false;
+ }
+ }
+ }
+
+ return true;
+}
+
+// Getting the set of TransposeOp that we can replace without causing
+// the old fan-in cones of any TransposeOp to remain "live", i.e, -- not being
+// dead code. This is done by iterating the set until convergence, since
+// if you are used outside your own fan-in cone, it's possible to be used
+// in another fan-in cone of a TransposeOp that is being replaced -- unless
+// we find that that one has a usage outside of it too.
+std::set<TransposeOp> TosaReduceTransposes::getGoodReplacements(
+ ArrayRef<int32_t> perms,
+ std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
+ &transposeInfo) {
+ // Initially, we assume they are all good to replace,
+ // and we whittle them down based on our criteria.
+ std::set<TransposeOp> ableToReplace;
+ for (const auto &[transposeOp, _] : transposeInfo)
+ ableToReplace.insert(transposeOp);
+
+ bool gotRid;
+ do {
+ gotRid = false;
+ for (const auto &[transposeOp, dependentOps] : transposeInfo) {
+ // We don't care about it. Already invalidated.
+ if (!ableToReplace.count(transposeOp))
+ continue;
+
+ // Check for validity.
+ if (!dependenciesAreValid(perms, dependentOps, ableToReplace,
+ transposeInfo)) {
+ ableToReplace.erase(transposeOp);
+ gotRid = true;
+ break;
+ }
+ }
+
+ } while (gotRid);
+
+ return ableToReplace;
+}
+
+void TosaReduceTransposes::runOnOperation() {
+ // We want to operate only within a single block.
+ if (!getOperation().getRegion().hasOneBlock())
+ return;
+
+ IRRewriter rewriter(&getContext());
+ // For each perms, maintain a mapping for converted ops, avoid duplication.
+ DenseMap<ArrayRef<int32_t>, DenseMap<Value, Value>> permsToValues;
+ // For each perms, we keep track of which TransposeOp are eligible
+ // for replacement alongside their dependentOps.
+ DenseMap<ArrayRef<int32_t>,
+ std::vector<std::pair<TransposeOp, SetVector<Operation *>>>>
+ permsToTransposeInfo;
+
+ // Necessary for lifetime, since DenseMap keeps a copy of the ArrayRef.
+ // Use SmallVector for perms (common-case is <= 4) but std::vector otherwise
+ // since no guarantee of smallness.
+ std::vector<SmallVector<int32_t>> collectedPerms;
+
+ // This keeps track of the order across all eligible-for-replacement
+ // TransposeOp and their perms, a necessity for the final replacements.
+ std::stack<std::pair<TransposeOp, ArrayRef<int32_t>>> totalTransposeOrder;
+
+ // We want to reserve the space up front, since SmallVector stores some data
+ // internally and the ArrayRef can reference that, which we don't want to get
+ // invalidated.
+ size_t expectedMaxPerms = 0;
+ getOperation().walk([&](TransposeOp) { expectedMaxPerms += 1; });
+ collectedPerms.reserve(expectedMaxPerms);
+
+ getOperation().walk([&](TransposeOp transposeOp) {
+ SetVector<Operation *> dependentOps;
+ collectedPerms.emplace_back();
+ SmallVector<int32_t> &perms = collectedPerms.back();
+
+ // Dynamic shapes are OK, but the incompatible ones will be rejected later.
+ auto input = transposeOp.getInput1();
+ auto output = transposeOp.getOutput();
+
+ // However, we don't support unranked tensors.
+ if (!llvm::isa<RankedTensorType>(input.getType()) ||
+ !llvm::isa<RankedTensorType>(output.getType()))
+ return;
+
+ // No transformation when transpose permutation non-constant.
+ if (failed(transposeOp.getConstantPerms(perms)))
+ return;
+
+ // We let --canonicalize deal with identity transpose.
+ if (llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
+ return;
+
+ // Can fail if some set of basic invariants is not met that we want to
+ // perform our conversions.
+ if (!collectFanIn(input.getDefiningOp(), dependentOps))
+ return;
+
+ // Want to associate valuesMap for already converted of the same perms,
+ // since it's possible multiple hoisted transposes w/ different perms
+ // converge on an op, which would result in different transformations.
+ DenseMap<Value, Value> &valuesMap = permsToValues[perms];
+
+ // Attempt to perform the conversions and placements into IR
+ // without turning inserted code "live". Also fills out valuesMap.
+ // Fails if there is an intermediary we do not support.
+ if (!convertDependentOps(dependentOps, valuesMap, rewriter, perms))
+ // Some additional operations may have been inserted, but will be
+ // removed by dead code elimination.
+ return;
+
+ // This should not happen. If it does -- it's unexpected,
+ // so we fail the pass.
+ if (!valuesMap.contains(input))
+ return signalPassFailure();
+
+ // It's possible the types are not compatible (because of dynamic shapes),
+ // and in these cases, want to resolve dynamic shapes before running the
+ // pass.
+ if (output.getType() != valuesMap.at(input).getType())
+ return;
+
+ auto &transposeInfo = permsToTransposeInfo[perms];
+
+ // In general, we might also want to introduce "newDependentOps"
+ // if there are new usages that don't fall inside the original fan-ins
+ // (like the TransposeOp we insert for ReshapeOp),
+ // but in this case, that is specialized enough and overlaps
+ // with another direct-use TransposeOp case we need to cover anyway.
+ transposeInfo.push_back({transposeOp, dependentOps});
+
+ // This is for the final replacement across all transposes.
+ totalTransposeOrder.push({transposeOp, perms});
+ });
+
+ // We want to do a full fan-in analysis on a perms-level,
+ // since if we do it on a multi-perms level, and they share (due to a shared
+ // dependency on a Reshape) then we would also get duplicate ops.
+ // Const is special cased.
+ std::set<TransposeOp> ableToReplace;
+ for (auto &[perms, transposeInfo] : permsToTransposeInfo) {
+ // Gives us back replacements that would never result in any duplicate
+ // operations being inserted by us in the IR (i.e, our goal is only to
+ // remove transposes, and not create a "new chain" to do so, but replace
+ // the existing chains).
+ // Ideally, --canonicalize is run before this pass, since it helps this
+ // analysis by removing dead code to allow more potentially acceptable
+ // transformations.
+ auto goodReplacementsForPerms = getGoodReplacements(perms, transposeInfo);
+ ableToReplace.insert(goodReplacementsForPerms.begin(),
+ goodReplacementsForPerms.end());
+ }
+
+ // We want to do replacement across all transposes
+ // in reverse order, due to invalidation of valuesMap mappings
+ // if we did it otherwise.
+ while (!totalTransposeOrder.empty()) {
+ auto [transposeOp, perms] = totalTransposeOrder.top();
+ totalTransposeOrder.pop();
+
+ if (ableToReplace.count(transposeOp) == 0)
+ continue;
+
+ auto &valuesMap = permsToValues[perms];
+ auto input = transposeOp.getInput1();
+
+ // The purpose of this reverse iteration
+ // is to avoid valuesMap invalidation. If it happens,
+ // something is wrong.
+ if (!valuesMap.contains(input))
+ return signalPassFailure();
+
+ rewriter.replaceOp(transposeOp, valuesMap.at(input));
+ }
+
+ // We can remove all dead code by going in reverse.
+ // This is because we would remove usages before we
+ // see the users.
+ getOperation().walk<WalkOrder::PostOrder, ReverseIterator>(
+ [&](Operation *op) {
+ if (isOpTriviallyDead(op))
+ rewriter.eraseOp(op);
+ });
+}
+
+} // namespace
diff --git a/mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir b/mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir
new file mode 100644
index 00000000000000..3f0d7544083a42
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir
@@ -0,0 +1,649 @@
+// RUN: mlir-opt --verify-diagnostics --split-input-file --verify-each --tosa-reduce-transposes %s | FileCheck %s
+
+// CHECK-LABEL: @test_transpose_tracks_to_nullifying_single_step
+// CHECK-NEXT: %[[RESULT:.*]] = tosa.ceil %arg0
+// CHECK-NEXT: return %[[RESULT]]
+func.func @test_transpose_tracks_to_nullifying_single_step(%arg0: tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %ceil = tosa.ceil %0 : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %1 = tosa.transpose %ceil, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ return %1 : tensor<1x2x3x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_transpose_tracks_to_nullifying_multi_unary_step
+// CHECK-NEXT: %[[CLAMP:.*]] = tosa.clamp %arg0
+// CHECK-NEXT: %[[ABS:.*]] = tosa.abs %[[CLAMP]]
+// CHECK-NEXT: %[[NOT:.*]] = tosa.bitwise_not %[[ABS]]
+// CHECK-NEXT: return %[[NOT]]
+func.func @test_transpose_tracks_to_nullifying_multi_unary_step(%arg0: tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %clamp = tosa.clamp %0 {max_fp = 1.0 : f32, min_fp = 0.0 : f32, max_int = 1 : i64, min_int = 0 : i64} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %abs = tosa.abs %clamp : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %bitwise_not = tosa.bitwise_not %abs : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %1 = tosa.transpose %bitwise_not, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ return %1 : tensor<1x2x3x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_transpose_tracks_to_nullifying_diverging_binary
+// CHECK-NEXT: %[[CLAMP:.*]] = tosa.clamp %arg0
+// CHECK-NEXT: %[[ABS:.*]] = tosa.abs %arg1
+// CHECK-NEXT: %[[ADD:.*]] = tosa.add %[[CLAMP]], %[[ABS]]
+// CHECK-NEXT: return %[[ADD]]
+func.func @test_transpose_tracks_to_nullifying_diverging_binary(%arg0: tensor<1x2x3x4xi32>, %arg1: tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %transpose0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %transpose1 = tosa.transpose %arg1, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %clamp = tosa.clamp %transpose0 {max_fp = 1.0 : f32, min_fp = 0.0 : f32, max_int = 1 : i64, min_int = 0 : i64} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %abs = tosa.abs %transpose1 : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %add = tosa.add %clamp, %abs : (tensor<1x3x4x2xi32>, tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %result = tosa.transpose %add, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ return %result : tensor<1x2x3x4xi32>
+}
+
+// -----
+
+
+// CHECK-LABEL: @test_transpose_tracks_to_nullifying_diverging_binary_with_broadcasting
+// CHECK-NEXT: %[[CLAMP:.*]] = tosa.clamp %arg0
+// CHECK-NEXT: %[[ABS:.*]] = tosa.abs %arg1
+// CHECK-NEXT: %[[ADD:.*]] = tosa.add %[[CLAMP]], %[[ABS]]
+// CHECK-NEXT: return %[[ADD]]
+func.func @test_transpose_tracks_to_nullifying_diverging_binary_with_broadcasting(%arg0: tensor<1x2x3x4xi32>, %arg1: tensor<1x2x1x4xi32>) -> tensor<1x2x3x4xi32> {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %transpose0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %transpose1 = tosa.transpose %arg1, %perms0 : (tensor<1x2x1x4xi32>, tensor<4xi32>) -> tensor<1x1x4x2xi32>
+ %clamp = tosa.clamp %transpose0 {max_fp = 1.0 : f32, min_fp = 0.0 : f32, max_int = 1 : i64, min_int = 0 : i64} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %abs = tosa.abs %transpose1 : (tensor<1x1x4x2xi32>) -> tensor<1x1x4x2xi32>
+ %add = tosa.add %clamp, %abs : (tensor<1x3x4x2xi32>, tensor<1x1x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %result = tosa.transpose %add, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ return %result : tensor<1x2x3x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_transpose_tracks_to_nullifying__converging_binary
+// CHECK-NEXT: %[[RESULT:.*]] = tosa.add %arg0, %arg0
+// CHECK-NEXT: return %[[RESULT]]
+func.func @test_transpose_tracks_to_nullifying__converging_binary(%arg0: tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %clamp = tosa.add %0, %0 : (tensor<1x3x4x2xi32>, tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %1 = tosa.transpose %clamp, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ return %1 : tensor<1x2x3x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_torch_conv2d_with_elementwise_in_between
+// CHECK: %[[CONV1:.*]] = tosa.conv2d
+// CHECK: %[[CEIL:.*]] = tosa.ceil %[[CONV1]]
+// CHECK: %[[CONV2:.*]] = tosa.conv2d %[[CEIL]]
+// CHECK: %[[FLOOR:.*]] = tosa.floor %[[CONV2]]
+// CHECK: %[[CONV3:.*]] = tosa.conv2d %[[FLOOR]]
+// CHECK: %[[RES:.*]] = tosa.transpose %[[CONV3]]
+// CHECK: return %[[RES]]
+func.func @test_torch_conv2d_with_elementwise_in_between(%arg0: tensor<3x3x10x10xf32>) -> tensor<3x3x7x7xf32> {
+ %0 = "tosa.const"() <{value = dense_resource<torch_tensor_3_torch.float32_2> : tensor<3xf32>}> : () -> tensor<3xf32>
+ %1 = "tosa.const"() <{value = dense_resource<torch_tensor_3_3_2_2_torch.float32_2> : tensor<3x3x2x2xf32>}> : () -> tensor<3x3x2x2xf32>
+ %2 = "tosa.const"() <{value = dense_resource<torch_tensor_3_torch.float32_1> : tensor<3xf32>}> : () -> tensor<3xf32>
+ %3 = "tosa.const"() <{value = dense_resource<torch_tensor_3_3_2_2_torch.float32_1> : tensor<3x3x2x2xf32>}> : () -> tensor<3x3x2x2xf32>
+ %4 = "tosa.const"() <{value = dense_resource<torch_tensor_3_3_2_2_torch.float32> : tensor<3x3x2x2xf32>}> : () -> tensor<3x3x2x2xf32>
+ %5 = "tosa.const"() <{value = dense_resource<torch_tensor_3_torch.float32> : tensor<3xf32>}> : () -> tensor<3xf32>
+ %6 = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32>
+ %7 = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
+ %8 = tosa.transpose %arg0, %6 : (tensor<3x3x10x10xf32>, tensor<4xi32>) -> tensor<3x10x10x3xf32>
+ %9 = tosa.transpose %4, %6 : (tensor<3x3x2x2xf32>, tensor<4xi32>) -> tensor<3x2x2x3xf32>
+ %10 = tosa.conv2d %8, %9, %5 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<3x10x10x3xf32>, tensor<3x2x2x3xf32>, tensor<3xf32>) -> tensor<3x9x9x3xf32>
+ %11 = tosa.transpose %10, %7 : (tensor<3x9x9x3xf32>, tensor<4xi32>) -> tensor<3x3x9x9xf32>
+ %12 = tosa.ceil %11 : (tensor<3x3x9x9xf32>) -> tensor<3x3x9x9xf32>
+ %13 = tosa.transpose %12, %6 : (tensor<3x3x9x9xf32>, tensor<4xi32>) -> tensor<3x9x9x3xf32>
+ %14 = tosa.transpose %3, %6 : (tensor<3x3x2x2xf32>, tensor<4xi32>) -> tensor<3x2x2x3xf32>
+ %15 = tosa.conv2d %13, %14, %2 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<3x9x9x3xf32>, tensor<3x2x2x3xf32>, tensor<3xf32>) -> tensor<3x8x8x3xf32>
+ %16 = tosa.transpose %15, %7 : (tensor<3x8x8x3xf32>, tensor<4xi32>) -> tensor<3x3x8x8xf32>
+ %17 = tosa.floor %16 : (tensor<3x3x8x8xf32>) -> tensor<3x3x8x8xf32>
+ %18 = tosa.transpose %17, %6 : (tensor<3x3x8x8xf32>, tensor<4xi32>) -> tensor<3x8x8x3xf32>
+ %19 = tosa.transpose %1, %6 : (tensor<3x3x2x2xf32>, tensor<4xi32>) -> tensor<3x2x2x3xf32>
+ %20 = tosa.conv2d %18, %19, %0 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<3x8x8x3xf32>, tensor<3x2x2x3xf32>, tensor<3xf32>) -> tensor<3x7x7x3xf32>
+ %21 = tosa.transpose %20, %7 : (tensor<3x7x7x3xf32>, tensor<4xi32>) -> tensor<3x3x7x7xf32>
+ return %21 : tensor<3x3x7x7xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_mulop_conversion
+// CHECK-NEXT: %[[RES:.*]] = tosa.mul %arg0, %arg1
+// CHECK-NEXT: return %[[RES]]
+func.func @test_mulop_conversion(%arg0: tensor<1x2x3x4xi32>, %arg1: tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %transpose0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %transpose1 = tosa.transpose %arg1, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %mul = tosa.mul %transpose0, %transpose1 {shift = 0 : i8} : (tensor<1x3x4x2xi32>, tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %result = tosa.transpose %mul, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ return %result : tensor<1x2x3x4xi32>
+}
+
+// -----
+
+// COM: this case is a reshape we don't convert, since can't fold the transpose into it.
+// COM: a transform actually occurs underneath the hood, but it results in identical IR.
+// CHECK-LABEL: @test_basic_non_broadcasting_reshape
+// CHECK: "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
+// CHECK: tosa.reshape %arg0 {new_shape = array<i64: 1, 3, 2>} : (tensor<2x3xi32>) -> tensor<1x3x2xi32>
+// CHECK: tosa.transpose %1, %0 : (tensor<1x3x2xi32>, tensor<3xi32>) -> tensor<1x2x3xi32>
+func.func @test_basic_non_broadcasting_reshape(%arg0: tensor<2x3xi32>) -> tensor<1x2x3xi32> {
+ %perms = "tosa.const"() {value = dense<[0, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
+ %1 = tosa.reshape %arg0 {new_shape = array<i64: 1, 3, 2>} : (tensor<2x3xi32>) -> tensor<1x3x2xi32>
+ %2 = tosa.transpose %1, %perms : (tensor<1x3x2xi32>, tensor<3xi32>) -> tensor<1x2x3xi32>
+ return %2 : tensor<1x2x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_dynamic_broadcasting_reshape
+// CHECK: %[[RES:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, -1>} : (tensor<?xi32>) -> tensor<1x1x?xi32>
+// CHECK: return %[[RES]]
+func.func @test_dynamic_broadcasting_reshape(%arg0: tensor<?xi32>) -> tensor<1x1x?xi32> {
+ %perms = "tosa.const"() {value = dense<[0, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
+ %1 = tosa.reshape %arg0 {new_shape = array<i64: 1, -1, 1>} : (tensor<?xi32>) -> tensor<1x?x1xi32>
+ %2 = tosa.transpose %1, %perms : (tensor<1x?x1xi32>, tensor<3xi32>) -> tensor<1x1x?xi32>
+ return %2 : tensor<1x1x?xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_reshape_for_broadcast
+// CHECK-DAG: %[[RESHAPE_INPUT:.*]] = "tosa.const"() <{value = dense<[1, 2, 3, 4]>
+// CHECK-DAG: %[[RESHAPE:.*]] = tosa.reshape %[[RESHAPE_INPUT]] {new_shape = array<i64: 4, 1, 1>}
+// CHECK-DAG: %[[ADD:.*]] = tosa.add %arg0, %[[RESHAPE]]
+// CHECK: return %[[ADD]]
+func.func @test_reshape_for_broadcast(%arg0: tensor<4x3x2xi32>) -> tensor<4x3x2xi32> {
+ %0 = "tosa.const"() {value = dense<[1,2,3,4]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %reshape = tosa.reshape %0 {new_shape = array<i64: 1, 1, 4>} : (tensor<4xi32>) -> tensor<1x1x4xi32>
+ %perms0 = "tosa.const"() {value = dense<[2, 1, 0]> : tensor<3xi32>} : () -> tensor<3xi32>
+ %transpose0 = tosa.transpose %arg0, %perms0 : (tensor<4x3x2xi32>, tensor<3xi32>) -> tensor<2x3x4xi32>
+ %add = tosa.add %transpose0, %reshape : (tensor<2x3x4xi32>, tensor<1x1x4xi32>) -> tensor<2x3x4xi32>
+ %transpose1 = tosa.transpose %add, %perms0 : (tensor<2x3x4xi32>, tensor<3xi32>) -> tensor<4x3x2xi32>
+ return %transpose1 : tensor<4x3x2xi32>
+}
+
+// -----
+
+// COM: taken directly from ResNet18 translation.
+// COM: changes: %74 as argument instead of result of conv2d
+
+// CHECK-LABEL: @test_resnet18_common_case
+// COM: note that %74 is now represented by %arg2
+// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense_resource<torch_tensor_64_torch.float32_1> : tensor<64xf32>}> : () -> tensor<64xf32>
+// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{value = dense_resource<torch_tensor_64_torch.float32> : tensor<64xf32>}> : () -> tensor<64xf32>
+// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<9.99999974E-6> : tensor<1xf32>}> : () -> tensor<1xf32>
+// CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor<1xf32>}> : () -> tensor<1xf32>
+// CHECK-DAG: %[[VAL_6:.*]] = tosa.add %arg1, %[[VAL_4]] : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32>
+// CHECK-DAG: %[[VAL_7:.*]] = tosa.pow %[[VAL_6]], %[[VAL_5]] : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32>
+// CHECK-DAG: %[[VAL_8:.*]] = tosa.reciprocal %[[VAL_7]] : (tensor<64xf32>) -> tensor<64xf32>
+// CHECK-DAG: %[[VAL_9:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 64>} : (tensor<64xf32>) -> tensor<1x1x1x64xf32>
+// CHECK-DAG: %[[VAL_10:.*]] = tosa.sub %arg2, %[[VAL_9]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32>
+// CHECK-DAG: %[[VAL_11:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array<i64: 1, 1, 1, 64>} : (tensor<64xf32>) -> tensor<1x1x1x64xf32>
+// CHECK-DAG: %[[VAL_12:.*]] = tosa.mul %[[VAL_10]], %[[VAL_11]] {shift = 0 : i8} : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32>
+// CHECK-DAG: %[[VAL_13:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array<i64: 1, 1, 1, 64>} : (tensor<64xf32>) -> tensor<1x1x1x64xf32>
+// CHECK-DAG: %[[VAL_14:.*]] = tosa.mul %[[VAL_12]], %[[VAL_13]] {shift = 0 : i8} : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32>
+// CHECK-DAG: %[[VAL_15:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array<i64: 1, 1, 1, 64>} : (tensor<64xf32>) -> tensor<1x1x1x64xf32>
+// CHECK-DAG: %[[VAL_16:.*]] = tosa.add %[[VAL_14]], %[[VAL_15]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32>
+// CHECK-DAG: %[[VAL_17:.*]] = tosa.clamp %[[VAL_16]] {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x112x112x64xf32>) -> tensor<1x112x112x64xf32>
+// CHECK: return %[[VAL_17]] : tensor<1x112x112x64xf32>
+
+func.func @test_resnet18_common_case(%arg0: tensor<64xf32>, %arg1: tensor<64xf32>, %74: tensor<1x112x112x64xf32>) -> tensor<1x112x112x64xf32> {
+ %59 = "tosa.const"() <{value = dense_resource<torch_tensor_64_torch.float32_1> : tensor<64xf32>}> : () -> tensor<64xf32>
+ %60 = "tosa.const"() <{value = dense_resource<torch_tensor_64_torch.float32> : tensor<64xf32>}> : () -> tensor<64xf32>
+ %63 = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32>
+ %64 = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
+ %69 = "tosa.const"() <{value = dense<9.99999974E-6> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %70 = "tosa.const"() <{value = dense<5.000000e-01> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %75 = tosa.transpose %74, %64 : (tensor<1x112x112x64xf32>, tensor<4xi32>) -> tensor<1x64x112x112xf32>
+ %76 = tosa.add %arg1, %69 : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32>
+ %77 = tosa.pow %76, %70 : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32>
+ %78 = tosa.reciprocal %77 : (tensor<64xf32>) -> tensor<64xf32>
+ %79 = tosa.reshape %arg0 {new_shape = array<i64: 1, 64, 1, 1>} : (tensor<64xf32>) -> tensor<1x64x1x1xf32>
+ %80 = tosa.sub %75, %79 : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32>
+ %81 = tosa.reshape %78 {new_shape = array<i64: 1, 64, 1, 1>} : (tensor<64xf32>) -> tensor<1x64x1x1xf32>
+ %82 = tosa.mul %80, %81 {shift = 0 : i8} : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32>
+ %83 = tosa.reshape %60 {new_shape = array<i64: 1, 64, 1, 1>} : (tensor<64xf32>) -> tensor<1x64x1x1xf32>
+ %84 = tosa.mul %82, %83 {shift = 0 : i8} : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32>
+ %85 = tosa.reshape %59 {new_shape = array<i64: 1, 64, 1, 1>} : (tensor<64xf32>) -> tensor<1x64x1x1xf32>
+ %86 = tosa.add %84, %85 : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32>
+ %87 = tosa.clamp %86 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x64x112x112xf32>) -> tensor<1x64x112x112xf32>
+ %88 = tosa.transpose %87, %63 : (tensor<1x64x112x112xf32>, tensor<4xi32>) -> tensor<1x112x112x64xf32>
+ return %88 : tensor<1x112x112x64xf32>
+}
+
+
+// -----
+
+// CHECK-LABEL: @test_back_to_back_nullifiers
+// CHECK: %[[PERMS:.*]] = "tosa.const"
+// CHECK: %[[RES:.*]] = tosa.transpose %arg0, %[[PERMS]]
+// CHECK: return %[[RES]]
+func.func @test_back_to_back_nullifiers(%arg0: tensor<2x3xi32>) -> tensor<3x2xi32> {
+ %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+ %0 = tosa.transpose %arg0, %perms : (tensor<2x3xi32>, tensor<2xi32>) -> tensor<3x2xi32>
+ %1 = tosa.transpose %0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
+ %2 = tosa.transpose %1, %perms : (tensor<2x3xi32>, tensor<2xi32>) -> tensor<3x2xi32>
+ return %2 : tensor<3x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_back_to_back_nullifiers_different_transposes
+// CHECK: %[[PERMS:.*]] = "tosa.const"
+// CHECK: %[[RES:.*]] = tosa.transpose %arg0, %[[PERMS]]
+// CHECK: return %[[RES]]
+func.func @test_back_to_back_nullifiers_different_transposes(%arg0: tensor<2x3x4x5xi32>) -> tensor<2x4x5x3xi32> {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %0 = tosa.transpose %arg0, %perms0 : (tensor<2x3x4x5xi32>, tensor<4xi32>) -> tensor<2x4x5x3xi32>
+ %1 = tosa.transpose %0, %perms1 : (tensor<2x4x5x3xi32>, tensor<4xi32>) -> tensor<2x3x4x5xi32>
+ %2 = tosa.transpose %1, %perms0 : (tensor<2x3x4x5xi32>, tensor<4xi32>) -> tensor<2x4x5x3xi32>
+ return %2 : tensor<2x4x5x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_no_transform_if_outside_fan_in_cone
+// CHECK: tosa.const
+// CHECK: %[[CLAMP_IN:.*]] = tosa.transpose
+// CHECK: %[[RES2:.*]] = tosa.clamp %[[CLAMP_IN]]
+// CHECK: tosa.const
+// CHECK: %[[RES1:.*]] = tosa.transpose
+// CHECK: return %[[RES1]], %[[RES2]]
+func.func @test_no_transform_if_outside_fan_in_cone(%arg0: tensor<3x3x3x3xi32>) -> (tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>) {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %0 = tosa.transpose %arg0, %perms0 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ %clamp = tosa.clamp %0 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %1 = tosa.transpose %clamp, %perms1 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ return %1, %clamp : tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_two_different_downstream_converge_to_reshape_same_perms
+// CHECK-DAG: %[[RESHAPE:.*]] = tosa.reshape %arg0
+// CHECK-DAG: %[[CLAMP:.*]] = tosa.clamp %[[RESHAPE]]
+// CHECK: return %[[RESHAPE]], %[[CLAMP]]
+func.func @test_two_different_downstream_converge_to_reshape_same_perms(%arg0: tensor<64xf32>) -> (tensor<1x1x64xf32>, tensor<1x1x64xf32>) {
+ %0 = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
+ %1 = tosa.reshape %arg0 {new_shape = array<i64: 1, 64, 1>} : (tensor<64xf32>) -> tensor<1x64x1xf32>
+ %2 = tosa.clamp %1 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x64x1xf32>) -> tensor<1x64x1xf32>
+ %3 = tosa.transpose %1, %0 : (tensor<1x64x1xf32>, tensor<3xi32>) -> tensor<1x1x64xf32>
+ %4 = tosa.transpose %2, %0 : (tensor<1x64x1xf32>, tensor<3xi32>) -> tensor<1x1x64xf32>
+ return %3, %4 : tensor<1x1x64xf32>, tensor<1x1x64xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_two_different_downstream_converge_to_reshape_different_perms
+// CHECK-DAG: tosa.const
+// CHECK-DAG: tosa.const
+// CHECK-DAG: %[[RESHAPE:.*]] = tosa.reshape
+// CHECK-DAG: %[[CLAMP:.*]] = tosa.clamp %[[RESHAPE]]
+// CHECK-DAG: %[[RET1:.*]] = tosa.transpose
+// CHECK-DAG: %[[RET2:.*]] = tosa.transpose
+// CHECK-DAG: return %[[RET1]], %[[RET2]]
+func.func @test_two_different_downstream_converge_to_reshape_different_perms(%arg0: tensor<64xf32>) -> (tensor<1x1x64xf32>, tensor<64x1x1xf32>) {
+ %0 = "tosa.const"() <{value = dense<[1, 2, 0]> : tensor<3xi32>}> : () -> tensor<3xi32>
+ %1 = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
+ %2 = tosa.reshape %arg0 {new_shape = array<i64: 1, 64, 1>} : (tensor<64xf32>) -> tensor<1x64x1xf32>
+ %3 = tosa.clamp %2 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x64x1xf32>) -> tensor<1x64x1xf32>
+ %4 = tosa.transpose %2, %1 : (tensor<1x64x1xf32>, tensor<3xi32>) -> tensor<1x1x64xf32>
+ %5 = tosa.transpose %3, %0 : (tensor<1x64x1xf32>, tensor<3xi32>) -> tensor<64x1x1xf32>
+ return %4, %5 : tensor<1x1x64xf32>, tensor<64x1x1xf32>
+}
+
+// -----
+
+// COM: no transform
+// CHECK-LABEL: @test_outside_perms_usage_of_fan_in
+// CHECK: tosa.const
+// CHECK: tosa.transpose
+// CHECK: tosa.clamp
+// CHECK: %[[RES1:.*]] = tosa.transpose
+// CHECK: %[[RES2:.*]] = tosa.add
+// CHECK: return %[[RES1]], %[[RES2]]
+func.func @test_outside_perms_usage_of_fan_in(%arg0: tensor<2x3xf32>, %arg1: tensor<3x2xf32>) -> (tensor<2x3xf32>, tensor<3x2xf32>) { %0 = "tosa.const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32>
+ %1 = tosa.transpose %arg0, %0 : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
+ %2 = tosa.clamp %1 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<3x2xf32>) -> tensor<3x2xf32>
+ %3 = tosa.transpose %2, %0 : (tensor<3x2xf32>, tensor<2xi32>) -> tensor<2x3xf32>
+ %4 = tosa.add %arg1, %2 : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32>
+ return %3, %4: tensor<2x3xf32>, tensor<3x2xf32>
+}
+
+// -----
+
+// COM: this use-case is important for ResNet. we want to allow these, but disallow if falls into an illegal area (outside fan-ins that get converted),
+// COM: since then we would get duplicate ops.
+// CHECK-LABEL: @test_use_present_in_another_valid_perms_fan_in
+// CHECK-DAG: %[[NEW_CLAMP:.*]] = tosa.clamp %arg0
+// CHECK-DAG: %[[NEW_ADD:.*]] = tosa.add %arg1, %[[NEW_CLAMP]]
+// CHECK: return %[[NEW_CLAMP]], %[[NEW_ADD]]
+func.func @test_use_present_in_another_valid_perms_fan_in(%arg0: tensor<2x3xf32>, %arg1: tensor<2x3xf32>) -> (tensor<2x3xf32>, tensor<2x3xf32>) {
+ %0 = "tosa.const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32>
+ %1 = tosa.transpose %arg0, %0 : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
+ %2 = tosa.clamp %1 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<3x2xf32>) -> tensor<3x2xf32>
+ %3 = tosa.transpose %2, %0 : (tensor<3x2xf32>, tensor<2xi32>) -> tensor<2x3xf32>
+ %4 = tosa.transpose %arg1, %0 : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
+ %5 = tosa.add %4, %2 : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32>
+ %6 = tosa.transpose %5, %0 : (tensor<3x2xf32>, tensor<2xi32>) -> tensor<2x3xf32>
+ return %3, %6: tensor<2x3xf32>, tensor<2x3xf32>
+}
+
+// -----
+
+// COM: no transform, since we would get duplicates
+// CHECK-LABEL: @test_two_same_perms_fan_in_but_one_doesnt_convert_dependents
+// CHECK: tosa.const
+// CHECK: tosa.transpose
+// CHECK: %[[CEIL:.*]] = tosa.ceil
+// CHECK: %[[ADD:.*]] = tosa.add %[[CEIL]]
+// CHECK: %[[RES1:.*]] = tosa.transpose %[[CEIL]]
+// CHECK: %[[RES2:.*]] = tosa.transpose %[[ADD]]
+// CHECK: return %[[RES1]], %[[RES2]]
+func.func @test_two_same_perms_fan_in_but_one_doesnt_convert_dependents(%arg0: tensor<2x3xi32>, %arg1: tensor<3x2xi32>) -> (tensor<2x3xi32>, tensor<2x3xi32>) {
+ %0 = "tosa.const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32>
+ %1 = tosa.transpose %arg0, %0 : (tensor<2x3xi32>, tensor<2xi32>) -> tensor<3x2xi32>
+ %2 = tosa.ceil %1 : (tensor<3x2xi32>) -> tensor<3x2xi32>
+ %3 = tosa.add %2, %arg1 : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
+ %4 = tosa.transpose %2, %0 : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
+ %5 = tosa.transpose %3, %0 : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
+ return %4, %5 : tensor<2x3xi32>, tensor<2x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_direct_use_in_other_transpose_with_same_perms
+// CHECK-NEXT: %[[RES:.*]] = tosa.clamp %arg0
+// CHECK-NEXT: return %[[RES]], %[[RES]]
+func.func @test_direct_use_in_other_transpose_with_same_perms(%arg0: tensor<3x3x3x3xi32>) -> (tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>) {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %0 = tosa.transpose %arg0, %perms0 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ %clamp = tosa.clamp %0 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %1 = tosa.transpose %clamp, %perms1 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ %2 = tosa.transpose %clamp, %perms1 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ return %1, %2 : tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_const_transpose
+// CHECK: %[[NEW:.*]] = "tosa.const"() <{value = dense<0> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
+// CHECK-NOT: tosa.transpose
+// CHECK: return %[[NEW]]
+func.func @test_const_transpose() -> tensor<2x3xi32> {
+ %0 = "tosa.const"() {value = dense<0> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
+ %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+ %1 = tosa.transpose %0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
+ return %1 : tensor<2x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_transpose_tracks_to_const_single_step
+// CHECK: %[[NEW_CONST:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x2x3x4xi32>}> : () -> tensor<1x2x3x4xi32>
+// CHECK: %[[NEW_CLAMP:.*]] = tosa.clamp %[[NEW_CONST]] {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32>
+// CHECK-NOT: tosa.transpose
+// CHECK: return %[[NEW_CLAMP]]
+func.func @test_transpose_tracks_to_const_single_step() -> tensor<1x2x3x4xi32> {
+ %0 = "tosa.const"() {value = dense<0> : tensor<1x3x4x2xi32>} : () -> tensor<1x3x4x2xi32>
+ %clamp = tosa.clamp %0 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %1 = tosa.transpose %clamp, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ return %1 : tensor<1x2x3x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_static_unary_path_to_const
+// CHECK: %[[NEW_CONST:.*]] = "tosa.const"() <{value = dense<1> : tensor<1x2x3x4xi32>}> : () -> tensor<1x2x3x4xi32>
+// CHECK: %[[NEW_CLAMP:.*]] = tosa.clamp %[[NEW_CONST]] {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32>
+// CHECK: %[[NEW_ABS:.*]] = tosa.abs %[[NEW_CLAMP]] : (tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32>
+// CHECK: %[[NEW_NOT:.*]] = tosa.bitwise_not %[[NEW_ABS]] : (tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32>
+// CHECK: return %[[NEW_NOT]]
+func.func @test_static_unary_path_to_const() -> tensor<1x2x3x4xi32> {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %0 = "tosa.const"() {value = dense<1> : tensor<1x3x4x2xi32>} : () -> tensor<1x3x4x2xi32>
+ %clamp = tosa.clamp %0 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %abs = tosa.abs %clamp : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %bitwise_not = tosa.bitwise_not %abs : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %1 = tosa.transpose %bitwise_not, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ return %1 : tensor<1x2x3x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_static_diverges_to_non_splat_const_and_nullifying
+// CHECK: %[[NEW_CONST:.*]] = "tosa.const"()
+// CHECK-SAME{LITERAL}: dense<[[[[1, 3, 5, 7], [9, 11, 13, 15], [17, 19, 21, 23]], [[2, 4, 6, 8], [10, 12, 14, 16], [18, 20, 22, 24]]]]>
+// CHECK: tensor<1x2x3x4xi32>}> : () -> tensor<1x2x3x4xi32>
+// CHECK: %[[NEW_CLAMP:.*]] = tosa.clamp %arg0 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32>
+// CHECK: %[[NEW_ABS:.*]] = tosa.abs %[[NEW_CONST]] : (tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32>
+// CHECK: %[[NEW_ADD:.*]] = tosa.add %[[NEW_ABS]], %[[NEW_CLAMP]] : (tensor<1x2x3x4xi32>, tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32>
+// CHECK: return %[[NEW_ADD]]
+func.func @test_static_diverges_to_non_splat_const_and_nullifying(%arg0: tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %transpose0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %const = "tosa.const"() {value = dense<[[[[1, 2], [3, 4], [5, 6], [7, 8]],
+ [[9, 10], [11, 12], [13, 14], [15, 16]],
+ [[17, 18], [19, 20], [21, 22], [23, 24]]]]> : tensor<1x3x4x2xi32>} : () -> tensor<1x3x4x2xi32>
+ %clamp = tosa.clamp %transpose0 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %abs = tosa.abs %const : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %add = tosa.add %abs, %clamp : (tensor<1x3x4x2xi32>, tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %perms2 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %result = tosa.transpose %add, %perms2 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ return %result : tensor<1x2x3x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_multi_downstream_both_nullify
+// CHECK-NEXT: %[[RES:.*]] = tosa.clamp %arg0
+// CHECK-NEXT: return %[[RES]], %[[RES]]
+func.func @test_multi_downstream_both_nullify(%arg0: tensor<3x3x3x3xi32>) -> (tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>) {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %0 = tosa.transpose %arg0, %perms0 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ %clamp = tosa.clamp %0 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %1 = tosa.transpose %clamp, %perms1 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ %2 = tosa.transpose %clamp, %perms1 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ return %1, %2 : tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>
+}
+
+// -----
+
+// COM: we don't perform this transformation intentionally, since we would then get duplicates
+// CHECK-LABEL: @test_multi_downstream_one_nullifies_upstream_other_does_not
+// CHECK: tosa.const
+// CHECK: tosa.transpose
+// CHECK: tosa.clamp
+// CHECK: tosa.const
+// CHECK: tosa.transpose
+// CHECK: tosa.transpose
+func.func @test_multi_downstream_one_nullifies_upstream_other_does_not(%arg0: tensor<3x3x3x3xi32>) -> (tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>) {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %0 = tosa.transpose %arg0, %perms0 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ %clamp = tosa.clamp %0 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %1 = tosa.transpose %clamp, %perms1 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ %2 = tosa.transpose %clamp, %perms0 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ return %1, %2 : tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_unknown_dim_inner_replacement_matches
+// CHECK-NEXT: return %arg0
+func.func @test_unknown_dim_inner_replacement_matches(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
+ %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+ %0 = tosa.transpose %arg0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<?x3xi32>
+ %1 = tosa.transpose %0, %perms : (tensor<?x3xi32>, tensor<2xi32>) -> tensor<3x2xi32>
+ return %1 : tensor<3x2xi32>
+}
+
+// -----
+
+
+// CHECK-LABEL: @test_unknown_dim_outer_replacement_matches
+// CHECK-NEXT: return %arg0
+func.func @test_unknown_dim_outer_replacement_matches(%arg0: tensor<3x?xi32>) -> tensor<3x?xi32> {
+ %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+ %0 = tosa.transpose %arg0, %perms : (tensor<3x?xi32>, tensor<2xi32>) -> tensor<2x3xi32>
+ %1 = tosa.transpose %0, %perms : (tensor<2x3xi32>, tensor<2xi32>) -> tensor<3x?xi32>
+ return %1 : tensor<3x?xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_transpose_tracks_to_nullifying_diverging_binary_unknown_dim_replacements_match
+// CHECK-NEXT: %[[CLAMP:.*]] = tosa.clamp %arg0
+// CHECK-NEXT: %[[ABS:.*]] = tosa.abs %arg1
+// CHECK-NEXT: %[[ADD:.*]] = tosa.add %[[CLAMP]], %[[ABS]]
+// CHECK-NEXT: return %[[ADD]]
+func.func @test_transpose_tracks_to_nullifying_diverging_binary_unknown_dim_replacements_match(%arg0: tensor<1x?x3x4xi32>, %arg1: tensor<1x2x?x4xi32>) -> tensor<1x2x3x4xi32> {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %transpose0 = tosa.transpose %arg0, %perms0 : (tensor<1x?x3x4xi32>, tensor<4xi32>) -> tensor<?x3x4x?xi32>
+ %transpose1 = tosa.transpose %arg1, %perms0 : (tensor<1x2x?x4xi32>, tensor<4xi32>) -> tensor<1x?x?x2xi32>
+ %clamp = tosa.clamp %transpose0 {min_int = 0 : i64, max_int = 1 : i64, min_fp = 0.0 : f64, max_fp = 1.0 : f64} : (tensor<?x3x4x?xi32>) -> tensor<?x3x4x?xi32>
+ %abs = tosa.abs %transpose1 : (tensor<1x?x?x2xi32>) -> tensor<1x?x?x2xi32>
+ %add = tosa.add %clamp, %abs : (tensor<?x3x4x?xi32>, tensor<1x?x?x2xi32>) -> tensor<1x3x4x2xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %result = tosa.transpose %add, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ return %result : tensor<1x2x3x4xi32>
+}
+
+// -----
+
+// COM: we cannot do anything to the transpose in this case.
+// CHECK-LABEL: @test_unimplemented_non_const_perms
+// CHECK: tosa.const
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: return
+func.func @test_unimplemented_non_const_perms(%perms: tensor<2xi32>) -> tensor<?x?xi32> {
+ %0 = "tosa.const"() {value = dense<0> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
+ %1 = tosa.transpose %0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<?x?xi32>
+ return %1 : tensor<?x?xi32>
+}
+
+// -----
+
+// COM: due to tracking back to a non-nullifying transpose, we can't get rid of the transposes entirely.
+// COM: later editions of the pass may wish to fold these into a single transpose.
+// CHECK-LABEL: @test_unimplemented_transpose_tracks_to_non_nullifying_transpose_single_step
+// CHECK: tosa.const
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: tosa.clamp
+// CHECK-NEXT: tosa.const
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: return
+func.func @test_unimplemented_transpose_tracks_to_non_nullifying_transpose_single_step(%arg0: tensor<1x2x3x4xi32>) -> tensor<1x2x4x3xi32> {
+ %perms0 = "tosa.const"() {value = dense<[0, 3, 2, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x4x3x2xi32>
+ %clamp = tosa.clamp %0 {min_int = 0 : i64, max_int = 1 : i64, min_fp = 0.0 : f64, max_fp = 1.0 : f64} : (tensor<1x4x3x2xi32>) -> tensor<1x4x3x2xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %1 = tosa.transpose %clamp, %perms1 : (tensor<1x4x3x2xi32>, tensor<4xi32>) -> tensor<1x2x4x3xi32>
+ return %1 : tensor<1x2x4x3xi32>
+}
+
+// -----
+
+// COM: we don't deal with this case. resolution of shapes required.
+// CHECK-LABEL: @test_unimplemented_unknown_dim_input_nullifying_pair
+// CHECK: tosa.const
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: return
+func.func @test_unimplemented_unknown_dim_input_nullifying_pair(%arg0: tensor<3x?xi32>) -> tensor<3x2xi32> {
+ %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+ %0 = tosa.transpose %arg0, %perms : (tensor<3x?xi32>, tensor<2xi32>) -> tensor<2x3xi32>
+ %1 = tosa.transpose %0, %perms : (tensor<2x3xi32>, tensor<2xi32>) -> tensor<3x2xi32>
+ return %1 : tensor<3x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_unimplemented_unknown_dim_replacement_does_not_match
+// CHECK: tosa.const
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: return
+func.func @test_unimplemented_unknown_dim_replacement_does_not_match(%arg0: tensor<3x?xi32>) -> tensor<?x?xi32> {
+ %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+ %0 = tosa.transpose %arg0, %perms : (tensor<3x?xi32>, tensor<2xi32>) -> tensor<?x3xi32>
+ %1 = tosa.transpose %0, %perms : (tensor<?x3xi32>, tensor<2xi32>) -> tensor<?x?xi32>
+ return %1 : tensor<?x?xi32>
+}
+
+// -----
+
+// COM: this would be able to be converted if --tosa-infer-shapes was run beforehand
+// CHECK-LABEL: @test_unimplemented_unranked_tensors_present
+// CHECK: tosa.const
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: return
+func.func @test_unimplemented_unranked_tensors_present(%arg0: tensor<3x2xi32>) -> tensor<*xi32> {
+ %perms = "tosa.const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> tensor<2xi32>
+ %0 = tosa.transpose %arg0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<*xi32>
+ %1 = tosa.transpose %0, %perms : (tensor<*xi32>, tensor<2xi32>) -> tensor<*xi32>
+ return %1 : tensor<*xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_unimplemented_unranked_everything
+// CHECK: tosa.const
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: return
+func.func @test_unimplemented_unranked_everything(%arg0: tensor<*xi32>) -> tensor<*xi32> {
+ %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+ %0 = tosa.transpose %arg0, %perms : (tensor<*xi32>, tensor<2xi32>) -> tensor<*xi32>
+ %1 = tosa.transpose %0, %perms : (tensor<*xi32>, tensor<2xi32>) -> tensor<*xi32>
+ return %1 : tensor<*xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_unimplemented_static_diverges_to_one_nullifying_one_non_nullifying
+// CHECK: tosa.const
+// CHECK-NEXT: tosa.const
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: tosa.clamp
+// CHECK-NEXT: tosa.abs
+// CHECK-NEXT: tosa.add
+// CHECK-NEXT: tosa.const
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: return
+func.func @test_unimplemented_static_diverges_to_one_nullifying_one_non_nullifying(%arg0: tensor<1x2x3x4xi32>, %arg1: tensor<1x2x4x3xi32>) -> tensor<1x2x3x4xi32> {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 2, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %transpose0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %transpose1 = tosa.transpose %arg1, %perms1 : (tensor<1x2x4x3xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %clamp = tosa.clamp %transpose0 {min_int = 0 : i64, max_int = 1 : i64, min_fp = 0.0 : f64, max_fp = 1.0 : f64} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %abs = tosa.abs %transpose1 : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %add = tosa.add %clamp, %abs : (tensor<1x3x4x2xi32>, tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %perms2 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %result = tosa.transpose %add, %perms2 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ return %result : tensor<1x2x3x4xi32>
+}
More information about the Mlir-commits
mailing list