[Mlir-commits] [mlir] [MLIR][TOSA] Add --tosa-remove-redundant-transposes pass (PR #108260)
Arteen Abrishami
llvmlistbot at llvm.org
Thu Sep 12 13:24:05 PDT 2024
https://github.com/arteen1000 updated https://github.com/llvm/llvm-project/pull/108260
>From ee9b2b0ac32f322f0182feb935ae249a55076875 Mon Sep 17 00:00:00 2001
From: Arteen Abrishami <arteen.abrishami at arm.com>
Date: Thu, 12 Sep 2024 20:18:29 +0000
Subject: [PATCH] [MLIR][TOSA] Add --tosa-remove-redundant-transposes pass
----------
Motivation:
----------
Some legalization pathways introduce redundant tosa.TRANSPOSE
operations that result in avoidable data movement. For example,
PyTorch -> TOSA contains a lot of unnecessary transposes due
to conversions between NCHW and NHWC.
We wish to remove all the ones that we can, since in general
it is possible to remove the overwhelming majority.
------------
Changes Made:
------------
- Add the --tosa-remove-redundant-transposes pass
- Add TosaElementwiseOperator trait.
-------------------
High-Level Overview:
-------------------
The pass begins at a downstream transpose with some perms tensor.
It traverses the dependencies upward, accepting only TosaElementwise operators.
Dependencies must terminate in nullifying transposes (when composed, they form the identity),
reshapes we can fold the transpose into, or consts.
Conceptually, we then "bubble up" the downstream transpose until
we hit the sources. For constants, we generate a new constants, composed
with the downstream transpose. For nullifying transposes, we "cancel"
them. For reshapes, we fold the transpose into them.
We then ensure that we do not cause any duplication by "converting"
this chain we bubbled-up into its transposed form. We do this by analyzing
the dependency fan-ins across all transposes with the same perms tensor
in order to ensure that they do not have uses outside this group, which
would cause the old code section to remain "live", and not removed by
DCE.
We then perform a simple one-pass DCE, so no canonicalization is necessary.
--------------
Impact of Pass:
--------------
Patching the dense_resource artifacts (from PyTorch) with dense attributes to
permit constant folding, we receive the following results.
Note that data movement represents total transpose data movement, calculated
by noting which dimensions moved during the transpose.
///////////
MobilenetV3:
///////////
BEFORE total data movement: 11798776 B (11.25 MiB)
AFTER total data movement: 2998016 B (2.86 MiB)
74.6% of data movement removed.
BEFORE transposes: 82
AFTER transposes: 20
75.6% of transposes removed.
////////
ResNet18:
////////
BEFORE total data movement: 20596556 B (19.64 MiB)
AFTER total data movement: 1003520 B (0.96 MiB)
95.2% of data movement removed.
BEFORE transposes: 56
AFTER transposes: 5
91.1% of transposes removed.
////////
ResNet50:
////////
BEFORE total data movement: 83236172 B (79.3 MiB)
AFTER total data movement: 3010560 B (2.87 MiB)
96.4% of data movement removed
BEFORE transposes: 120
AFTER transposes: 7
94.2% of transposes removed.
/////////
ResNet101:
/////////
BEFORE total data movement: 124336460 B (118.58 MiB)
AFTER total data movement: 3010560 B (2.87 MiB)
97.6% of data movement removed
BEFORE transposes: 239
AFTER transposes: 7
97.1% of transposes removed.
/////////
ResNet152:
/////////
BEFORE total data movement: 175052108 B (166.94 MiB)
AFTER total data movement: 3010560 B (2.87 MiB)
98.3% of data movement removed
BEFORE transposes: 358
AFTER transposes: 7
98.0% of transposes removed.
////////
Overview:
////////
We see that we remove up to 98% of transposes and eliminate
up to 98.3% of redundant transpose data movement.
In the context of ResNet50, with 120 inferences per second,
we reduce dynamic transpose data bandwidth from 9.29 GiB/s
to 344.4 MiB/s.
-----------
Future Work:
-----------
(1)
Evaluate tradeoffs with the duplication of ConstOp, especially
across many downstream transposes with different perms, which can result
in the same ConstOp being duplicated (but transposed) multiple times.
Observe tradeoffs between a lower memory footprint and potentially
converting many fan-ins of downstream transposes with the same perms,
which if not converted may affect ability of other inter-dependent fan-in
to convert.
(2)
Expand the class of foldable upstream ReshapeOp we permit beyond
N -> 1x1x...x1xNx1x...x1x1.
(3)
Make the pass more general, beyond just allowing upstream transposes
to be nullifying. For example,
transpose1 -> ... -> transpose2
where transpose2(transpose1) do not cancel to identity.
This can be done by propagating the downstream transpose up
and inserting after transpose1, just like how it is done for
reshape. However, in the case of chains like
transpose1 -> ... -> transpose2 -> ... -> transpose3
this could require running the current runOnOperation() function
until we converge. This can be done by stopping when all transposes
that we can successfully collect the fan-ins of have the owner
of their first operand being either another TransposeOp or a
ReshapeOp, since those are what we propagate to and where we leave
behind / insert another TransposeOp. Otherwise, we would could potentially
have infinite looping.
Folding of the transposes is then necessary.
(4)
Add support for more instructions (for example, those that reduce
alongside an axis) to be one of the intervening operations in the
fan-in cones (other than those with TosaElementwiseOperator trait).
(5)
Support bubbling transposes up to the input parameter. May not
need extensive fan-in analysis as no operation cost associated
if used elsewhere.
Signed-off-by: Arteen Abrishami <arteen.abrishami at arm.com>
---
.../mlir/Dialect/Tosa/IR/TosaOpBase.td | 10 +
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h | 6 +
.../mlir/Dialect/Tosa/Transforms/Passes.td | 14 +
.../Dialect/Tosa/Transforms/CMakeLists.txt | 1 +
.../TosaRemoveRedundantTransposes.cpp | 733 ++++++++++++++++++
.../tosa-remove-redundant-transposes.mlir | 649 ++++++++++++++++
6 files changed, 1413 insertions(+)
create mode 100644 mlir/lib/Dialect/Tosa/Transforms/TosaRemoveRedundantTransposes.cpp
create mode 100644 mlir/test/Dialect/Tosa/tosa-remove-redundant-transposes.mlir
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 1412c7a2615d20..64bacd0e432fe5 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -206,6 +206,15 @@ def Tosa_ExplicitValuePadOpQuantInfoBuilder : OpBuilder<
input, paddings, pad_value);
}]>;
+//===----------------------------------------------------------------------===//
+// TOSA Operator Trait.
+//===----------------------------------------------------------------------===//
+
+// Permits broadcasting. Elementwise trait is too strict.
+def TosaElementwiseOperator : NativeOpTrait<"TosaElementwiseOperator"> {
+ let cppNamespace = "mlir::OpTrait::tosa";
+}
+
//===----------------------------------------------------------------------===//
// TOSA Operator Class.
//===----------------------------------------------------------------------===//
@@ -219,6 +228,7 @@ class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
ResultsBroadcastableShape,
+ TosaElementwiseOperator,
Pure])> {
let assemblyFormat =
"operands attr-dict `:` functional-type(operands, results)";
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index 7ed89bff474a2e..66512cbe350ec8 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -84,6 +84,12 @@ class MulOperandsAndResultElementType
}
};
+/// This class indicates that an op is tosa-elementwise (permits broadcasting,
+/// unlike Elementwise trait).
+template <typename ConcreteType>
+class TosaElementwiseOperator
+ : public TraitBase<ConcreteType, TosaElementwiseOperator> {};
+
} // namespace tosa
} // namespace OpTrait
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index a0f670de20150f..5159d258d0f26f 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -126,4 +126,18 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
];
}
+def TosaRemoveRedundantTransposes : Pass<"tosa-remove-redundant-transposes", "func::FuncOp"> {
+ let summary = "Remove redundant transposes";
+ let description = [{
+ Pass that identifies and removes redundant tosa.TRANSPOSE operations.
+ It does so by traversing dependencies of tosa.TRANSPOSE operations until they terminate in either
+ tosa.RESHAPE, a nullifying tosa.TRANSPOSE, or a tosa.CONST. It then propagates the downstream
+ transform upward through the intervening operators if it is able and replaces the downstream tosa.TRANSPOSE.
+ Results generally better when run after canonicalization and resolution of dynamic shapes.
+ This pass has an important use-case in cleaning up the results of frameworks that introduce a lot
+ of data-layout transformations when legalizing to TOSA, a common one being transformations between NHWC and NCHW
+ layouts.
+ }];
+}
+
#endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index c78a74b874aff1..624038b9b38981 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
TosaLayerwiseConstantFoldPass.cpp
TosaMakeBroadcastable.cpp
TosaOptionalDecompositions.cpp
+ TosaRemoveRedundantTransposes.cpp
TosaTypeConverters.cpp
TosaValidation.cpp
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaRemoveRedundantTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaRemoveRedundantTransposes.cpp
new file mode 100644
index 00000000000000..183700d117f3cc
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaRemoveRedundantTransposes.cpp
@@ -0,0 +1,733 @@
+//===- TosaRemoveRedundantTransposes.cpp
+//------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+// ----------
+// Motivation:
+// ----------
+
+// Some legalization pathways introduce redundant tosa.TRANSPOSE
+// operations that result in avoidable data movement. For example,
+// PyTorch -> TOSA contains a lot of unnecessary transposes due
+// to conversions between NCHW and NHWC.
+
+// We wish to remove all the ones that we can, since in general
+// it is possible to remove the overwhelming majority.
+
+// -------------------
+// High-Level Overview:
+// -------------------
+
+// The pass begins at a downstream transpose with some perms tensor.
+// It traverses the dependencies upward, accepting only TosaElementwise
+// operators. Dependencies must terminate in nullifying transposes (when
+// composed, they form the identity), reshapes we can fold the transpose into,
+// or consts.
+
+// Conceptually, we then "bubble up" the downstream transpose until
+// we hit the sources. For constants, we generate a new constants, composed
+// with the downstream transpose. For nullifying transposes, we "cancel"
+// them. For reshapes, we fold the transpose into them.
+
+// We then ensure that we do not cause any duplication by "converting"
+// this chain we bubbled-up into its transposed form. We do this by analyzing
+// the dependency fan-ins across all transposes with the same perms tensor
+// in order to ensure that they do not have uses outside this group, which
+// would cause the old code section to remain "live", and not removed by
+// DCE.
+
+// We then perform a simple one-pass DCE, so no canonicalization is necessary.
+
+// --------------
+// Impact of Pass:
+// --------------
+
+// We note that up to 98.3% of transpose data movement and 98.0%
+// of transposes can be removed from MobilenetV3 and ResNet networks.
+
+// -----------
+// Future Work:
+// -----------
+
+// (1)
+
+// Evaluate tradeoffs with the duplication of ConstOp, especially
+// across many downstream transposes with different perms, which can result
+// in the same ConstOp being duplicated (but transposed) multiple times.
+
+// Observe tradeoffs between a lower memory footprint and potentially
+// converting many fan-ins of downstream transposes with the same perms,
+// which if not converted may affect ability of other inter-dependent fan-in
+// to convert.
+
+// (2)
+
+// Expand the class of foldable upstream ReshapeOp we permit beyond
+// N -> 1x1x...x1xNx1x...x1x1.
+
+// (3)
+
+// Make the pass more general, beyond just allowing upstream transposes
+// to be nullifying. For example,
+
+// transpose1 -> ... -> transpose2
+
+// where transpose2(transpose1) do not cancel to identity.
+
+// This can be done by propagating the downstream transpose up
+// and inserting after transpose1, just like how it is done for
+// reshape. However, in the case of chains like
+
+// transpose1 -> ... -> transpose2 -> ... -> transpose3
+
+// this could require running the current runOnOperation() function
+// until we converge. This can be done by stopping when all transposes
+// that we can successfully collect the fan-ins of have the owner
+// of their first operand being either another TransposeOp or a
+// ReshapeOp, since those are what we propagate to and where we leave
+// behind / insert another TransposeOp. Otherwise, we would could potentially
+// have infinite looping.
+
+// Folding of the transposes is then necessary.
+
+// (4)
+
+// Add support for more instructions (for example, those that reduce
+// alongside an axis) to be one of the intervening operations in the
+// fan-in cones (other than those with TosaElementwiseOperator trait).
+
+// (5)
+
+// Support bubbling transposes up to the input parameter. May not
+// need extensive fan-in analysis as no operation cost associated
+// if used elsewhere.
+
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/IR/Iterators.h"
+#include "mlir/IR/Matchers.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include <memory>
+#include <set>
+#include <stack>
+
+namespace mlir {
+namespace tosa {
+#define GEN_PASS_DEF_TOSAREMOVEREDUNDANTTRANSPOSES
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+} // namespace tosa
+} // namespace mlir
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// TOSA Remove Redundant Transposes Pass.
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+struct TosaRemoveRedundantTransposes final
+ : public tosa::impl::TosaRemoveRedundantTransposesBase<
+ TosaRemoveRedundantTransposes> {
+ void runOnOperation() override;
+
+private:
+ // This will collect all the data dependencies for the given Operation
+ // up to and including ConstOp, ReshapeOp, and TransposeOp.
+ bool collectFanIn(Operation *op, SetVector<Operation *> &collected);
+ bool convertDependentOps(SetVector<Operation *> &dependentOps,
+ DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter,
+ ArrayRef<int32_t> downstreamPerms);
+
+ // Checks if the two permutations, when applied consecutively, result
+ // in the identity.
+ bool areNullifyingTransposes(ArrayRef<int32_t> perms1,
+ ArrayRef<int32_t> perms2);
+
+ // This is meant to apply to operations with the TosaElementwiseOperator
+ // trait.
+ std::optional<Value>
+ buildMappedToValue(Operation *op, const DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms);
+
+ // This updates valuesMap when we encounter another TransposeOp as a
+ // dependency of the downstream one. %0 = tosa.transpose %arg0 <- applies to
+ // this %1 = tosa.transpose %0 <- when tracking back from this
+ std::optional<Value>
+ buildMappedToValue(tosa::TransposeOp transposeOp,
+ const DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms);
+
+ // Inserts the downstream TransposeOp after the ReshapeOp, since we generally
+ // cannot propagate through it.
+ std::optional<Value>
+ buildMappedToValue(tosa::ReshapeOp reshapeOp,
+ const DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms);
+
+ // We may have something like:
+ // %0 = tosa.const
+ // %1 = tosa.transpose
+ // %2 = tosa.add %0, %1
+ // %3 = tosa.transpose %2
+ // that --tosa-layerwise-const-fold wouldn't handle. This use shows up
+ // in MobilenetV3.
+ std::optional<Value>
+ buildMappedToValue(tosa::ConstOp constOp,
+ const DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms);
+
+ // Checks which TransposeOp we should "replace", turning their converted
+ // chains of ops, through which they were propagated, "live", and the old code
+ // "dead." Attempts to avoid doing so when doing so would result in the old
+ // code staying "live," resulting in duplication.
+ std::set<tosa::TransposeOp> getGoodReplacements(
+ ArrayRef<int32_t> perms,
+ std::vector<std::pair<tosa::TransposeOp, SetVector<Operation *>>>
+ &transposeInfo);
+
+ // Helper function for getGoodReplacements to check if some TransposeOp's
+ // dependencies are OK.
+ bool dependenciesAreValid(
+ ArrayRef<int32_t> perms, const SetVector<Operation *> &dependentOps,
+ std::set<tosa::TransposeOp> &validTransposes,
+ std::vector<std::pair<tosa::TransposeOp, SetVector<Operation *>>>
+ &transposeInfo);
+
+ // Applies perms to the DenseElementsAttr.
+ // If it returns std::nullopt, it also triggers pass failure, since verifier
+ // guarantees from TOSA are not in place (and otherwise, if used elsewhere
+ // it should fail).
+ // This is a basic API and may benefit from refactor into the core MLIR APIs.
+ std::optional<DenseElementsAttr>
+ transposeDenseAttribute(DenseElementsAttr input, ArrayRef<int32_t> perms);
+};
+
+std::optional<DenseElementsAttr>
+TosaRemoveRedundantTransposes::transposeDenseAttribute(
+ DenseElementsAttr input, ArrayRef<int32_t> perms) {
+ RankedTensorType oldType = llvm::cast<RankedTensorType>(input.getType());
+ RankedTensorType newType = RankedTensorType::get(
+ tosa::applyTOSAPermutation(oldType.getShape(), perms),
+ oldType.getElementType());
+ size_t rank = oldType.getRank();
+
+ if (input.isSplat())
+ return input.reshape(newType);
+ // Asserted by TransposeOp verifier and TOSA disallowing tensor with dimension
+ // 0.
+ // If not in place, something is very wrong.
+ if (rank <= 0 || oldType.getNumElements() <= 0 || perms.size() != rank) {
+ signalPassFailure();
+ return std::nullopt;
+ }
+
+ // The algorithm is approximately as follows:
+ // input: perms, input flat array, input tensor type
+ // (1/2) determine the strides of input/output if
+ // they were strided in row-major order. (3) adjust the strides for the
+ // input to be in the same order of indices as the output is written.
+ // (4) process dimension by dimension. example: perms 2, 0, 1; input
+ // 2x3x4; output 4x2x3 for i ... 4, j ... 2, k ... 3: output[i][j][k] =
+ // input[j][k][i] output[6i + 3j + k] = input[12j + 4k + i] and we adjust
+ // input strides to be as input[i + 12j + 4k] so we may process
+ // layer-by-layer.
+
+ // Step 1/2: Strides for input. We ignore output since row-major and can just
+ // push_back.
+
+ SmallVector<int64_t> originalInputStrides(rank);
+ originalInputStrides[rank - 1] = 1;
+ // index with int64_t to avoid overflow
+ for (int64_t i = rank - 2; i >= 0; i--)
+ originalInputStrides[i] =
+ originalInputStrides[i + 1] * oldType.getDimSize(i + 1);
+
+ // Step 3: Transpose strides of input to be same indexing (i, j, k, ...) as
+ // output which is done in row-major order.
+
+ SmallVector<int64_t> newInputStrides;
+ newInputStrides.reserve(rank);
+ for (int32_t v : perms)
+ newInputStrides.push_back(originalInputStrides[v]);
+
+ // Step 4: Write out the transposed "flat array" dimension by dimension.
+
+ auto inputArray = input.getValues<Attribute>();
+ SmallVector<std::pair<int64_t, int64_t>> boundsAndStrides;
+ for (size_t i = 0; i < rank; i++)
+ boundsAndStrides.push_back({newType.getDimSize(i), newInputStrides[i]});
+
+ SmallVector<Attribute> resultArray;
+ resultArray.reserve(inputArray.size());
+
+ std::function<void(int64_t,
+ SmallVector<std::pair<int64_t, int64_t>>::const_iterator)>
+ processTransposeDim = [&](auto accumulatedIndex, auto it) {
+ if (it == boundsAndStrides.end()) {
+ resultArray.push_back(inputArray[accumulatedIndex]);
+ return;
+ }
+
+ for (int64_t i = 0; i < it->first; i++) {
+ int64_t j = accumulatedIndex + i * it->second;
+ processTransposeDim(j, it + 1);
+ }
+ };
+
+ processTransposeDim(0, boundsAndStrides.begin());
+
+ return DenseElementsAttr::get(newType, resultArray);
+}
+
+// The SetVector should only contain ConstOp, ReshapeOp, TransposeOp
+// as the sources of the data dependencies, and TosaElementWiseOperator
+// after that, if the function returns true.
+bool TosaRemoveRedundantTransposes::collectFanIn(
+ Operation *op, SetVector<Operation *> &collected) {
+ // Can occur if defined through the parameter to a func.func.
+ if (!op)
+ return false;
+
+ if (!llvm::isa_and_present<tosa::TosaDialect>(op->getDialect()))
+ return false;
+
+ // Prevent extra work if already seen.
+ if (collected.contains(op))
+ return true;
+
+ // Throw it out so later don't have to deal with this.
+ if (op->getNumResults() != 1 ||
+ !llvm::isa<RankedTensorType>(op->getResult(0).getType()))
+ return false;
+
+ // We don't wish to traverse up a ReshapeOp,
+ // since generally we can't propagate a TransposeOp through it.
+ // TransposeOp, ReshapeOp, ConstOp will have no in-edges in the data
+ // dependency graph we construct for the downstream TransposeOp.
+ if (!llvm::isa<tosa::TransposeOp>(op) && !llvm::isa<tosa::ReshapeOp>(op) &&
+ !llvm::isa<tosa::ConstOp>(op)) {
+
+ if (!op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>())
+ return false;
+
+ for (Value operand : op->getOperands()) {
+
+ if (!collectFanIn(operand.getDefiningOp(), collected))
+ return false;
+ }
+ }
+
+ // Insert in topological order.
+ collected.insert(op);
+
+ return true;
+}
+
+// Assuming that due to the verification of TransposeOp
+// perms arrays are permutations of 0 - perms.size() - 1.
+bool TosaRemoveRedundantTransposes::areNullifyingTransposes(
+ ArrayRef<int32_t> perms1, ArrayRef<int32_t> perms2) {
+ if (perms1.size() != perms2.size())
+ return false;
+ for (int32_t i = 0; i < static_cast<int32_t>(perms1.size()); i++)
+ if (perms2[perms1[i]] != i)
+ return false;
+ return true;
+}
+
+// Primary overload for those with TosaElementwiseOperator trait.
+// The other ones handle the case of the operations that occur at the
+// roots of the data dependency graph (ConstOp, ReshapeOp, TransposeOp).
+std::optional<Value> TosaRemoveRedundantTransposes::buildMappedToValue(
+ Operation *op, const DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms) {
+ if (op->getNumResults() != 1 ||
+ !op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>())
+ return std::nullopt;
+
+ auto resultType = llvm::cast<RankedTensorType>(op->getResult(0).getType());
+ SmallVector<Value, 3> operands;
+ for (Value v : op->getOperands()) {
+ if (valuesMap.contains(v)) {
+ operands.push_back(valuesMap.at(v));
+ } else {
+ return std::nullopt;
+ }
+ }
+
+ // Conceptually, we propagate the downstream TransposeOp through
+ // these interveaning operations. For example,
+ // %0 = tosa.clamp %input : (tensor<2x3xi32>) -> tensor<2x3xi32>
+ // %1 = tosa.transpose %0 {perms = [1, 0]} : (tensor<2x3xi32>) ->
+ // tensor<3x2xi32> becomes: %0 = tosa.transpose %input {perms = [1, 0]} :
+ // (tensor<2x3xi32>) -> tensor<3x2xi32> %1 = tosa.clamp %0 : (tensor<3x2xi32>)
+ // -> tensor<3x2xi32>) We construct this new tosa.clamp here, but it doesn't
+ // turn "live" until the final downstream transpose in the chain (that we are
+ // currently traversing up its dependencies) is replaced with the proper value
+ // from this new chain.
+ return rewriter
+ .create(op->getLoc(),
+ rewriter.getStringAttr(op->getName().getStringRef()), operands,
+ RankedTensorType::get(tosa::applyTOSAPermutation(
+ resultType.getShape(), downstreamPerms),
+ resultType.getElementType()),
+ op->getAttrs())
+ ->getResult(0);
+}
+
+std::optional<Value> TosaRemoveRedundantTransposes::buildMappedToValue(
+ tosa::TransposeOp transposeOp, const DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms) {
+ SmallVector<int32_t> perms;
+ if (failed(transposeOp.getConstantPerms(perms)) ||
+ !areNullifyingTransposes(downstreamPerms, perms))
+ return std::nullopt;
+ return transposeOp.getInput1();
+}
+
+std::optional<Value> TosaRemoveRedundantTransposes::buildMappedToValue(
+ tosa::ReshapeOp reshapeOp, const DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms) {
+ auto reshapeOutput = reshapeOp.getOutput();
+ auto reshapeInputType =
+ llvm::dyn_cast<RankedTensorType>(reshapeOp.getInput1().getType());
+ auto reshapeInputShape = reshapeInputType.getShape();
+ // want reshape N -> 1x1x...x1xNx1x...x1x1
+ if (!reshapeInputType || reshapeInputShape.size() != 1)
+ return std::nullopt;
+ auto reshapeOutputType =
+ llvm::cast<RankedTensorType>(reshapeOutput.getType());
+
+ // Instead of inserting a TransposeOp here, we
+ // check if we can fold it into the ReshapeOp.
+ // There is more complex cases where this is possible,
+ // and this check can be extended.
+
+ // Checking if reshape is N -> 1x1x...x1xNx1x...x1x1
+ auto shape = reshapeOutputType.getShape();
+ size_t ones = llvm::count(shape, 1);
+ // N == 1 and N != 1
+ if (ones != shape.size() - 1 &&
+ !(ones == shape.size() && reshapeInputShape[0] == 1))
+ return std::nullopt;
+
+ // Do not insert a TransposeOp, instead we fold the reshape and its attribute.
+ auto foldedReshape = rewriter.create<tosa::ReshapeOp>(
+ reshapeOp.getLoc(),
+ RankedTensorType::get(tosa::applyTOSAPermutation(shape, downstreamPerms),
+ reshapeOutputType.getElementType()),
+ reshapeOp.getInput1(),
+ rewriter.getDenseI64ArrayAttr(tosa::applyTOSAPermutation(
+ reshapeOp.getNewShape(), downstreamPerms)));
+ return foldedReshape->getResult(0);
+}
+
+std::optional<Value> TosaRemoveRedundantTransposes::buildMappedToValue(
+ tosa::ConstOp constOp, const DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms) {
+ auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(constOp.getValue());
+ if (!denseAttr)
+ return std::nullopt;
+ auto maybeNewDenseAttr = transposeDenseAttribute(denseAttr, downstreamPerms);
+ if (!maybeNewDenseAttr.has_value())
+ return std::nullopt;
+ auto newDenseAttr = maybeNewDenseAttr.value();
+ auto newConstOp = rewriter.create<tosa::ConstOp>(
+ constOp.getLoc(), newDenseAttr.getType(), newDenseAttr);
+ return newConstOp->getResult(0);
+}
+
+bool TosaRemoveRedundantTransposes::convertDependentOps(
+ SetVector<Operation *> &dependentOps, DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms) {
+
+ for (Operation *op : dependentOps) {
+ if (!op || op->getNumResults() != 1)
+ return false;
+
+ Value priorValue = op->getResult(0);
+
+ // It's possible on a prior transposeOp
+ // we had the same dependency and already resolved it.
+ if (valuesMap.contains(priorValue))
+ continue;
+
+ // Keep converted ops close to the original.
+ rewriter.setInsertionPointAfter(op);
+
+ std::optional<Value> maybeValue =
+ llvm::TypeSwitch<Operation *, std::optional<Value>>(op)
+ .Case<tosa::TransposeOp>([&](tosa::TransposeOp transposeOp) {
+ return buildMappedToValue(transposeOp, valuesMap, rewriter,
+ downstreamPerms);
+ })
+ .Case<tosa::ReshapeOp>([&](tosa::ReshapeOp reshapeOp) {
+ return buildMappedToValue(reshapeOp, valuesMap, rewriter,
+ downstreamPerms);
+ })
+ .Case<tosa::ConstOp>([&](tosa::ConstOp constOp) {
+ return buildMappedToValue(constOp, valuesMap, rewriter,
+ downstreamPerms);
+ })
+ .Default([&](Operation *op) {
+ return buildMappedToValue(op, valuesMap, rewriter,
+ downstreamPerms);
+ });
+
+ if (!maybeValue.has_value())
+ return false;
+
+ valuesMap[priorValue] = maybeValue.value();
+ }
+
+ return true;
+}
+
+// Dependencies are valid for an operation if none of them occur outside
+// of the proper fan-in cones of the downstream TransposeOp with the same perms
+// that we can replace. Described in more detail within.
+bool TosaRemoveRedundantTransposes::dependenciesAreValid(
+ ArrayRef<int32_t> perms, const SetVector<Operation *> &dependentOps,
+ std::set<tosa::TransposeOp> &validTransposes,
+ std::vector<std::pair<tosa::TransposeOp, SetVector<Operation *>>>
+ &transposeInfo) {
+ for (Operation *op : dependentOps) {
+
+ // It's OK wherever ConstOp has uses -- in the worst case, we duplicate.
+ // This can be changed later if we find the memory impact is too high.
+ if (llvm::isa<tosa::ConstOp>(op))
+ continue;
+
+ for (OpOperand &use : op->getUses()) {
+ // Want the uses to be (1) contained in the dependentOps of other
+ // validTransposes, or (2) to be directly used in a TransposeOp with the
+ // same perms. For (2) it means the fan-in is a subset of our
+ // dependentOps, so it is also a validTranspose that will eventually be
+ // replaced.
+ Operation *user = use.getOwner();
+ if (auto otherTranspose = llvm::dyn_cast<tosa::TransposeOp>(user)) {
+ SmallVector<int32_t> otherPerms;
+
+ // Can later think about cases where transpose -> transpose
+ // or reshape -> transpose, where the transposes are not necessarily
+ // the same perms as the downstream, if implementing a more general
+ // transform. These could be permitted.
+ if (failed(otherTranspose.getConstantPerms(otherPerms)) ||
+ !llvm::equal(perms, otherPerms))
+ return false;
+
+ } else if (llvm::none_of(
+ transposeInfo,
+ [&validTransposes,
+ user](const std::pair<tosa::TransposeOp,
+ SetVector<Operation *>> &info) {
+ const auto &[transposeOp, dependentOps] = info;
+ return validTransposes.count(transposeOp) &&
+ dependentOps.contains(user);
+ })) {
+ return false;
+ }
+ }
+ }
+
+ return true;
+}
+
+// Getting the set of TransposeOp that we can replace without causing
+// the old fan-in cones of any TransposeOp to remain "live", i.e, -- not being
+// dead code. This is done by iterating the set until convergence, since
+// if you are used outside your own fan-in cone, it's possible to be used
+// in another fan-in cone of a TransposeOp that is being replaced -- unless
+// we find that that one has a usage outside of it too.
+std::set<tosa::TransposeOp> TosaRemoveRedundantTransposes::getGoodReplacements(
+ ArrayRef<int32_t> perms,
+ std::vector<std::pair<tosa::TransposeOp, SetVector<Operation *>>>
+ &transposeInfo) {
+ // Initially, we assume they are all good to replace,
+ // and we whittle them down based on our criteria.
+ std::set<tosa::TransposeOp> ableToReplace;
+ for (const auto &[transposeOp, _] : transposeInfo)
+ ableToReplace.insert(transposeOp);
+
+ bool gotRid;
+ do {
+ gotRid = false;
+ for (const auto &[transposeOp, dependentOps] : transposeInfo) {
+ // We don't care about it. Already invalidated.
+ if (!ableToReplace.count(transposeOp))
+ continue;
+
+ // Check for validity.
+ if (!dependenciesAreValid(perms, dependentOps, ableToReplace,
+ transposeInfo)) {
+ ableToReplace.erase(transposeOp);
+ gotRid = true;
+ break;
+ }
+ }
+
+ } while (gotRid);
+
+ return ableToReplace;
+}
+
+void TosaRemoveRedundantTransposes::runOnOperation() {
+ // We want to operate only within a single block.
+ // Call --inline before to run the pass.
+ if (!getOperation().getRegion().hasOneBlock())
+ return;
+
+ IRRewriter rewriter(&getContext());
+ // For each perms, maintain a mapping for converted ops, avoid duplication.
+ DenseMap<ArrayRef<int32_t>, DenseMap<Value, Value>> permsToValues;
+ // For each perms, we keep track of which tosa::TransposeOp are eligible
+ // for replacement alongside their dependentOps.
+ DenseMap<ArrayRef<int32_t>,
+ std::vector<std::pair<tosa::TransposeOp, SetVector<Operation *>>>>
+ permsToTransposeInfo;
+
+ // Necessary for lifetime, since DenseMap keeps a copy of the ArrayRef.
+ // Use SmallVector for perms (common-case is <= 4) but std::vector otherwise
+ // since no guarantee of smallness.
+ std::vector<SmallVector<int32_t>> collectedPerms;
+
+ // This keeps track of the order across all eligible-for-replacement
+ // TransposeOp and their perms, a necessity for the final replacements.
+ std::stack<std::pair<tosa::TransposeOp, ArrayRef<int32_t>>>
+ totalTransposeOrder;
+
+ // We want to reserve the space up front,
+ // since SmallVector stores some data internally
+ // and the ArrayRef can reference that, which we don't want to get
+ // invalidated.
+ size_t expectedMaxPerms = 0;
+ getOperation().walk([&](tosa::TransposeOp) { expectedMaxPerms += 1; });
+ collectedPerms.reserve(expectedMaxPerms);
+
+ getOperation().walk([&](tosa::TransposeOp transposeOp) {
+ SetVector<Operation *> dependentOps;
+ collectedPerms.emplace_back();
+ SmallVector<int32_t> &perms = collectedPerms.back();
+
+ // Dynamic shapes are OK,
+ // but the incompatible ones will be rejected later.
+ auto input = transposeOp.getInput1();
+ auto output = transposeOp.getOutput();
+
+ // However, we don't support unranked tensors.
+ if (!llvm::isa<RankedTensorType>(input.getType()) ||
+ !llvm::isa<RankedTensorType>(output.getType()))
+ return;
+
+ // No transformation when transpose permutation non-constant.
+ if (failed(transposeOp.getConstantPerms(perms)))
+ return;
+
+ // We let --canonicalize deal with identity transpose.
+ if (llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
+ return;
+
+ // Can fail if some set of basic invariants is not met that we want to
+ // perform our conversions.
+ if (!collectFanIn(input.getDefiningOp(), dependentOps))
+ return;
+
+ // Want to associate valuesMap for already converted of the same perms,
+ // since it's possible multiple downstream transposes w/ different perms
+ // converge on an op, which would result in different transformations.
+ DenseMap<Value, Value> &valuesMap = permsToValues[perms];
+
+ // Attempt to perform the conversions and placements into IR
+ // without turning inserted code "live". Also fills out valuesMap.
+ // Fails if there is an intermediary we do not support.
+ if (!convertDependentOps(dependentOps, valuesMap, rewriter, perms))
+ // Some additional operations may have been inserted, but will be
+ // removed by dead code elimination.
+ return;
+
+ // This should not happen. If it does -- it's unexpected,
+ // so we fail the pass.
+ if (!valuesMap.contains(input))
+ return signalPassFailure();
+
+ // It's possible the types are not compatible (because of dynamic shapes),
+ // and in these cases, want to resolve dynamic shapes before running the
+ // pass.
+ if (output.getType() != valuesMap.at(input).getType())
+ return;
+
+ auto &transposeInfo = permsToTransposeInfo[perms];
+
+ // In general, we might also want to introduce "newDependentOps"
+ // if there are new usages that don't fall inside the original fan-ins
+ // (like the tosa::TransposeOp we insert for tosa::ReshapeOp),
+ // but in this case, that is specialized enough and overlaps
+ // with another direct-use tosa::TransposeOp case we need to cover anyway.
+ transposeInfo.push_back({transposeOp, dependentOps});
+
+ // This is for the final replacement across all transposes.
+ totalTransposeOrder.push({transposeOp, perms});
+ });
+
+ // We want to do a full fan-in analysis on a perms-level,
+ // since if we do it on a multi-perms level, and they share (due to a shared
+ // dependency on a Reshape) then we would also get duplicate ops.
+ // Const is special cased.
+ std::set<tosa::TransposeOp> ableToReplace;
+ for (auto &[perms, transposeInfo] : permsToTransposeInfo) {
+ // Gives us back replacements that would never result in any duplicate
+ // operations being inserted by us in the IR (i.e, our goal is only to
+ // remove transposes, and not create a "new chain" to do so, but replace
+ // the existing chains).
+ // Ideally, --canonicalize is run before this pass, since it helps this
+ // analysis by removing dead code to allow more potentially acceptable
+ // transformations.
+ auto goodReplacementsForPerms = getGoodReplacements(perms, transposeInfo);
+ ableToReplace.insert(goodReplacementsForPerms.begin(),
+ goodReplacementsForPerms.end());
+ }
+
+ // We want to do replacement across all transposes
+ // in reverse order, due to invalidation of valuesMap mappings
+ // if we did it otherwise.
+ while (!totalTransposeOrder.empty()) {
+ auto [transposeOp, perms] = totalTransposeOrder.top();
+ totalTransposeOrder.pop();
+
+ if (ableToReplace.count(transposeOp) == 0)
+ continue;
+
+ auto &valuesMap = permsToValues[perms];
+ auto input = transposeOp.getInput1();
+
+ // The purpose of this reverse iteration
+ // is to avoid valuesMap invalidation. If it happens,
+ // something is wrong.
+ if (!valuesMap.contains(input))
+ return signalPassFailure();
+
+ rewriter.replaceOp(transposeOp, valuesMap.at(input));
+ }
+
+ // We can remove all dead code by going in reverse.
+ // This is because we would remove usages before we
+ // see the users.
+ getOperation().walk<WalkOrder::PostOrder, ReverseIterator>(
+ [&](Operation *op) {
+ if (isOpTriviallyDead(op))
+ rewriter.eraseOp(op);
+ });
+}
+
+} // namespace
diff --git a/mlir/test/Dialect/Tosa/tosa-remove-redundant-transposes.mlir b/mlir/test/Dialect/Tosa/tosa-remove-redundant-transposes.mlir
new file mode 100644
index 00000000000000..9f31b787998e90
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-remove-redundant-transposes.mlir
@@ -0,0 +1,649 @@
+// RUN: mlir-opt --verify-diagnostics --split-input-file --verify-each --tosa-remove-redundant-transposes %s | FileCheck %s
+
+// CHECK-LABEL: @test_transpose_tracks_to_nullifying_single_step
+// CHECK-NEXT: %[[RESULT:.*]] = tosa.ceil %arg0
+// CHECK-NEXT: return %[[RESULT]]
+func.func @test_transpose_tracks_to_nullifying_single_step(%arg0: tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %ceil = tosa.ceil %0 : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %1 = tosa.transpose %ceil, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ return %1 : tensor<1x2x3x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_transpose_tracks_to_nullifying_multi_unary_step
+// CHECK-NEXT: %[[CLAMP:.*]] = tosa.clamp %arg0
+// CHECK-NEXT: %[[ABS:.*]] = tosa.abs %[[CLAMP]]
+// CHECK-NEXT: %[[NOT:.*]] = tosa.bitwise_not %[[ABS]]
+// CHECK-NEXT: return %[[NOT]]
+func.func @test_transpose_tracks_to_nullifying_multi_unary_step(%arg0: tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %clamp = tosa.clamp %0 {max_fp = 1.0 : f32, min_fp = 0.0 : f32, max_int = 1 : i64, min_int = 0 : i64} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %abs = tosa.abs %clamp : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %bitwise_not = tosa.bitwise_not %abs : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %1 = tosa.transpose %bitwise_not, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ return %1 : tensor<1x2x3x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_transpose_tracks_to_nullifying_diverging_binary
+// CHECK-NEXT: %[[CLAMP:.*]] = tosa.clamp %arg0
+// CHECK-NEXT: %[[ABS:.*]] = tosa.abs %arg1
+// CHECK-NEXT: %[[ADD:.*]] = tosa.add %[[CLAMP]], %[[ABS]]
+// CHECK-NEXT: return %[[ADD]]
+func.func @test_transpose_tracks_to_nullifying_diverging_binary(%arg0: tensor<1x2x3x4xi32>, %arg1: tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %transpose0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %transpose1 = tosa.transpose %arg1, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %clamp = tosa.clamp %transpose0 {max_fp = 1.0 : f32, min_fp = 0.0 : f32, max_int = 1 : i64, min_int = 0 : i64} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %abs = tosa.abs %transpose1 : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %add = tosa.add %clamp, %abs : (tensor<1x3x4x2xi32>, tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %result = tosa.transpose %add, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ return %result : tensor<1x2x3x4xi32>
+}
+
+// -----
+
+
+// CHECK-LABEL: @test_transpose_tracks_to_nullifying_diverging_binary_with_broadcasting
+// CHECK-NEXT: %[[CLAMP:.*]] = tosa.clamp %arg0
+// CHECK-NEXT: %[[ABS:.*]] = tosa.abs %arg1
+// CHECK-NEXT: %[[ADD:.*]] = tosa.add %[[CLAMP]], %[[ABS]]
+// CHECK-NEXT: return %[[ADD]]
+func.func @test_transpose_tracks_to_nullifying_diverging_binary_with_broadcasting(%arg0: tensor<1x2x3x4xi32>, %arg1: tensor<1x2x1x4xi32>) -> tensor<1x2x3x4xi32> {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %transpose0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %transpose1 = tosa.transpose %arg1, %perms0 : (tensor<1x2x1x4xi32>, tensor<4xi32>) -> tensor<1x1x4x2xi32>
+ %clamp = tosa.clamp %transpose0 {max_fp = 1.0 : f32, min_fp = 0.0 : f32, max_int = 1 : i64, min_int = 0 : i64} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %abs = tosa.abs %transpose1 : (tensor<1x1x4x2xi32>) -> tensor<1x1x4x2xi32>
+ %add = tosa.add %clamp, %abs : (tensor<1x3x4x2xi32>, tensor<1x1x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %result = tosa.transpose %add, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ return %result : tensor<1x2x3x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_transpose_tracks_to_nullifying__converging_binary
+// CHECK-NEXT: %[[RESULT:.*]] = tosa.add %arg0, %arg0
+// CHECK-NEXT: return %[[RESULT]]
+func.func @test_transpose_tracks_to_nullifying__converging_binary(%arg0: tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %clamp = tosa.add %0, %0 : (tensor<1x3x4x2xi32>, tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %1 = tosa.transpose %clamp, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ return %1 : tensor<1x2x3x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_torch_conv2d_with_elementwise_in_between
+// CHECK: %[[CONV1:.*]] = tosa.conv2d
+// CHECK: %[[CEIL:.*]] = tosa.ceil %[[CONV1]]
+// CHECK: %[[CONV2:.*]] = tosa.conv2d %[[CEIL]]
+// CHECK: %[[FLOOR:.*]] = tosa.floor %[[CONV2]]
+// CHECK: %[[CONV3:.*]] = tosa.conv2d %[[FLOOR]]
+// CHECK: %[[RES:.*]] = tosa.transpose %[[CONV3]]
+// CHECK: return %[[RES]]
+func.func @test_torch_conv2d_with_elementwise_in_between(%arg0: tensor<3x3x10x10xf32>) -> tensor<3x3x7x7xf32> {
+ %0 = "tosa.const"() <{value = dense_resource<torch_tensor_3_torch.float32_2> : tensor<3xf32>}> : () -> tensor<3xf32>
+ %1 = "tosa.const"() <{value = dense_resource<torch_tensor_3_3_2_2_torch.float32_2> : tensor<3x3x2x2xf32>}> : () -> tensor<3x3x2x2xf32>
+ %2 = "tosa.const"() <{value = dense_resource<torch_tensor_3_torch.float32_1> : tensor<3xf32>}> : () -> tensor<3xf32>
+ %3 = "tosa.const"() <{value = dense_resource<torch_tensor_3_3_2_2_torch.float32_1> : tensor<3x3x2x2xf32>}> : () -> tensor<3x3x2x2xf32>
+ %4 = "tosa.const"() <{value = dense_resource<torch_tensor_3_3_2_2_torch.float32> : tensor<3x3x2x2xf32>}> : () -> tensor<3x3x2x2xf32>
+ %5 = "tosa.const"() <{value = dense_resource<torch_tensor_3_torch.float32> : tensor<3xf32>}> : () -> tensor<3xf32>
+ %6 = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32>
+ %7 = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
+ %8 = tosa.transpose %arg0, %6 : (tensor<3x3x10x10xf32>, tensor<4xi32>) -> tensor<3x10x10x3xf32>
+ %9 = tosa.transpose %4, %6 : (tensor<3x3x2x2xf32>, tensor<4xi32>) -> tensor<3x2x2x3xf32>
+ %10 = tosa.conv2d %8, %9, %5 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<3x10x10x3xf32>, tensor<3x2x2x3xf32>, tensor<3xf32>) -> tensor<3x9x9x3xf32>
+ %11 = tosa.transpose %10, %7 : (tensor<3x9x9x3xf32>, tensor<4xi32>) -> tensor<3x3x9x9xf32>
+ %12 = tosa.ceil %11 : (tensor<3x3x9x9xf32>) -> tensor<3x3x9x9xf32>
+ %13 = tosa.transpose %12, %6 : (tensor<3x3x9x9xf32>, tensor<4xi32>) -> tensor<3x9x9x3xf32>
+ %14 = tosa.transpose %3, %6 : (tensor<3x3x2x2xf32>, tensor<4xi32>) -> tensor<3x2x2x3xf32>
+ %15 = tosa.conv2d %13, %14, %2 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<3x9x9x3xf32>, tensor<3x2x2x3xf32>, tensor<3xf32>) -> tensor<3x8x8x3xf32>
+ %16 = tosa.transpose %15, %7 : (tensor<3x8x8x3xf32>, tensor<4xi32>) -> tensor<3x3x8x8xf32>
+ %17 = tosa.floor %16 : (tensor<3x3x8x8xf32>) -> tensor<3x3x8x8xf32>
+ %18 = tosa.transpose %17, %6 : (tensor<3x3x8x8xf32>, tensor<4xi32>) -> tensor<3x8x8x3xf32>
+ %19 = tosa.transpose %1, %6 : (tensor<3x3x2x2xf32>, tensor<4xi32>) -> tensor<3x2x2x3xf32>
+ %20 = tosa.conv2d %18, %19, %0 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<3x8x8x3xf32>, tensor<3x2x2x3xf32>, tensor<3xf32>) -> tensor<3x7x7x3xf32>
+ %21 = tosa.transpose %20, %7 : (tensor<3x7x7x3xf32>, tensor<4xi32>) -> tensor<3x3x7x7xf32>
+ return %21 : tensor<3x3x7x7xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_mulop_conversion
+// CHECK-NEXT: %[[RES:.*]] = tosa.mul %arg0, %arg1
+// CHECK-NEXT: return %[[RES]]
+func.func @test_mulop_conversion(%arg0: tensor<1x2x3x4xi32>, %arg1: tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %transpose0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %transpose1 = tosa.transpose %arg1, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %mul = tosa.mul %transpose0, %transpose1 {shift = 0 : i8} : (tensor<1x3x4x2xi32>, tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %result = tosa.transpose %mul, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ return %result : tensor<1x2x3x4xi32>
+}
+
+// -----
+
+// COM: this case is a reshape we don't convert, since can't fold the transpose into it.
+// COM: a transform actually occurs underneath the hood, but it results in identical IR.
+// CHECK-LABEL: @test_basic_non_broadcasting_reshape
+// CHECK: "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
+// CHECK: tosa.reshape %arg0 {new_shape = array<i64: 1, 3, 2>} : (tensor<2x3xi32>) -> tensor<1x3x2xi32>
+// CHECK: tosa.transpose %1, %0 : (tensor<1x3x2xi32>, tensor<3xi32>) -> tensor<1x2x3xi32>
+func.func @test_basic_non_broadcasting_reshape(%arg0: tensor<2x3xi32>) -> tensor<1x2x3xi32> {
+ %perms = "tosa.const"() {value = dense<[0, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
+ %1 = tosa.reshape %arg0 {new_shape = array<i64: 1, 3, 2>} : (tensor<2x3xi32>) -> tensor<1x3x2xi32>
+ %2 = tosa.transpose %1, %perms : (tensor<1x3x2xi32>, tensor<3xi32>) -> tensor<1x2x3xi32>
+ return %2 : tensor<1x2x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_dynamic_broadcasting_reshape
+// CHECK: %[[RES:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, -1>} : (tensor<?xi32>) -> tensor<1x1x?xi32>
+// CHECK: return %[[RES]]
+func.func @test_dynamic_broadcasting_reshape(%arg0: tensor<?xi32>) -> tensor<1x1x?xi32> {
+ %perms = "tosa.const"() {value = dense<[0, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
+ %1 = tosa.reshape %arg0 {new_shape = array<i64: 1, -1, 1>} : (tensor<?xi32>) -> tensor<1x?x1xi32>
+ %2 = tosa.transpose %1, %perms : (tensor<1x?x1xi32>, tensor<3xi32>) -> tensor<1x1x?xi32>
+ return %2 : tensor<1x1x?xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_reshape_for_broadcast
+// CHECK-DAG: %[[RESHAPE_INPUT:.*]] = "tosa.const"() <{value = dense<[1, 2, 3, 4]>
+// CHECK-DAG: %[[RESHAPE:.*]] = tosa.reshape %[[RESHAPE_INPUT]] {new_shape = array<i64: 4, 1, 1>}
+// CHECK-DAG: %[[ADD:.*]] = tosa.add %arg0, %[[RESHAPE]]
+// CHECK: return %[[ADD]]
+func.func @test_reshape_for_broadcast(%arg0: tensor<4x3x2xi32>) -> tensor<4x3x2xi32> {
+ %0 = "tosa.const"() {value = dense<[1,2,3,4]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %reshape = tosa.reshape %0 {new_shape = array<i64: 1, 1, 4>} : (tensor<4xi32>) -> tensor<1x1x4xi32>
+ %perms0 = "tosa.const"() {value = dense<[2, 1, 0]> : tensor<3xi32>} : () -> tensor<3xi32>
+ %transpose0 = tosa.transpose %arg0, %perms0 : (tensor<4x3x2xi32>, tensor<3xi32>) -> tensor<2x3x4xi32>
+ %add = tosa.add %transpose0, %reshape : (tensor<2x3x4xi32>, tensor<1x1x4xi32>) -> tensor<2x3x4xi32>
+ %transpose1 = tosa.transpose %add, %perms0 : (tensor<2x3x4xi32>, tensor<3xi32>) -> tensor<4x3x2xi32>
+ return %transpose1 : tensor<4x3x2xi32>
+}
+
+// -----
+
+// COM: taken directly from ResNet18 translation.
+// COM: changes: %74 as argument instead of result of conv2d
+
+// CHECK-LABEL: @test_resnet18_common_case
+// COM: note that %74 is now represented by %arg2
+// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense_resource<torch_tensor_64_torch.float32_1> : tensor<64xf32>}> : () -> tensor<64xf32>
+// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{value = dense_resource<torch_tensor_64_torch.float32> : tensor<64xf32>}> : () -> tensor<64xf32>
+// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<9.99999974E-6> : tensor<1xf32>}> : () -> tensor<1xf32>
+// CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor<1xf32>}> : () -> tensor<1xf32>
+// CHECK-DAG: %[[VAL_6:.*]] = tosa.add %arg1, %[[VAL_4]] : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32>
+// CHECK-DAG: %[[VAL_7:.*]] = tosa.pow %[[VAL_6]], %[[VAL_5]] : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32>
+// CHECK-DAG: %[[VAL_8:.*]] = tosa.reciprocal %[[VAL_7]] : (tensor<64xf32>) -> tensor<64xf32>
+// CHECK-DAG: %[[VAL_9:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 64>} : (tensor<64xf32>) -> tensor<1x1x1x64xf32>
+// CHECK-DAG: %[[VAL_10:.*]] = tosa.sub %arg2, %[[VAL_9]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32>
+// CHECK-DAG: %[[VAL_11:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array<i64: 1, 1, 1, 64>} : (tensor<64xf32>) -> tensor<1x1x1x64xf32>
+// CHECK-DAG: %[[VAL_12:.*]] = tosa.mul %[[VAL_10]], %[[VAL_11]] {shift = 0 : i8} : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32>
+// CHECK-DAG: %[[VAL_13:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array<i64: 1, 1, 1, 64>} : (tensor<64xf32>) -> tensor<1x1x1x64xf32>
+// CHECK-DAG: %[[VAL_14:.*]] = tosa.mul %[[VAL_12]], %[[VAL_13]] {shift = 0 : i8} : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32>
+// CHECK-DAG: %[[VAL_15:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array<i64: 1, 1, 1, 64>} : (tensor<64xf32>) -> tensor<1x1x1x64xf32>
+// CHECK-DAG: %[[VAL_16:.*]] = tosa.add %[[VAL_14]], %[[VAL_15]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32>
+// CHECK-DAG: %[[VAL_17:.*]] = tosa.clamp %[[VAL_16]] {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x112x112x64xf32>) -> tensor<1x112x112x64xf32>
+// CHECK: return %[[VAL_17]] : tensor<1x112x112x64xf32>
+
+func.func @test_resnet18_common_case(%arg0: tensor<64xf32>, %arg1: tensor<64xf32>, %74: tensor<1x112x112x64xf32>) -> tensor<1x112x112x64xf32> {
+ %59 = "tosa.const"() <{value = dense_resource<torch_tensor_64_torch.float32_1> : tensor<64xf32>}> : () -> tensor<64xf32>
+ %60 = "tosa.const"() <{value = dense_resource<torch_tensor_64_torch.float32> : tensor<64xf32>}> : () -> tensor<64xf32>
+ %63 = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32>
+ %64 = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
+ %69 = "tosa.const"() <{value = dense<9.99999974E-6> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %70 = "tosa.const"() <{value = dense<5.000000e-01> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %75 = tosa.transpose %74, %64 : (tensor<1x112x112x64xf32>, tensor<4xi32>) -> tensor<1x64x112x112xf32>
+ %76 = tosa.add %arg1, %69 : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32>
+ %77 = tosa.pow %76, %70 : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32>
+ %78 = tosa.reciprocal %77 : (tensor<64xf32>) -> tensor<64xf32>
+ %79 = tosa.reshape %arg0 {new_shape = array<i64: 1, 64, 1, 1>} : (tensor<64xf32>) -> tensor<1x64x1x1xf32>
+ %80 = tosa.sub %75, %79 : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32>
+ %81 = tosa.reshape %78 {new_shape = array<i64: 1, 64, 1, 1>} : (tensor<64xf32>) -> tensor<1x64x1x1xf32>
+ %82 = tosa.mul %80, %81 {shift = 0 : i8} : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32>
+ %83 = tosa.reshape %60 {new_shape = array<i64: 1, 64, 1, 1>} : (tensor<64xf32>) -> tensor<1x64x1x1xf32>
+ %84 = tosa.mul %82, %83 {shift = 0 : i8} : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32>
+ %85 = tosa.reshape %59 {new_shape = array<i64: 1, 64, 1, 1>} : (tensor<64xf32>) -> tensor<1x64x1x1xf32>
+ %86 = tosa.add %84, %85 : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32>
+ %87 = tosa.clamp %86 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x64x112x112xf32>) -> tensor<1x64x112x112xf32>
+ %88 = tosa.transpose %87, %63 : (tensor<1x64x112x112xf32>, tensor<4xi32>) -> tensor<1x112x112x64xf32>
+ return %88 : tensor<1x112x112x64xf32>
+}
+
+
+// -----
+
+// CHECK-LABEL: @test_back_to_back_nullifiers
+// CHECK: %[[PERMS:.*]] = "tosa.const"
+// CHECK: %[[RES:.*]] = tosa.transpose %arg0, %[[PERMS]]
+// CHECK: return %[[RES]]
+func.func @test_back_to_back_nullifiers(%arg0: tensor<2x3xi32>) -> tensor<3x2xi32> {
+ %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+ %0 = tosa.transpose %arg0, %perms : (tensor<2x3xi32>, tensor<2xi32>) -> tensor<3x2xi32>
+ %1 = tosa.transpose %0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
+ %2 = tosa.transpose %1, %perms : (tensor<2x3xi32>, tensor<2xi32>) -> tensor<3x2xi32>
+ return %2 : tensor<3x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_back_to_back_nullifiers_different_transposes
+// CHECK: %[[PERMS:.*]] = "tosa.const"
+// CHECK: %[[RES:.*]] = tosa.transpose %arg0, %[[PERMS]]
+// CHECK: return %[[RES]]
+func.func @test_back_to_back_nullifiers_different_transposes(%arg0: tensor<2x3x4x5xi32>) -> tensor<2x4x5x3xi32> {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %0 = tosa.transpose %arg0, %perms0 : (tensor<2x3x4x5xi32>, tensor<4xi32>) -> tensor<2x4x5x3xi32>
+ %1 = tosa.transpose %0, %perms1 : (tensor<2x4x5x3xi32>, tensor<4xi32>) -> tensor<2x3x4x5xi32>
+ %2 = tosa.transpose %1, %perms0 : (tensor<2x3x4x5xi32>, tensor<4xi32>) -> tensor<2x4x5x3xi32>
+ return %2 : tensor<2x4x5x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_no_transform_if_outside_fan_in_cone
+// CHECK: tosa.const
+// CHECK: %[[CLAMP_IN:.*]] = tosa.transpose
+// CHECK: %[[RES2:.*]] = tosa.clamp %[[CLAMP_IN]]
+// CHECK: tosa.const
+// CHECK: %[[RES1:.*]] = tosa.transpose
+// CHECK: return %[[RES1]], %[[RES2]]
+func.func @test_no_transform_if_outside_fan_in_cone(%arg0: tensor<3x3x3x3xi32>) -> (tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>) {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %0 = tosa.transpose %arg0, %perms0 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ %clamp = tosa.clamp %0 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %1 = tosa.transpose %clamp, %perms1 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ return %1, %clamp : tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_two_different_downstream_converge_to_reshape_same_perms
+// CHECK-DAG: %[[RESHAPE:.*]] = tosa.reshape %arg0
+// CHECK-DAG: %[[CLAMP:.*]] = tosa.clamp %[[RESHAPE]]
+// CHECK: return %[[RESHAPE]], %[[CLAMP]]
+func.func @test_two_different_downstream_converge_to_reshape_same_perms(%arg0: tensor<64xf32>) -> (tensor<1x1x64xf32>, tensor<1x1x64xf32>) {
+ %0 = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
+ %1 = tosa.reshape %arg0 {new_shape = array<i64: 1, 64, 1>} : (tensor<64xf32>) -> tensor<1x64x1xf32>
+ %2 = tosa.clamp %1 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x64x1xf32>) -> tensor<1x64x1xf32>
+ %3 = tosa.transpose %1, %0 : (tensor<1x64x1xf32>, tensor<3xi32>) -> tensor<1x1x64xf32>
+ %4 = tosa.transpose %2, %0 : (tensor<1x64x1xf32>, tensor<3xi32>) -> tensor<1x1x64xf32>
+ return %3, %4 : tensor<1x1x64xf32>, tensor<1x1x64xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_two_different_downstream_converge_to_reshape_different_perms
+// CHECK-DAG: tosa.const
+// CHECK-DAG: tosa.const
+// CHECK-DAG: %[[RESHAPE:.*]] = tosa.reshape
+// CHECK-DAG: %[[CLAMP:.*]] = tosa.clamp %[[RESHAPE]]
+// CHECK-DAG: %[[RET1:.*]] = tosa.transpose
+// CHECK-DAG: %[[RET2:.*]] = tosa.transpose
+// CHECK-DAG: return %[[RET1]], %[[RET2]]
+func.func @test_two_different_downstream_converge_to_reshape_different_perms(%arg0: tensor<64xf32>) -> (tensor<1x1x64xf32>, tensor<64x1x1xf32>) {
+ %0 = "tosa.const"() <{value = dense<[1, 2, 0]> : tensor<3xi32>}> : () -> tensor<3xi32>
+ %1 = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
+ %2 = tosa.reshape %arg0 {new_shape = array<i64: 1, 64, 1>} : (tensor<64xf32>) -> tensor<1x64x1xf32>
+ %3 = tosa.clamp %2 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x64x1xf32>) -> tensor<1x64x1xf32>
+ %4 = tosa.transpose %2, %1 : (tensor<1x64x1xf32>, tensor<3xi32>) -> tensor<1x1x64xf32>
+ %5 = tosa.transpose %3, %0 : (tensor<1x64x1xf32>, tensor<3xi32>) -> tensor<64x1x1xf32>
+ return %4, %5 : tensor<1x1x64xf32>, tensor<64x1x1xf32>
+}
+
+// -----
+
+// COM: no transform
+// CHECK-LABEL: @test_outside_perms_usage_of_fan_in
+// CHECK: tosa.const
+// CHECK: tosa.transpose
+// CHECK: tosa.clamp
+// CHECK: %[[RES1:.*]] = tosa.transpose
+// CHECK: %[[RES2:.*]] = tosa.add
+// CHECK: return %[[RES1]], %[[RES2]]
+func.func @test_outside_perms_usage_of_fan_in(%arg0: tensor<2x3xf32>, %arg1: tensor<3x2xf32>) -> (tensor<2x3xf32>, tensor<3x2xf32>) { %0 = "tosa.const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32>
+ %1 = tosa.transpose %arg0, %0 : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
+ %2 = tosa.clamp %1 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<3x2xf32>) -> tensor<3x2xf32>
+ %3 = tosa.transpose %2, %0 : (tensor<3x2xf32>, tensor<2xi32>) -> tensor<2x3xf32>
+ %4 = tosa.add %arg1, %2 : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32>
+ return %3, %4: tensor<2x3xf32>, tensor<3x2xf32>
+}
+
+// -----
+
+// COM: this use-case is important for ResNet. we want to allow these, but disallow if falls into an illegal area (outside fan-ins that get converted),
+// COM: since then we would get duplicate ops.
+// CHECK-LABEL: @test_use_present_in_another_valid_perms_fan_in
+// CHECK-DAG: %[[NEW_CLAMP:.*]] = tosa.clamp %arg0
+// CHECK-DAG: %[[NEW_ADD:.*]] = tosa.add %arg1, %[[NEW_CLAMP]]
+// CHECK: return %[[NEW_CLAMP]], %[[NEW_ADD]]
+func.func @test_use_present_in_another_valid_perms_fan_in(%arg0: tensor<2x3xf32>, %arg1: tensor<2x3xf32>) -> (tensor<2x3xf32>, tensor<2x3xf32>) {
+ %0 = "tosa.const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32>
+ %1 = tosa.transpose %arg0, %0 : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
+ %2 = tosa.clamp %1 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<3x2xf32>) -> tensor<3x2xf32>
+ %3 = tosa.transpose %2, %0 : (tensor<3x2xf32>, tensor<2xi32>) -> tensor<2x3xf32>
+ %4 = tosa.transpose %arg1, %0 : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
+ %5 = tosa.add %4, %2 : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32>
+ %6 = tosa.transpose %5, %0 : (tensor<3x2xf32>, tensor<2xi32>) -> tensor<2x3xf32>
+ return %3, %6: tensor<2x3xf32>, tensor<2x3xf32>
+}
+
+// -----
+
+// COM: no transform, since we would get duplicates
+// CHECK-LABEL: @test_two_same_perms_fan_in_but_one_doesnt_convert_dependents
+// CHECK: tosa.const
+// CHECK: tosa.transpose
+// CHECK: %[[CEIL:.*]] = tosa.ceil
+// CHECK: %[[ADD:.*]] = tosa.add %[[CEIL]]
+// CHECK: %[[RES1:.*]] = tosa.transpose %[[CEIL]]
+// CHECK: %[[RES2:.*]] = tosa.transpose %[[ADD]]
+// CHECK: return %[[RES1]], %[[RES2]]
+func.func @test_two_same_perms_fan_in_but_one_doesnt_convert_dependents(%arg0: tensor<2x3xi32>, %arg1: tensor<3x2xi32>) -> (tensor<2x3xi32>, tensor<2x3xi32>) {
+ %0 = "tosa.const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32>
+ %1 = tosa.transpose %arg0, %0 : (tensor<2x3xi32>, tensor<2xi32>) -> tensor<3x2xi32>
+ %2 = tosa.ceil %1 : (tensor<3x2xi32>) -> tensor<3x2xi32>
+ %3 = tosa.add %2, %arg1 : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
+ %4 = tosa.transpose %2, %0 : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
+ %5 = tosa.transpose %3, %0 : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
+ return %4, %5 : tensor<2x3xi32>, tensor<2x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_direct_use_in_other_transpose_with_same_perms
+// CHECK-NEXT: %[[RES:.*]] = tosa.clamp %arg0
+// CHECK-NEXT: return %[[RES]], %[[RES]]
+func.func @test_direct_use_in_other_transpose_with_same_perms(%arg0: tensor<3x3x3x3xi32>) -> (tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>) {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %0 = tosa.transpose %arg0, %perms0 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ %clamp = tosa.clamp %0 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %1 = tosa.transpose %clamp, %perms1 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ %2 = tosa.transpose %clamp, %perms1 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ return %1, %2 : tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_const_transpose
+// CHECK: %[[NEW:.*]] = "tosa.const"() <{value = dense<0> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
+// CHECK-NOT: tosa.transpose
+// CHECK: return %[[NEW]]
+func.func @test_const_transpose() -> tensor<2x3xi32> {
+ %0 = "tosa.const"() {value = dense<0> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
+ %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+ %1 = tosa.transpose %0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
+ return %1 : tensor<2x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_transpose_tracks_to_const_single_step
+// CHECK: %[[NEW_CONST:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x2x3x4xi32>}> : () -> tensor<1x2x3x4xi32>
+// CHECK: %[[NEW_CLAMP:.*]] = tosa.clamp %[[NEW_CONST]] {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32>
+// CHECK-NOT: tosa.transpose
+// CHECK: return %[[NEW_CLAMP]]
+func.func @test_transpose_tracks_to_const_single_step() -> tensor<1x2x3x4xi32> {
+ %0 = "tosa.const"() {value = dense<0> : tensor<1x3x4x2xi32>} : () -> tensor<1x3x4x2xi32>
+ %clamp = tosa.clamp %0 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %1 = tosa.transpose %clamp, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ return %1 : tensor<1x2x3x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_static_unary_path_to_const
+// CHECK: %[[NEW_CONST:.*]] = "tosa.const"() <{value = dense<1> : tensor<1x2x3x4xi32>}> : () -> tensor<1x2x3x4xi32>
+// CHECK: %[[NEW_CLAMP:.*]] = tosa.clamp %[[NEW_CONST]] {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32>
+// CHECK: %[[NEW_ABS:.*]] = tosa.abs %[[NEW_CLAMP]] : (tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32>
+// CHECK: %[[NEW_NOT:.*]] = tosa.bitwise_not %[[NEW_ABS]] : (tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32>
+// CHECK: return %[[NEW_NOT]]
+func.func @test_static_unary_path_to_const() -> tensor<1x2x3x4xi32> {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %0 = "tosa.const"() {value = dense<1> : tensor<1x3x4x2xi32>} : () -> tensor<1x3x4x2xi32>
+ %clamp = tosa.clamp %0 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %abs = tosa.abs %clamp : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %bitwise_not = tosa.bitwise_not %abs : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %1 = tosa.transpose %bitwise_not, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ return %1 : tensor<1x2x3x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_static_diverges_to_non_splat_const_and_nullifying
+// CHECK: %[[NEW_CONST:.*]] = "tosa.const"()
+// CHECK-SAME{LITERAL}: dense<[[[[1, 3, 5, 7], [9, 11, 13, 15], [17, 19, 21, 23]], [[2, 4, 6, 8], [10, 12, 14, 16], [18, 20, 22, 24]]]]>
+// CHECK: tensor<1x2x3x4xi32>}> : () -> tensor<1x2x3x4xi32>
+// CHECK: %[[NEW_CLAMP:.*]] = tosa.clamp %arg0 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32>
+// CHECK: %[[NEW_ABS:.*]] = tosa.abs %[[NEW_CONST]] : (tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32>
+// CHECK: %[[NEW_ADD:.*]] = tosa.add %[[NEW_ABS]], %[[NEW_CLAMP]] : (tensor<1x2x3x4xi32>, tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32>
+// CHECK: return %[[NEW_ADD]]
+func.func @test_static_diverges_to_non_splat_const_and_nullifying(%arg0: tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %transpose0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %const = "tosa.const"() {value = dense<[[[[1, 2], [3, 4], [5, 6], [7, 8]],
+ [[9, 10], [11, 12], [13, 14], [15, 16]],
+ [[17, 18], [19, 20], [21, 22], [23, 24]]]]> : tensor<1x3x4x2xi32>} : () -> tensor<1x3x4x2xi32>
+ %clamp = tosa.clamp %transpose0 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %abs = tosa.abs %const : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %add = tosa.add %abs, %clamp : (tensor<1x3x4x2xi32>, tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %perms2 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %result = tosa.transpose %add, %perms2 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ return %result : tensor<1x2x3x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_multi_downstream_both_nullify
+// CHECK-NEXT: %[[RES:.*]] = tosa.clamp %arg0
+// CHECK-NEXT: return %[[RES]], %[[RES]]
+func.func @test_multi_downstream_both_nullify(%arg0: tensor<3x3x3x3xi32>) -> (tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>) {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %0 = tosa.transpose %arg0, %perms0 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ %clamp = tosa.clamp %0 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %1 = tosa.transpose %clamp, %perms1 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ %2 = tosa.transpose %clamp, %perms1 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ return %1, %2 : tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>
+}
+
+// -----
+
+// COM: we don't perform this transformation intentionally, since we would then get duplicates
+// CHECK-LABEL: @test_multi_downstream_one_nullifies_upstream_other_does_not
+// CHECK: tosa.const
+// CHECK: tosa.transpose
+// CHECK: tosa.clamp
+// CHECK: tosa.const
+// CHECK: tosa.transpose
+// CHECK: tosa.transpose
+func.func @test_multi_downstream_one_nullifies_upstream_other_does_not(%arg0: tensor<3x3x3x3xi32>) -> (tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>) {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %0 = tosa.transpose %arg0, %perms0 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ %clamp = tosa.clamp %0 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %1 = tosa.transpose %clamp, %perms1 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ %2 = tosa.transpose %clamp, %perms0 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ return %1, %2 : tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_unknown_dim_inner_replacement_matches
+// CHECK-NEXT: return %arg0
+func.func @test_unknown_dim_inner_replacement_matches(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
+ %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+ %0 = tosa.transpose %arg0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<?x3xi32>
+ %1 = tosa.transpose %0, %perms : (tensor<?x3xi32>, tensor<2xi32>) -> tensor<3x2xi32>
+ return %1 : tensor<3x2xi32>
+}
+
+// -----
+
+
+// CHECK-LABEL: @test_unknown_dim_outer_replacement_matches
+// CHECK-NEXT: return %arg0
+func.func @test_unknown_dim_outer_replacement_matches(%arg0: tensor<3x?xi32>) -> tensor<3x?xi32> {
+ %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+ %0 = tosa.transpose %arg0, %perms : (tensor<3x?xi32>, tensor<2xi32>) -> tensor<2x3xi32>
+ %1 = tosa.transpose %0, %perms : (tensor<2x3xi32>, tensor<2xi32>) -> tensor<3x?xi32>
+ return %1 : tensor<3x?xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_transpose_tracks_to_nullifying_diverging_binary_unknown_dim_replacements_match
+// CHECK-NEXT: %[[CLAMP:.*]] = tosa.clamp %arg0
+// CHECK-NEXT: %[[ABS:.*]] = tosa.abs %arg1
+// CHECK-NEXT: %[[ADD:.*]] = tosa.add %[[CLAMP]], %[[ABS]]
+// CHECK-NEXT: return %[[ADD]]
+func.func @test_transpose_tracks_to_nullifying_diverging_binary_unknown_dim_replacements_match(%arg0: tensor<1x?x3x4xi32>, %arg1: tensor<1x2x?x4xi32>) -> tensor<1x2x3x4xi32> {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %transpose0 = tosa.transpose %arg0, %perms0 : (tensor<1x?x3x4xi32>, tensor<4xi32>) -> tensor<?x3x4x?xi32>
+ %transpose1 = tosa.transpose %arg1, %perms0 : (tensor<1x2x?x4xi32>, tensor<4xi32>) -> tensor<1x?x?x2xi32>
+ %clamp = tosa.clamp %transpose0 {min_int = 0 : i64, max_int = 1 : i64, min_fp = 0.0 : f64, max_fp = 1.0 : f64} : (tensor<?x3x4x?xi32>) -> tensor<?x3x4x?xi32>
+ %abs = tosa.abs %transpose1 : (tensor<1x?x?x2xi32>) -> tensor<1x?x?x2xi32>
+ %add = tosa.add %clamp, %abs : (tensor<?x3x4x?xi32>, tensor<1x?x?x2xi32>) -> tensor<1x3x4x2xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %result = tosa.transpose %add, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ return %result : tensor<1x2x3x4xi32>
+}
+
+// -----
+
+// COM: we cannot do anything to the transpose in this case.
+// CHECK-LABEL: @test_unimplemented_non_const_perms
+// CHECK: tosa.const
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: return
+func.func @test_unimplemented_non_const_perms(%perms: tensor<2xi32>) -> tensor<?x?xi32> {
+ %0 = "tosa.const"() {value = dense<0> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
+ %1 = tosa.transpose %0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<?x?xi32>
+ return %1 : tensor<?x?xi32>
+}
+
+// -----
+
+// COM: due to tracking back to a non-nullifying transpose, we can't get rid of the transposes entirely.
+// COM: later editions of the pass may wish to fold these into a single transpose.
+// CHECK-LABEL: @test_unimplemented_transpose_tracks_to_non_nullifying_transpose_single_step
+// CHECK: tosa.const
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: tosa.clamp
+// CHECK-NEXT: tosa.const
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: return
+func.func @test_unimplemented_transpose_tracks_to_non_nullifying_transpose_single_step(%arg0: tensor<1x2x3x4xi32>) -> tensor<1x2x4x3xi32> {
+ %perms0 = "tosa.const"() {value = dense<[0, 3, 2, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x4x3x2xi32>
+ %clamp = tosa.clamp %0 {min_int = 0 : i64, max_int = 1 : i64, min_fp = 0.0 : f64, max_fp = 1.0 : f64} : (tensor<1x4x3x2xi32>) -> tensor<1x4x3x2xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %1 = tosa.transpose %clamp, %perms1 : (tensor<1x4x3x2xi32>, tensor<4xi32>) -> tensor<1x2x4x3xi32>
+ return %1 : tensor<1x2x4x3xi32>
+}
+
+// -----
+
+// COM: we don't deal with this case. resolution of shapes required.
+// CHECK-LABEL: @test_unimplemented_unknown_dim_input_nullifying_pair
+// CHECK: tosa.const
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: return
+func.func @test_unimplemented_unknown_dim_input_nullifying_pair(%arg0: tensor<3x?xi32>) -> tensor<3x2xi32> {
+ %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+ %0 = tosa.transpose %arg0, %perms : (tensor<3x?xi32>, tensor<2xi32>) -> tensor<2x3xi32>
+ %1 = tosa.transpose %0, %perms : (tensor<2x3xi32>, tensor<2xi32>) -> tensor<3x2xi32>
+ return %1 : tensor<3x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_unimplemented_unknown_dim_replacement_does_not_match
+// CHECK: tosa.const
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: return
+func.func @test_unimplemented_unknown_dim_replacement_does_not_match(%arg0: tensor<3x?xi32>) -> tensor<?x?xi32> {
+ %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+ %0 = tosa.transpose %arg0, %perms : (tensor<3x?xi32>, tensor<2xi32>) -> tensor<?x3xi32>
+ %1 = tosa.transpose %0, %perms : (tensor<?x3xi32>, tensor<2xi32>) -> tensor<?x?xi32>
+ return %1 : tensor<?x?xi32>
+}
+
+// -----
+
+// COM: this would be able to be converted if --tosa-infer-shapes was run beforehand
+// CHECK-LABEL: @test_unimplemented_unranked_tensors_present
+// CHECK: tosa.const
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: return
+func.func @test_unimplemented_unranked_tensors_present(%arg0: tensor<3x2xi32>) -> tensor<*xi32> {
+ %perms = "tosa.const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> tensor<2xi32>
+ %0 = tosa.transpose %arg0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<*xi32>
+ %1 = tosa.transpose %0, %perms : (tensor<*xi32>, tensor<2xi32>) -> tensor<*xi32>
+ return %1 : tensor<*xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_unimplemented_unranked_everything
+// CHECK: tosa.const
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: return
+func.func @test_unimplemented_unranked_everything(%arg0: tensor<*xi32>) -> tensor<*xi32> {
+ %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+ %0 = tosa.transpose %arg0, %perms : (tensor<*xi32>, tensor<2xi32>) -> tensor<*xi32>
+ %1 = tosa.transpose %0, %perms : (tensor<*xi32>, tensor<2xi32>) -> tensor<*xi32>
+ return %1 : tensor<*xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_unimplemented_static_diverges_to_one_nullifying_one_non_nullifying
+// CHECK: tosa.const
+// CHECK-NEXT: tosa.const
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: tosa.clamp
+// CHECK-NEXT: tosa.abs
+// CHECK-NEXT: tosa.add
+// CHECK-NEXT: tosa.const
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: return
+func.func @test_unimplemented_static_diverges_to_one_nullifying_one_non_nullifying(%arg0: tensor<1x2x3x4xi32>, %arg1: tensor<1x2x4x3xi32>) -> tensor<1x2x3x4xi32> {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 2, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %transpose0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %transpose1 = tosa.transpose %arg1, %perms1 : (tensor<1x2x4x3xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %clamp = tosa.clamp %transpose0 {min_int = 0 : i64, max_int = 1 : i64, min_fp = 0.0 : f64, max_fp = 1.0 : f64} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %abs = tosa.abs %transpose1 : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %add = tosa.add %clamp, %abs : (tensor<1x3x4x2xi32>, tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %perms2 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %result = tosa.transpose %add, %perms2 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ return %result : tensor<1x2x3x4xi32>
+}
More information about the Mlir-commits
mailing list