[Mlir-commits] [mlir] [MLIR][TOSA] Add --tosa-remove-redundant-transposes pass (PR #108260)
Arteen Abrishami
llvmlistbot at llvm.org
Wed Sep 11 20:08:09 PDT 2024
https://github.com/arteen1000 updated https://github.com/llvm/llvm-project/pull/108260
>From 7f8ec4b983812531e959d85c2104b2ed2ac73d22 Mon Sep 17 00:00:00 2001
From: Arteen Abrishami <arteen.abrishami at arm.com>
Date: Thu, 5 Sep 2024 22:03:19 +0000
Subject: [PATCH] [MLIR][TOSA] Add --tosa-remove-redundant-transposes pass
----------
Motivation:
----------
Some legalization pathways introduce redundant tosa.TRANSPOSE
operations that result in avoidable data movement. For example,
PyTorch -> TOSA contains a lot of unnecessary transposes due
to conversions between NCHW and NHWC.
We wish to remove all the ones that we can, since in general
it is possible to remove upwards of 90% of these transposes
in a provable manner.
------------
Changes Made:
------------
- Add the --tosa-remove-redundant-transposes pass
- Add TosaElementwiseOperator trait.
-------------------
High-Level Overview:
-------------------
The pass begins at a downstream transpose with some perms tensor.
It traverses the dependencies upward, accepting only TosaElementwise operators.
Dependencies must terminate in nullifying transposes (when composed, they form the identity),
reshapes, or consts.
Conceptually, we then "bubble up" the downstream transpose until
we hit the sources. For constants, we generate a new constants, composed
with the downstream transpose. For nullifying transposes, we "cancel"
them. For reshapes, we generally cannot "bubble" through them, so we
insert the downstream transpose there.
We then ensure that we do not cause any duplication by "converting"
this chain we bubbled-up into its transposed form. We do this by analyzing
the dependency fan-ins across all transposes with the same perms tensor
in order to ensure that they do not have uses outside this group, which
would cause the old code section to remain "live", and not removed by
canonicalization.
--------------
Impact of Pass:
--------------
For the ResNet18 network, we are able to reduce it to 5 transposes, from
56 -- with the patching of the torch dense_resource artifacts with dense
attributes. Otherwise, without that patch, we reduce to 23, since we cannot
fold those artifacts.
In the second case (56 -> 23), instruction count is reduced by exactly 33.
There are 3 transposes that would be removed if we omitted the fan-in analysis,
however, with fan-in analysis, we end up with ~15 less operations, due
to the lack of duplication.
For ResNet50, the results are essentially identical.
For MobilenetV3, we reduce the number of transposes from 82 to 38 without
taking care of upstream constants. After also taking care of constants, we
reduce it to 20 transposes. The remaining have a use elsewhere outside
of the fan-in cones. The pass alone (after --canonicalize is run on the
initial network), is responsible for the removal of 48 of the transposes.
Due to cases where a constant is used purely in its NCHW form without a
transpose to NHWC and also separately used in a place where the downstream
converts to NHWC, we do end up with 7 additional constants; however, due to
their small size, this has minimal memory footprint.
-----------
Future Work:
-----------
(1)
Evaluate tradeoffs with the duplication of ConstOp, especially
across many downstream transposes with different perms, which can result
in the same ConstOp being duplicated (but transposed) multiple times.
Observe tradeoffs between a lower memory footprint and potentially
converting many fan-ins of downstream transposes with the same perms,
which if not converted may affect ability of other inter-dependent fan-in
to convert.
(2)
Restrict the propagation of transposes up their fan-in cone if one
of the sources is a ReshapeOp for which the inserted TransposeOp would
not be a TransposeOp that lends itself to the TransposeIsReshape Canonicalization,
which permits them to be folded to a single ReshapeOp.
Observe impact on how this restriction may be detrimental to the
conversion of other downstream transpose conversions due to the
fan-in cone analysis. Additionally, consider cases where there
may be multiple upstream transposes that could be removed as a
result of this -- and trade that off with how many you would
effectively insert if the ReshapeOp/TransposeOp can't be folded
to a single ReshapeOp.
(3)
Make the pass more general, beyond just allowing upstream transposes
to be nullifying. For example,
transpose1 -> ... -> transpose2
where transpose2(transpose1) do not cancel to identity.
This can be done by propagating the downstream transpose up
and inserting after transpose1, just like how it is done for
reshape. However, in the case of chains like
transpose1 -> ... -> transpose2 -> ... -> transpose3
this could require running the current runOnOperation() function
until we converge. This can be done by stopping when all transposes
that we can successfully collect the fan-ins of have the owner
of their first operand being either another TransposeOp or a
ReshapeOp, since those are what we propagate to and where we leave
behind / insert another TransposeOp. Otherwise, we would could potentially
have infinite looping.
This additionally has the implication that we would not replace any
transposes and instead we could have canonicalization handle that.
(4)
Add support for more instructions (for example, those that reduce
alongside an axis) to be one of the intervening operations in the
fan-in cones (other than those with TosaElementwiseOperator trait).
(5)
Support bubbling transposes up to the input parameter. May not
need extensive fan-in analysis as no operation cost associated
if used elsewhere.
Signed-off-by: Arteen Abrishami <arteen.abrishami at arm.com>
---
.../mlir/Dialect/Tosa/IR/TosaOpBase.td | 10 +
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h | 6 +
.../mlir/Dialect/Tosa/Transforms/Passes.h | 2 +
.../mlir/Dialect/Tosa/Transforms/Passes.td | 16 +
.../Dialect/Tosa/Transforms/CMakeLists.txt | 1 +
.../TosaRemoveRedundantTransposes.cpp | 761 ++++++++++++++++++
...e-redundant-transposes-dynamic-shapes.mlir | 55 ++
...-remove-redundant-transposes-pipeline.mlir | 48 ++
...ve-redundant-transposes-static-shapes.mlir | 552 +++++++++++++
...ve-redundant-transposes-unimplemented.mlir | 120 +++
10 files changed, 1571 insertions(+)
create mode 100644 mlir/lib/Dialect/Tosa/Transforms/TosaRemoveRedundantTransposes.cpp
create mode 100644 mlir/test/Dialect/Tosa/tosa-remove-redundant-transposes/tosa-remove-redundant-transposes-dynamic-shapes.mlir
create mode 100644 mlir/test/Dialect/Tosa/tosa-remove-redundant-transposes/tosa-remove-redundant-transposes-pipeline.mlir
create mode 100644 mlir/test/Dialect/Tosa/tosa-remove-redundant-transposes/tosa-remove-redundant-transposes-static-shapes.mlir
create mode 100644 mlir/test/Dialect/Tosa/tosa-remove-redundant-transposes/tosa-remove-redundant-transposes-unimplemented.mlir
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 1412c7a2615d20..64bacd0e432fe5 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -206,6 +206,15 @@ def Tosa_ExplicitValuePadOpQuantInfoBuilder : OpBuilder<
input, paddings, pad_value);
}]>;
+//===----------------------------------------------------------------------===//
+// TOSA Operator Trait.
+//===----------------------------------------------------------------------===//
+
+// Permits broadcasting. Elementwise trait is too strict.
+def TosaElementwiseOperator : NativeOpTrait<"TosaElementwiseOperator"> {
+ let cppNamespace = "mlir::OpTrait::tosa";
+}
+
//===----------------------------------------------------------------------===//
// TOSA Operator Class.
//===----------------------------------------------------------------------===//
@@ -219,6 +228,7 @@ class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
ResultsBroadcastableShape,
+ TosaElementwiseOperator,
Pure])> {
let assemblyFormat =
"operands attr-dict `:` functional-type(operands, results)";
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index 7ed89bff474a2e..8122752a9f3e1a 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -84,6 +84,12 @@ class MulOperandsAndResultElementType
}
};
+/// This class indicates that an op is tosa-elementwise (permits broadcasting,
+/// unlike Elementwise trait)
+template <typename ConcreteType>
+class TosaElementwiseOperator
+ : public TraitBase<ConcreteType, TosaElementwiseOperator> {};
+
} // namespace tosa
} // namespace OpTrait
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index 1f9522b51a4cf5..c0913171f9b17f 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -14,6 +14,7 @@
#define MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/PassesEnums.h.inc"
#include "mlir/Pass/Pass.h"
@@ -48,6 +49,7 @@ std::unique_ptr<Pass> createTosaInferShapesPass();
std::unique_ptr<Pass> createTosaMakeBroadcastablePass();
std::unique_ptr<Pass> createTosaTestQuantUtilAPIPass();
std::unique_ptr<Pass> createTosaOptionalDecompositions();
+std::unique_ptr<Pass> createTosaRemoveRedundantTransposes();
struct ValidationOptions {
/// Validate if operations match for the given profile.
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index a0f670de20150f..66d046c7040b4f 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -126,4 +126,20 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
];
}
+def TosaRemoveRedundantTransposes : Pass<"tosa-remove-redundant-transposes", "func::FuncOp"> {
+ let summary = "Remove redundant transposes";
+ let description = [{
+ Pass that identifies and removes redundant tosa.TRANSPOSE operations.
+ It does so by traversing dependencies of tosa.TRANSPOSE operations until they terminate in either
+ tosa.RESHAPE, a nullifying tosa.TRANSPOSE, or a tosa.CONST. It then propagates the downstream
+ transform upward through the intervening operators if it is able and replaces the downstream tosa.TRANSPOSE.
+ Results generally better when run after Canonicalization and resolution of dynamic shapes.
+ Canonicalization is required for dead code elimination after pass is run.
+ This pass has an important use-case in cleaning up the results of frameworks that introduce a lot
+ of data-layout transformations when legalizing to TOSA, a common one being transformations between NHWC and NCHW
+ layouts.
+ }];
+ let constructor = "tosa::createTosaRemoveRedundantTransposes()";
+}
+
#endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index c78a74b874aff1..624038b9b38981 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
TosaLayerwiseConstantFoldPass.cpp
TosaMakeBroadcastable.cpp
TosaOptionalDecompositions.cpp
+ TosaRemoveRedundantTransposes.cpp
TosaTypeConverters.cpp
TosaValidation.cpp
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaRemoveRedundantTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaRemoveRedundantTransposes.cpp
new file mode 100644
index 00000000000000..06d7754e79a023
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaRemoveRedundantTransposes.cpp
@@ -0,0 +1,761 @@
+//===- TosaRemoveRedundantTransposes.cpp
+//------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// ----------
+// Motivation:
+// ----------
+
+// Some legalization pathways introduce redundant tosa.TRANSPOSE
+// operations that result in avoidable data movement. For example,
+// PyTorch -> TOSA contains a lot of unnecessary transposes due
+// to conversions between NCHW and NHWC.
+
+// We wish to remove all the ones that we can, since in general
+// it is possible to remove upwards of 90% of these transposes
+// in a provable manner.
+//
+// -------------------
+// High-Level Overview:
+// -------------------
+
+// The pass begins at a downstream transpose with some perms tensor.
+// It traverses the dependencies upward, accepting only TosaElementwise
+// operators. Dependencies must terminate in nullifying transposes (when
+// composed, they form the identity), reshapes, or consts.
+
+// Conceptually, we then "bubble up" the downstream transpose until
+// we hit the sources. For constants, we generate a new constants, composed
+// with the downstream transpose. For nullifying transposes, we "cancel"
+// them. For reshapes, we generally cannot "bubble" through them, so we
+// insert the downstream transpose there.
+
+// We then ensure that we do not cause any duplication by replacing usages
+// of the downstream transpose with the converted value of the operand
+// that feeds into it (after this bubble-up process). We do this by analyzing
+// the dependency fan-ins across all transposes with the same perms tensor
+// in order to ensure that they do not have uses outside this group, which
+// would cause the old code section to remain "live", and not removed by
+// canonicalization.
+
+// --------------
+// Impact of Pass:
+// --------------
+
+// For the ResNet18 network, we are able to reduce it to 5 transposes, from
+// 56 -- with the patching of the torch dense_resource artifacts with dense
+// attributes. Otherwise, without that patch, we reduce to 23, since we cannot
+// fold those artifacts.
+
+// In the second case (56 -> 23), instruction count is reduced by exactly 33.
+// There are 3 transposes that would be removed if we omitted the fan-in
+// analysis, however, with fan-in analysis, we end up with ~15 less operations,
+// due to the lack of duplication.
+
+// For ResNet50, the results are essentially identical.
+
+// For MobilenetV3, we reduce the number of transposes from 82 to 38 without
+// taking care of upstream constants. After also taking care of constants, we
+// reduce it to 20 transposes. The remaining have a use elsewhere outside
+// of the fan-in cones. The pass alone (after --canonicalize is run on the
+// initial network), is responsible for the removal of 48 of the transposes.
+
+// Due to cases where a constant is used purely in its NCHW form without a
+// transpose to NHWC and also separately used in a place where the downstream
+// converts to NHWC, we do end up with 7 additional constants; however, due to
+// their small size, this has minimal memory footprint.
+
+// -----------
+// Future Work:
+// -----------
+
+// (1)
+
+// Evaluate tradeoffs with the duplication of ConstOp, especially
+// across many downstream transposes with different perms, which can result
+// in the same ConstOp being duplicated (but transposed) multiple times.
+
+// Observe tradeoffs between a lower memory footprint and potentially
+// converting many fan-ins of downstream transposes with the same perms,
+// which if not converted may affect ability of other inter-dependent fan-in
+// to convert.
+
+// (2)
+
+// Restrict the propagation of transposes up their fan-in cone if one
+// of the sources is a ReshapeOp for which the inserted TransposeOp would
+// not be a TransposeOp that lends itself to the TransposeIsReshape
+// Canonicalization, which permits them to be folded to a single ReshapeOp.
+
+// Observe impact on how this restriction may be detrimental to the
+// conversion of other downstream transpose conversions due to the
+// fan-in cone analysis. Additionally, consider cases where there
+// may be multiple upstream transposes that could be removed as a
+// result of this -- and trade that off with how many you would
+// effectively insert if the ReshapeOp/TransposeOp can't be folded
+// to a single ReshapeOp.
+
+// (3)
+
+// Make the pass more general, beyond just allowing upstream transposes
+// to be nullifying. For example,
+
+// transpose1 -> ... -> transpose2
+
+// where transpose2(transpose1) do not cancel to identity.
+
+// This can be done by propagating the downstream transpose up
+// and inserting after transpose1, just like how it is done for
+// reshape. However, in the case of chains like
+
+// transpose1 -> ... -> transpose2 -> ... -> transpose3
+
+// this could require running the current runOnOperation() function
+// until we converge. This can be done by stopping when all transposes
+// that we can successfully collect the fan-ins of have the owner
+// of their first operand being either another TransposeOp or a
+// ReshapeOp, since those are what we propagate to and where we leave
+// behind / insert another TransposeOp. Otherwise, we would could potentially
+// have infinite looping.
+
+// This additionally has the implication that we would not replace any
+// transposes and instead we could have canonicalization handle that.
+
+// (4)
+
+// Add support for more instructions (for example, those that reduce
+// alongside an axis) to be one of the intervening operations in the
+// fan-in cones (other than those with TosaElementwiseOperator trait).
+
+// (5)
+
+// Support bubbling transposes up to the input parameter. May not
+// need extensive fan-in analysis as no operation cost associated
+// if used elsewhere.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/IR/Matchers.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include <memory>
+#include <set>
+#include <stack>
+
+namespace mlir {
+namespace tosa {
+#define GEN_PASS_DEF_TOSAREMOVEREDUNDANTTRANSPOSES
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+} // namespace tosa
+} // namespace mlir
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// TOSA Remove Redundant Transposes Pass.
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+struct TosaRemoveRedundantTransposes final
+ : public tosa::impl::TosaRemoveRedundantTransposesBase<
+ TosaRemoveRedundantTransposes> {
+ void runOnOperation() override;
+
+private:
+ // This will collect all the data dependencies for the given Operation
+ // up to and including ConstOp, ReshapeOp, and TransposeOp.
+ bool collectFanIn(Operation *op, SetVector<Operation *> &collected);
+ bool convertDependentOps(SetVector<Operation *> &dependentOps,
+ DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter,
+ ArrayRef<int32_t> downstreamPerms);
+
+ // Checks if the two permutations, when applied consecutively, result
+ // in the identity.
+ bool areNullifyingTransposes(ArrayRef<int32_t> perms1,
+ ArrayRef<int32_t> perms2);
+
+ // This is meant to apply to operations with the TosaElementwiseOperator
+ // trait.
+ std::optional<Value>
+ buildMappedToValue(Operation *op, const DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms);
+
+ // This updates valuesMap when we encounter another TransposeOp as a
+ // dependency of the downstream one. %0 = tosa.transpose %arg0 <- applies to
+ // this %1 = tosa.transpose %0 <- when tracking back from this
+ std::optional<Value>
+ buildMappedToValue(tosa::TransposeOp transposeOp,
+ const DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms);
+
+ // Inserts the downstream TransposeOp after the ReshapeOp, since we generally
+ // cannot propagate through it.
+ std::optional<Value>
+ buildMappedToValue(tosa::ReshapeOp reshapeOp,
+ const DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms);
+
+ // We may have something like:
+ // %0 = tosa.const
+ // %1 = tosa.transpose
+ // %2 = tosa.add %0, %1
+ // %3 = tosa.transpose %2
+ // that --tosa-layerwise-const-fold wouldn't handle. This use shows up
+ // in MobilenetV3.
+ std::optional<Value>
+ buildMappedToValue(tosa::ConstOp constOp,
+ const DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms);
+
+ // Checks which TransposeOp we should "replace", turning their converted
+ // chains of ops, through which they were propagated, "live", and the old code
+ // "dead." Attempts to avoid doing so when doing so would result in the old
+ // code staying "live," resulting in duplication. Relies on --canonicalize to
+ // remove the dead code that results from performing said replacement.
+ std::set<tosa::TransposeOp> getGoodReplacements(
+ ArrayRef<int32_t> perms,
+ std::vector<std::pair<tosa::TransposeOp, SetVector<Operation *>>>
+ &transposeInfo);
+
+ // Helper function for getGoodReplacements to check if some TransposeOp's
+ // dependencies are OK.
+ bool dependenciesAreValid(
+ ArrayRef<int32_t> perms, const SetVector<Operation *> &dependentOps,
+ std::set<tosa::TransposeOp> &validTransposes,
+ std::vector<std::pair<tosa::TransposeOp, SetVector<Operation *>>>
+ &transposeInfo);
+
+ // Applies perms to the DenseElementsAttr.
+ // If it returns std::nullopt, it also triggers pass failure, since verifier
+ // guarantees from TOSA are not in place (and otherwise, if used elsewhere
+ // it should fail).
+ // This is a basic API and may benefit from refactor into the core MLIR APIs.
+ std::optional<DenseElementsAttr>
+ transposeDenseAttribute(DenseElementsAttr input, ArrayRef<int32_t> perms);
+};
+
+std::optional<DenseElementsAttr>
+TosaRemoveRedundantTransposes::transposeDenseAttribute(
+ DenseElementsAttr input, ArrayRef<int32_t> perms) {
+ RankedTensorType oldType = llvm::cast<RankedTensorType>(input.getType());
+ RankedTensorType newType = RankedTensorType::get(
+ tosa::applyTOSAPermutation(oldType.getShape(), perms),
+ oldType.getElementType());
+ size_t rank = oldType.getRank();
+
+ if (input.isSplat())
+ return input.reshape(newType);
+ // Asserted by TransposeOp verifier and TOSA disallowing tensor with dimension
+ // 0.
+ // If not in place, something is very wrong.
+ if (rank <= 0 || oldType.getNumElements() <= 0 || perms.size() != rank) {
+ signalPassFailure();
+ return std::nullopt;
+ }
+
+ // The algorithm is approximately as follows:
+ // input: perms, input flat array, input tensor type
+ // (1/2) determine the strides of input/output if
+ // they were strided in row-major order. (3) adjust the strides for the
+ // input to be in the same order of indices as the output is written.
+ // (4) process dimension by dimension. example: perms 2, 0, 1; input
+ // 2x3x4; output 4x2x3 for i ... 4, j ... 2, k ... 3: output[i][j][k] =
+ // input[j][k][i] output[6i + 3j + k] = input[12j + 4k + i] and we adjust
+ // input strides to be as input[i + 12j + 4k] so we may process
+ // layer-by-layer.
+
+ // Step 1/2: Strides for input. We ignore output since row-major and can just
+ // push_back.
+
+ SmallVector<int64_t> originalInputStrides(rank);
+ originalInputStrides[rank - 1] = 1;
+ // index with int64_t to avoid overflow
+ for (int64_t i = rank - 2; i >= 0; i--)
+ originalInputStrides[i] =
+ originalInputStrides[i + 1] * oldType.getDimSize(i + 1);
+
+ // Step 3: Transpose strides of input to be same indexing (i, j, k, ...) as
+ // output which is done in row-major order.
+
+ SmallVector<int64_t> newInputStrides;
+ newInputStrides.reserve(rank);
+ for (int32_t v : perms)
+ newInputStrides.push_back(originalInputStrides[v]);
+
+ // Step 4: Write out the transposed "flat array" dimension by dimension.
+
+ auto inputArray = input.getValues<Attribute>();
+ SmallVector<std::pair<int64_t, int64_t>> boundsAndStrides;
+ for (size_t i = 0; i < rank; i++)
+ boundsAndStrides.push_back({newType.getDimSize(i), newInputStrides[i]});
+
+ SmallVector<Attribute> resultArray;
+ resultArray.reserve(inputArray.size());
+
+ std::function<void(int64_t,
+ SmallVector<std::pair<int64_t, int64_t>>::const_iterator)>
+ processTransposeDim = [&](auto accumulatedIndex, auto it) {
+ if (it == boundsAndStrides.end()) {
+ resultArray.push_back(inputArray[accumulatedIndex]);
+ return;
+ }
+
+ for (int64_t i = 0; i < it->first; i++) {
+ int64_t j = accumulatedIndex + i * it->second;
+ processTransposeDim(j, it + 1);
+ }
+ };
+
+ processTransposeDim(0, boundsAndStrides.begin());
+
+ return DenseElementsAttr::get(newType, resultArray);
+}
+
+// The SetVector should only contain ConstOp, ReshapeOp, TransposeOp
+// as the sources of the data dependencies, and TosaElementWiseOperator
+// after that, if the function returns true.
+bool TosaRemoveRedundantTransposes::collectFanIn(
+ Operation *op, SetVector<Operation *> &collected) {
+ // Can occur if defined through the parameter to a func.func.
+ if (!op)
+ return false;
+
+ if (!llvm::isa_and_present<tosa::TosaDialect>(op->getDialect()))
+ return false;
+
+ // Prevent extra work if already seen.
+ if (collected.contains(op))
+ return true;
+
+ // Throw it out so later don't have to deal with this.
+ if (op->getNumResults() != 1 ||
+ !llvm::isa<RankedTensorType>(op->getResult(0).getType()))
+ return false;
+
+ // We don't wish to traverse up a ReshapeOp,
+ // since generally we can't propagate a TransposeOp through it.
+ // TransposeOp, ReshapeOp, ConstOp will have no in-edges in the data
+ // dependency graph we construct for the downstream TransposeOp.
+ if (!llvm::isa<tosa::TransposeOp>(op) && !llvm::isa<tosa::ReshapeOp>(op) &&
+ !llvm::isa<tosa::ConstOp>(op)) {
+
+ if (!op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>())
+ return false;
+
+ for (Value operand : op->getOperands()) {
+
+ if (!collectFanIn(operand.getDefiningOp(), collected))
+ return false;
+ }
+ }
+
+ // Insert in topological order.
+ collected.insert(op);
+
+ return true;
+}
+
+// Assuming that due to the verification of TransposeOp
+// perms arrays are permutations of 0 - perms.size() - 1.
+bool TosaRemoveRedundantTransposes::areNullifyingTransposes(
+ ArrayRef<int32_t> perms1, ArrayRef<int32_t> perms2) {
+ if (perms1.size() != perms2.size())
+ return false;
+ for (int32_t i = 0; i < static_cast<int32_t>(perms1.size()); i++)
+ if (perms2[perms1[i]] != i)
+ return false;
+ return true;
+}
+
+// Primarily overload for those with TosaElementwiseOperator trait.
+// The other ones handle the case of the operations that occur at the
+// roots of the data dependency graph (ConstOp, ReshapeOp, TransposeOp).
+std::optional<Value> TosaRemoveRedundantTransposes::buildMappedToValue(
+ Operation *op, const DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms) {
+ if (op->getNumResults() != 1 ||
+ !op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>())
+ return std::nullopt;
+
+ auto resultType = llvm::cast<RankedTensorType>(op->getResult(0).getType());
+ SmallVector<Value, 3> operands;
+ for (Value v : op->getOperands()) {
+ if (valuesMap.contains(v)) {
+ operands.push_back(valuesMap.at(v));
+ } else {
+ return std::nullopt;
+ }
+ }
+
+ // Conceptually, we propagate the downstream TransposeOp through
+ // these interveaning operations. For example,
+ // %0 = tosa.clamp %input : (tensor<2x3xi32>) -> tensor<2x3xi32>
+ // %1 = tosa.transpose %0 {perms = [1, 0]} : (tensor<2x3xi32>) ->
+ // tensor<3x2xi32> becomes: %0 = tosa.transpose %input {perms = [1, 0]} :
+ // (tensor<2x3xi32>) -> tensor<3x2xi32> %1 = tosa.clamp %0 : (tensor<3x2xi32>)
+ // -> tensor<3x2xi32>) We construct this new tosa.clamp here, but it doesn't
+ // turn "live" until the final downstream transpose in the chain (that we are
+ // currently traversing up its dependencies) is replaced with the proper value
+ // from this new chain.
+ return rewriter
+ .create(op->getLoc(),
+ rewriter.getStringAttr(op->getName().getStringRef()), operands,
+ RankedTensorType::get(tosa::applyTOSAPermutation(
+ resultType.getShape(), downstreamPerms),
+ resultType.getElementType()),
+ op->getAttrs())
+ ->getResult(0);
+}
+
+std::optional<Value> TosaRemoveRedundantTransposes::buildMappedToValue(
+ tosa::TransposeOp transposeOp, const DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms) {
+ SmallVector<int32_t> perms;
+ if (failed(transposeOp.getConstantPerms(perms)) ||
+ !areNullifyingTransposes(downstreamPerms, perms))
+ return std::nullopt;
+ return transposeOp.getInput1();
+}
+
+std::optional<Value> TosaRemoveRedundantTransposes::buildMappedToValue(
+ tosa::ReshapeOp reshapeOp, const DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms) {
+ auto reshapeOutput = reshapeOp.getOutput();
+ auto reshapeOutputType =
+ llvm::cast<RankedTensorType>(reshapeOutput.getType());
+ if (downstreamPerms.size() !=
+ static_cast<size_t>(reshapeOutputType.getRank()))
+ return std::nullopt;
+
+ // Since perms is guaranteed to be i32,
+ // then this is OK, since --canonicalize can fold them into one.
+ auto permsAttr = rewriter.getI32TensorAttr(downstreamPerms);
+ auto permsValue = rewriter.create<tosa::ConstOp>(
+ reshapeOp.getLoc(), permsAttr.getType(), permsAttr);
+
+ // We cannot propagate the TransposeOp through the ReshapeOp, like
+ // we do with those with TosaElementwiseOperator attribute.
+ // In general, there won't be any transpose upstream of the ReshapeOp,
+ // such as in the ResNet networks.
+
+ // By propagating it here, we permit ourselves to allow this dependency
+ // chain to be removed, and also potentially later remove this one
+ // if the inserted TransposeOp lends itself to the TransposeIsReshape
+ // canonicalization. For example, in the common PyTorch networks.
+
+ // There can be pathological behavior if there are many TransposeOp
+ // that do not lend themselves to the TransposeIsReshape canonicalization.
+ auto insertedOp = rewriter.create<tosa::TransposeOp>(
+ reshapeOp.getLoc(),
+ RankedTensorType::get(tosa::applyTOSAPermutation(
+ reshapeOutputType.getShape(), downstreamPerms),
+ reshapeOutputType.getElementType()),
+ reshapeOutput, permsValue->getResult(0));
+ return insertedOp->getResult(0);
+}
+
+std::optional<Value> TosaRemoveRedundantTransposes::buildMappedToValue(
+ tosa::ConstOp constOp, const DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms) {
+ auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(constOp.getValue());
+ if (!denseAttr)
+ return std::nullopt;
+ auto maybeNewDenseAttr = transposeDenseAttribute(denseAttr, downstreamPerms);
+ if (!maybeNewDenseAttr.has_value())
+ return std::nullopt;
+ auto newDenseAttr = maybeNewDenseAttr.value();
+ auto newConstOp = rewriter.create<tosa::ConstOp>(
+ constOp.getLoc(), newDenseAttr.getType(), newDenseAttr);
+ return newConstOp->getResult(0);
+}
+
+bool TosaRemoveRedundantTransposes::convertDependentOps(
+ SetVector<Operation *> &dependentOps, DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms) {
+
+ for (Operation *op : dependentOps) {
+ if (!op || op->getNumResults() != 1)
+ return false;
+
+ Value priorValue = op->getResult(0);
+
+ // It's possible on a prior transposeOp
+ // we had the same dependency and already resolved it.
+ if (valuesMap.contains(priorValue))
+ continue;
+
+ // Keep converted ops close to the original.
+ rewriter.setInsertionPointAfter(op);
+
+ std::optional<Value> maybeValue =
+ llvm::TypeSwitch<Operation *, std::optional<Value>>(op)
+ .Case<tosa::TransposeOp>([&](tosa::TransposeOp transposeOp) {
+ return buildMappedToValue(transposeOp, valuesMap, rewriter,
+ downstreamPerms);
+ })
+ .Case<tosa::ReshapeOp>([&](tosa::ReshapeOp reshapeOp) {
+ return buildMappedToValue(reshapeOp, valuesMap, rewriter,
+ downstreamPerms);
+ })
+ .Case<tosa::ConstOp>([&](tosa::ConstOp constOp) {
+ return buildMappedToValue(constOp, valuesMap, rewriter,
+ downstreamPerms);
+ })
+ .Default([&](Operation *op) {
+ return buildMappedToValue(op, valuesMap, rewriter,
+ downstreamPerms);
+ });
+
+ if (!maybeValue.has_value())
+ return false;
+
+ valuesMap[priorValue] = maybeValue.value();
+ }
+
+ return true;
+}
+
+// Dependencies are valid for an operation if none of them occur outside
+// of the proper fan-in cones of the downstream TransposeOp with the same perms
+// that we can replace. Described in more detail within.
+bool TosaRemoveRedundantTransposes::dependenciesAreValid(
+ ArrayRef<int32_t> perms, const SetVector<Operation *> &dependentOps,
+ std::set<tosa::TransposeOp> &validTransposes,
+ std::vector<std::pair<tosa::TransposeOp, SetVector<Operation *>>>
+ &transposeInfo) {
+ for (Operation *op : dependentOps) {
+
+ // It's OK wherever ConstOp has uses -- in the worst case, we duplicate.
+ // This can be changed later if we find the memory impact is too high.
+ if (llvm::isa<tosa::ConstOp>(op))
+ continue;
+
+ for (OpOperand &use : op->getUses()) {
+ // Want the uses to be (1) contained in the dependentOps of other
+ // validTransposes, or (2) to be directly used in a TransposeOp with the
+ // same perms. For (2), it is either (a) we inserted this for
+ // ReshapeOp conversion or (b) the fan-in is a subset of our
+ // dependentOps, so it is also a validTranspose that will eventually be
+ // replaced.
+ Operation *user = use.getOwner();
+ if (auto otherTranspose = llvm::dyn_cast<tosa::TransposeOp>(user)) {
+ SmallVector<int32_t> otherPerms;
+
+ // Can later think about cases where transpose -> transpose
+ // or reshape -> transpose, where the transposes are not necessarily
+ // the same perms as the downstream, if implementing a more general
+ // transform. These could be permitted.
+ if (failed(otherTranspose.getConstantPerms(otherPerms)) ||
+ !llvm::equal(perms, otherPerms))
+ return false;
+
+ } else if (llvm::none_of(
+ transposeInfo,
+ [&validTransposes,
+ user](const std::pair<tosa::TransposeOp,
+ SetVector<Operation *>> &info) {
+ const auto &[transposeOp, dependentOps] = info;
+ return validTransposes.count(transposeOp) &&
+ dependentOps.contains(user);
+ })) {
+ return false;
+ }
+ }
+ }
+
+ return true;
+}
+
+// Getting the set of TransposeOp that we can replace without causing
+// the old fan-in cones of any TransposeOp to remain "live", i.e, -- not being
+// dead code. This is done by iterating the set until convergence, since
+// if you are used outside your own fan-in cone, it's possible to be used
+// in another fan-in cone of a TransposeOp that is being replaced -- unless
+// we find that that one has a usage outside of it too.
+std::set<tosa::TransposeOp> TosaRemoveRedundantTransposes::getGoodReplacements(
+ ArrayRef<int32_t> perms,
+ std::vector<std::pair<tosa::TransposeOp, SetVector<Operation *>>>
+ &transposeInfo) {
+ // Initially, we assume they are all good to replace,
+ // and we whittle them down based on our criteria.
+ std::set<tosa::TransposeOp> ableToReplace;
+ for (const auto &[transposeOp, _] : transposeInfo)
+ ableToReplace.insert(transposeOp);
+
+ bool gotRid;
+ do {
+ gotRid = false;
+ for (const auto &[transposeOp, dependentOps] : transposeInfo) {
+ // We don't care about it. Already invalidated.
+ if (!ableToReplace.count(transposeOp))
+ continue;
+
+ // Check for validity.
+ if (!dependenciesAreValid(perms, dependentOps, ableToReplace,
+ transposeInfo)) {
+ ableToReplace.erase(transposeOp);
+ gotRid = true;
+ break;
+ }
+ }
+
+ } while (gotRid);
+
+ return ableToReplace;
+}
+
+void TosaRemoveRedundantTransposes::runOnOperation() {
+ // We want to operate only within a single block.
+ // Call --inline before to run the pass.
+ // This assumption is not strict and could potentially be made more
+ // flexible.
+ if (!getOperation().getRegion().hasOneBlock())
+ return;
+
+ IRRewriter rewriter(&getContext());
+ // For each perms, maintain a mapping for converted ops, avoid duplication.
+ DenseMap<ArrayRef<int32_t>, DenseMap<Value, Value>> permsToValues;
+ // For each perms, we keep track of which tosa::TransposeOp are eligible
+ // for replacement alongside their dependentOps.
+ DenseMap<ArrayRef<int32_t>,
+ std::vector<std::pair<tosa::TransposeOp, SetVector<Operation *>>>>
+ permsToTransposeInfo;
+
+ // Necessary for lifetime, since DenseMap keeps a copy of the ArrayRef.
+ // Use SmallVector for perms (common-case is <= 4) but std::vector otherwise
+ // since no guarantee of smallness.
+ std::vector<SmallVector<int32_t>> collectedPerms;
+
+ // This keeps track of the order across all eligible-for-replacement
+ // TransposeOp and their perms, a necessity for the final replacements.
+ std::stack<std::pair<tosa::TransposeOp, ArrayRef<int32_t>>>
+ totalTransposeOrder;
+
+ // We want to reserve the space up front,
+ // since SmallVector stores some data internally
+ // and the ArrayRef can reference that, which we don't want to get
+ // invalidated.
+ size_t expectedMaxPerms = 0;
+ getOperation().walk([&](tosa::TransposeOp) { expectedMaxPerms += 1; });
+ collectedPerms.reserve(expectedMaxPerms);
+
+ getOperation().walk([&](tosa::TransposeOp transposeOp) {
+ SetVector<Operation *> dependentOps;
+ collectedPerms.emplace_back();
+ SmallVector<int32_t> &perms = collectedPerms.back();
+
+ // Dynamic shapes are OK,
+ // but the incompatible ones will be rejected later.
+ auto input = transposeOp.getInput1();
+ auto output = transposeOp.getOutput();
+
+ // However, we don't support unranked tensors.
+ if (!llvm::isa<RankedTensorType>(input.getType()) ||
+ !llvm::isa<RankedTensorType>(output.getType()))
+ return;
+
+ // No transformation when transpose permutation non-constant.
+ if (failed(transposeOp.getConstantPerms(perms)))
+ return;
+
+ // We let --canonicalize deal with identity transpose.
+ if (llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
+ return;
+
+ // Can fail if some set of basic invariants is not met that we want to
+ // perform our conversions.
+ if (!collectFanIn(input.getDefiningOp(), dependentOps))
+ return;
+
+ // Want to associate valuesMap for already converted of the same perms,
+ // since it's possible multiple downstream transposes w/ different perms
+ // converge on an op, which would result in different transformations.
+ DenseMap<Value, Value> &valuesMap = permsToValues[perms];
+
+ // Attempt to perform the conversions and placements into IR
+ // without turning inserted code "live". Also fills out valuesMap.
+ // Fails if there is an intermediary we do not support.
+ if (!convertDependentOps(dependentOps, valuesMap, rewriter, perms))
+ // Some additional operations may have been inserted, but will be
+ // removed by dead code elimination and --canonicalize.
+ return;
+
+ // This should not happen. If it does -- it's unexpected,
+ // so we fail the pass.
+ if (!valuesMap.contains(input))
+ return signalPassFailure();
+
+ // It's possible the types are not compatible (because of dynamic shapes),
+ // and in these cases, want to resolve dynamic shapes before running the
+ // pass.
+ if (output.getType() != valuesMap.at(input).getType())
+ return;
+
+ auto &transposeInfo = permsToTransposeInfo[perms];
+
+ // In general, we might also want to introduce "newDependentOps"
+ // if there are new usages that don't fall inside the original fan-ins
+ // (like the tosa::TransposeOp we insert for tosa::ReshapeOp),
+ // but in this case, that is specialized enough and overlaps
+ // with another direct-use tosa::TransposeOp case we need to cover anyway.
+ transposeInfo.push_back({transposeOp, dependentOps});
+
+ // This is for the final replacement across all transposes.
+ totalTransposeOrder.push({transposeOp, perms});
+ });
+
+ // We want to do a full fan-in analysis on a perms-level,
+ // since if we do it on a multi-perms level, and they share (due to a shared
+ // dependency on a Reshape) then we would also get duplicate ops.
+ // Const is special cased.
+ std::set<tosa::TransposeOp> ableToReplace;
+ for (auto &[perms, transposeInfo] : permsToTransposeInfo) {
+ // Gives us back replacements that would never result in any duplicate
+ // operations being inserted by us in the IR (i.e, our goal is only to
+ // remove transposes, and not create a "new chain" to do so, but replace
+ // the existing chains).
+ // Ideally, --canonicalize is run before this pass, since it helps this
+ // analysis by removing dead code to allow more potentially acceptable
+ // transformations.
+ auto goodReplacementsForPerms = getGoodReplacements(perms, transposeInfo);
+ ableToReplace.insert(goodReplacementsForPerms.begin(),
+ goodReplacementsForPerms.end());
+ }
+
+ // We want to do replacement across all transposes
+ // in reverse order, due to invalidation of valuesMap mappings
+ // if we did it otherwise.
+ while (!totalTransposeOrder.empty()) {
+ auto [transposeOp, perms] = totalTransposeOrder.top();
+ totalTransposeOrder.pop();
+
+ if (ableToReplace.count(transposeOp) == 0)
+ continue;
+
+ auto &valuesMap = permsToValues[perms];
+ auto input = transposeOp.getInput1();
+
+ // The purpose of this reverse iteration
+ // is to avoid valuesMap invalidation. If it happens,
+ // something is wrong.
+ if (!valuesMap.contains(input))
+ return signalPassFailure();
+
+ rewriter.replaceOp(transposeOp, valuesMap.at(input));
+ }
+}
+
+} // namespace
+
+std::unique_ptr<Pass> tosa::createTosaRemoveRedundantTransposes() {
+ return std::make_unique<TosaRemoveRedundantTransposes>();
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-remove-redundant-transposes/tosa-remove-redundant-transposes-dynamic-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-remove-redundant-transposes/tosa-remove-redundant-transposes-dynamic-shapes.mlir
new file mode 100644
index 00000000000000..89da09a697114f
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-remove-redundant-transposes/tosa-remove-redundant-transposes-dynamic-shapes.mlir
@@ -0,0 +1,55 @@
+// RUN: mlir-opt --verify-diagnostics --split-input-file --verify-each --tosa-remove-redundant-transposes %s | FileCheck %s
+
+// -----
+
+// CHECK-LABEL: @test_unknown_dim_inner_replacement_matches
+// CHECK: tosa.const
+// CHECK-NEXT: tosa.transpose %arg0
+// CHECK-NEXT: return %arg0
+func.func @test_unknown_dim_inner_replacement_matches(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
+ %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+ %0 = tosa.transpose %arg0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<?x3xi32>
+ %1 = tosa.transpose %0, %perms : (tensor<?x3xi32>, tensor<2xi32>) -> tensor<3x2xi32>
+ return %1 : tensor<3x2xi32>
+}
+
+// -----
+
+
+// CHECK-LABEL: @test_unknown_dim_outer_replacement_matches
+// CHECK: tosa.const
+// CHECK-NEXT: tosa.transpose %arg0
+// CHECK-NEXT: return %arg0
+func.func @test_unknown_dim_outer_replacement_matches(%arg0: tensor<3x?xi32>) -> tensor<3x?xi32> {
+ %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+ %0 = tosa.transpose %arg0, %perms : (tensor<3x?xi32>, tensor<2xi32>) -> tensor<2x3xi32>
+ %1 = tosa.transpose %0, %perms : (tensor<2x3xi32>, tensor<2xi32>) -> tensor<3x?xi32>
+ return %1 : tensor<3x?xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_transpose_tracks_to_nullifying_diverging_binary_replacements_match
+// CHECK: tosa.const
+// CHECK-NEXT: tosa.transpose %arg0
+// CHECK-NEXT: tosa.transpose %arg1
+// CHECK-NEXT: tosa.clamp
+// CHECK-NEXT: %[[CLAMP:.*]] = tosa.clamp %arg0
+// CHECK-NEXT: tosa.abs
+// CHECK-NEXT: %[[ABS:.*]] = tosa.abs %arg1
+// CHECK-NEXT: tosa.add
+// CHECK-NEXT: %[[ADD:.*]] = tosa.add %[[CLAMP]], %[[ABS]]
+// CHECK-NEXT: tosa.const
+// CHECK-NOT: tosa.transpose
+// CHECK-NEXT: return %[[ADD]]
+func.func @test_transpose_tracks_to_nullifying_diverging_binary_replacements_match(%arg0: tensor<1x?x3x4xi32>, %arg1: tensor<1x2x?x4xi32>) -> tensor<1x2x3x4xi32> {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %transpose0 = tosa.transpose %arg0, %perms0 : (tensor<1x?x3x4xi32>, tensor<4xi32>) -> tensor<?x3x4x?xi32>
+ %transpose1 = tosa.transpose %arg1, %perms0 : (tensor<1x2x?x4xi32>, tensor<4xi32>) -> tensor<1x?x?x2xi32>
+ %clamp = tosa.clamp %transpose0 {min_int = 0 : i64, max_int = 1 : i64, min_fp = 0.0 : f64, max_fp = 1.0 : f64} : (tensor<?x3x4x?xi32>) -> tensor<?x3x4x?xi32>
+ %abs = tosa.abs %transpose1 : (tensor<1x?x?x2xi32>) -> tensor<1x?x?x2xi32>
+ %add = tosa.add %clamp, %abs : (tensor<?x3x4x?xi32>, tensor<1x?x?x2xi32>) -> tensor<1x3x4x2xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %result = tosa.transpose %add, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ return %result : tensor<1x2x3x4xi32>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-remove-redundant-transposes/tosa-remove-redundant-transposes-pipeline.mlir b/mlir/test/Dialect/Tosa/tosa-remove-redundant-transposes/tosa-remove-redundant-transposes-pipeline.mlir
new file mode 100644
index 00000000000000..fb508adab40639
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-remove-redundant-transposes/tosa-remove-redundant-transposes-pipeline.mlir
@@ -0,0 +1,48 @@
+// RUN: mlir-opt --verify-diagnostics --split-input-file --verify-each --tosa-remove-redundant-transposes --canonicalize %s | FileCheck %s
+// COM: the purpose of this file is to demonstrate real examples of transformations on TOSA lowerings with the full
+// COM: power of --tosa-remove-redundant-transposes in conjunction with --canonicalize
+
+// COM: same as @test_resnet18_common_case in --tosa-remove-redundant-transposes-static-shapes
+
+// CHECK-LABEL: @test_resnet18
+// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense_resource<torch_tensor_64_torch.float32_1> : tensor<64xf32>}> : () -> tensor<64xf32>
+// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{value = dense_resource<torch_tensor_64_torch.float32> : tensor<64xf32>}> : () -> tensor<64xf32>
+// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<9.99999974E-6> : tensor<1xf32>}> : () -> tensor<1xf32>
+// CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor<1xf32>}> : () -> tensor<1xf32>
+// CHECK-DAG: %[[VAL_6:.*]] = tosa.add %arg1, %[[VAL_4]] : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32>
+// CHECK-DAG: %[[VAL_7:.*]] = tosa.pow %[[VAL_6]], %[[VAL_5]] : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32>
+// CHECK-DAG: %[[VAL_8:.*]] = tosa.reciprocal %[[VAL_7]] : (tensor<64xf32>) -> tensor<64xf32>
+// CHECK-DAG: %[[VAL_9:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 64>} : (tensor<64xf32>) -> tensor<1x1x1x64xf32>
+// CHECK-DAG: %[[VAL_10:.*]] = tosa.sub %arg2, %[[VAL_9]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32>
+// CHECK-DAG: %[[VAL_11:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array<i64: 1, 1, 1, 64>} : (tensor<64xf32>) -> tensor<1x1x1x64xf32>
+// CHECK-DAG: %[[VAL_12:.*]] = tosa.mul %[[VAL_10]], %[[VAL_11]] {shift = 0 : i8} : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32>
+// CHECK-DAG: %[[VAL_13:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array<i64: 1, 1, 1, 64>} : (tensor<64xf32>) -> tensor<1x1x1x64xf32>
+// CHECK-DAG: %[[VAL_14:.*]] = tosa.mul %[[VAL_12]], %[[VAL_13]] {shift = 0 : i8} : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32>
+// CHECK-DAG: %[[VAL_15:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array<i64: 1, 1, 1, 64>} : (tensor<64xf32>) -> tensor<1x1x1x64xf32>
+// CHECK-DAG: %[[VAL_16:.*]] = tosa.add %[[VAL_14]], %[[VAL_15]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32>
+// CHECK-DAG: %[[VAL_17:.*]] = tosa.clamp %[[VAL_16]] {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x112x112x64xf32>) -> tensor<1x112x112x64xf32>
+// CHECK: return %[[VAL_17]] : tensor<1x112x112x64xf32>
+
+func.func @test_resnet18(%arg0: tensor<64xf32>, %arg1: tensor<64xf32>, %74: tensor<1x112x112x64xf32>) -> tensor<1x112x112x64xf32> {
+ %59 = "tosa.const"() <{value = dense_resource<torch_tensor_64_torch.float32_1> : tensor<64xf32>}> : () -> tensor<64xf32>
+ %60 = "tosa.const"() <{value = dense_resource<torch_tensor_64_torch.float32> : tensor<64xf32>}> : () -> tensor<64xf32>
+ %63 = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32>
+ %64 = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
+ %69 = "tosa.const"() <{value = dense<9.99999974E-6> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %70 = "tosa.const"() <{value = dense<5.000000e-01> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %75 = tosa.transpose %74, %64 : (tensor<1x112x112x64xf32>, tensor<4xi32>) -> tensor<1x64x112x112xf32>
+ %76 = tosa.add %arg1, %69 : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32>
+ %77 = tosa.pow %76, %70 : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32>
+ %78 = tosa.reciprocal %77 : (tensor<64xf32>) -> tensor<64xf32>
+ %79 = tosa.reshape %arg0 {new_shape = array<i64: 1, 64, 1, 1>} : (tensor<64xf32>) -> tensor<1x64x1x1xf32>
+ %80 = tosa.sub %75, %79 : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32>
+ %81 = tosa.reshape %78 {new_shape = array<i64: 1, 64, 1, 1>} : (tensor<64xf32>) -> tensor<1x64x1x1xf32>
+ %82 = tosa.mul %80, %81 {shift = 0 : i8} : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32>
+ %83 = tosa.reshape %60 {new_shape = array<i64: 1, 64, 1, 1>} : (tensor<64xf32>) -> tensor<1x64x1x1xf32>
+ %84 = tosa.mul %82, %83 {shift = 0 : i8} : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32>
+ %85 = tosa.reshape %59 {new_shape = array<i64: 1, 64, 1, 1>} : (tensor<64xf32>) -> tensor<1x64x1x1xf32>
+ %86 = tosa.add %84, %85 : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32>
+ %87 = tosa.clamp %86 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x64x112x112xf32>) -> tensor<1x64x112x112xf32>
+ %88 = tosa.transpose %87, %63 : (tensor<1x64x112x112xf32>, tensor<4xi32>) -> tensor<1x112x112x64xf32>
+ return %88 : tensor<1x112x112x64xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-remove-redundant-transposes/tosa-remove-redundant-transposes-static-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-remove-redundant-transposes/tosa-remove-redundant-transposes-static-shapes.mlir
new file mode 100644
index 00000000000000..acc0baf0810358
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-remove-redundant-transposes/tosa-remove-redundant-transposes-static-shapes.mlir
@@ -0,0 +1,552 @@
+// RUN: mlir-opt --verify-diagnostics --split-input-file --verify-each --tosa-remove-redundant-transposes %s | FileCheck %s
+// COM: note that this pass makes no attempt to remove dead code. that is done by --canonicalize
+
+// CHECK-LABEL: @test_transpose_tracks_to_nullifying__single_step
+// CHECK: tosa.const
+// CHECK-NEXT: tosa.transpose %arg0
+// CHECK-NEXT: tosa.ceil
+// CHECK-NEXT: %[[RESULT:.*]] = tosa.ceil %arg0
+// CHECK-NEXT: tosa.const
+// CHECK-NOT: tosa.transpose
+// CHECK-NEXT: return %[[RESULT]]
+func.func @test_transpose_tracks_to_nullifying__single_step(%arg0: tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %ceil = tosa.ceil %0 : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %1 = tosa.transpose %ceil, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ return %1 : tensor<1x2x3x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_transpose_tracks_to_nullifying__multi_unary_step
+// CHECK: tosa.const
+// CHECK-NEXT: tosa.transpose %arg0
+// CHECK-NEXT: tosa.clamp
+// CHECK-NEXT: %[[CLAMP:.*]] = tosa.clamp %arg0
+// CHECK-NEXT: tosa.abs
+// CHECK-NEXT: %[[ABS:.*]] = tosa.abs %[[CLAMP]]
+// CHECK-NEXT: tosa.bitwise_not
+// CHECK-NEXT: %[[NOT:.*]] = tosa.bitwise_not %[[ABS]]
+// CHECK-NEXT: tosa.const
+// CHECK-NOT: tosa.transpose
+// CHECK-NEXT: return %[[NOT]]
+func.func @test_transpose_tracks_to_nullifying__multi_unary_step(%arg0: tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %clamp = tosa.clamp %0 {max_fp = 1.0 : f32, min_fp = 0.0 : f32, max_int = 1 : i64, min_int = 0 : i64} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %abs = tosa.abs %clamp : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %bitwise_not = tosa.bitwise_not %abs : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %1 = tosa.transpose %bitwise_not, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ return %1 : tensor<1x2x3x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_transpose_tracks_to_nullifying__diverging_binary
+// CHECK: tosa.const
+// CHECK-NEXT: tosa.transpose %arg0
+// CHECK-NEXT: tosa.transpose %arg1
+// CHECK-NEXT: tosa.clamp
+// CHECK-NEXT: %[[CLAMP:.*]] = tosa.clamp %arg0
+// CHECK-NEXT: tosa.abs
+// CHECK-NEXT: %[[ABS:.*]] = tosa.abs %arg1
+// CHECK-NEXT: tosa.add
+// CHECK-NEXT: %[[ADD:.*]] = tosa.add %[[CLAMP]], %[[ABS]]
+// CHECK-NEXT: tosa.const
+// CHECK-NOT: tosa.transpose
+// CHECK-NEXT: return %[[ADD]]
+func.func @test_transpose_tracks_to_nullifying__diverging_binary(%arg0: tensor<1x2x3x4xi32>, %arg1: tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %transpose0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %transpose1 = tosa.transpose %arg1, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %clamp = tosa.clamp %transpose0 {max_fp = 1.0 : f32, min_fp = 0.0 : f32, max_int = 1 : i64, min_int = 0 : i64} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %abs = tosa.abs %transpose1 : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %add = tosa.add %clamp, %abs : (tensor<1x3x4x2xi32>, tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %result = tosa.transpose %add, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ return %result : tensor<1x2x3x4xi32>
+}
+
+// -----
+
+
+// CHECK-LABEL: @test_transpose_tracks_to_nullifying__diverging_binary_with_broadcasting
+// CHECK: tosa.const
+// CHECK-NEXT: tosa.transpose %arg0
+// CHECK-NEXT: tosa.transpose %arg1
+// CHECK-NEXT: tosa.clamp
+// CHECK-NEXT: %[[CLAMP:.*]] = tosa.clamp %arg0
+// CHECK-NEXT: tosa.abs
+// CHECK-NEXT: %[[ABS:.*]] = tosa.abs %arg1
+// CHECK-NEXT: tosa.add
+// CHECK-NEXT: %[[ADD:.*]] = tosa.add %[[CLAMP]], %[[ABS]]
+// CHECK-NEXT: tosa.const
+// CHECK-NOT: tosa.transpose
+// CHECK-NEXT: return %[[ADD]]
+func.func @test_transpose_tracks_to_nullifying__diverging_binary_with_broadcasting(%arg0: tensor<1x2x3x4xi32>, %arg1: tensor<1x2x1x4xi32>) -> tensor<1x2x3x4xi32> {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %transpose0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %transpose1 = tosa.transpose %arg1, %perms0 : (tensor<1x2x1x4xi32>, tensor<4xi32>) -> tensor<1x1x4x2xi32>
+ %clamp = tosa.clamp %transpose0 {max_fp = 1.0 : f32, min_fp = 0.0 : f32, max_int = 1 : i64, min_int = 0 : i64} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %abs = tosa.abs %transpose1 : (tensor<1x1x4x2xi32>) -> tensor<1x1x4x2xi32>
+ %add = tosa.add %clamp, %abs : (tensor<1x3x4x2xi32>, tensor<1x1x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %result = tosa.transpose %add, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ return %result : tensor<1x2x3x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_transpose_tracks_to_nullifying__converging_binary
+// CHECK: tosa.const
+// CHECK-NEXT: tosa.transpose %arg0
+// CHECK-NEXT: tosa.add
+// CHECK-NEXT: %[[RESULT:.*]] = tosa.add %arg0, %arg0
+// CHECK-NEXT: tosa.const
+// CHECK-NOT: tosa.transpose
+// CHECK-NEXT: return %[[RESULT]]
+func.func @test_transpose_tracks_to_nullifying__converging_binary(%arg0: tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %clamp = tosa.add %0, %0 : (tensor<1x3x4x2xi32>, tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %1 = tosa.transpose %clamp, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ return %1 : tensor<1x2x3x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_torch_conv2d_with_elementwise_in_between
+// CHECK: %[[CONV1:.*]] = tosa.conv2d
+// CHECK: %[[CEIL:.*]] = tosa.ceil %[[CONV1]]
+// CHECK: %[[CONV2:.*]] = tosa.conv2d %[[CEIL]]
+// CHECK: %[[FLOOR:.*]] = tosa.floor %[[CONV2]]
+// CHECK: %[[CONV3:.*]] = tosa.conv2d %[[FLOOR]]
+// CHECK: %[[RES:.*]] = tosa.transpose %[[CONV3]]
+// CHECK: return %[[RES]]
+func.func @test_torch_conv2d_with_elementwise_in_between(%arg0: tensor<3x3x10x10xf32>) -> tensor<3x3x7x7xf32> {
+ %0 = "tosa.const"() <{value = dense_resource<torch_tensor_3_torch.float32_2> : tensor<3xf32>}> : () -> tensor<3xf32>
+ %1 = "tosa.const"() <{value = dense_resource<torch_tensor_3_3_2_2_torch.float32_2> : tensor<3x3x2x2xf32>}> : () -> tensor<3x3x2x2xf32>
+ %2 = "tosa.const"() <{value = dense_resource<torch_tensor_3_torch.float32_1> : tensor<3xf32>}> : () -> tensor<3xf32>
+ %3 = "tosa.const"() <{value = dense_resource<torch_tensor_3_3_2_2_torch.float32_1> : tensor<3x3x2x2xf32>}> : () -> tensor<3x3x2x2xf32>
+ %4 = "tosa.const"() <{value = dense_resource<torch_tensor_3_3_2_2_torch.float32> : tensor<3x3x2x2xf32>}> : () -> tensor<3x3x2x2xf32>
+ %5 = "tosa.const"() <{value = dense_resource<torch_tensor_3_torch.float32> : tensor<3xf32>}> : () -> tensor<3xf32>
+ %6 = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32>
+ %7 = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
+ %8 = tosa.transpose %arg0, %6 : (tensor<3x3x10x10xf32>, tensor<4xi32>) -> tensor<3x10x10x3xf32>
+ %9 = tosa.transpose %4, %6 : (tensor<3x3x2x2xf32>, tensor<4xi32>) -> tensor<3x2x2x3xf32>
+ %10 = tosa.conv2d %8, %9, %5 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<3x10x10x3xf32>, tensor<3x2x2x3xf32>, tensor<3xf32>) -> tensor<3x9x9x3xf32>
+ %11 = tosa.transpose %10, %7 : (tensor<3x9x9x3xf32>, tensor<4xi32>) -> tensor<3x3x9x9xf32>
+ %12 = tosa.ceil %11 : (tensor<3x3x9x9xf32>) -> tensor<3x3x9x9xf32>
+ %13 = tosa.transpose %12, %6 : (tensor<3x3x9x9xf32>, tensor<4xi32>) -> tensor<3x9x9x3xf32>
+ %14 = tosa.transpose %3, %6 : (tensor<3x3x2x2xf32>, tensor<4xi32>) -> tensor<3x2x2x3xf32>
+ %15 = tosa.conv2d %13, %14, %2 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<3x9x9x3xf32>, tensor<3x2x2x3xf32>, tensor<3xf32>) -> tensor<3x8x8x3xf32>
+ %16 = tosa.transpose %15, %7 : (tensor<3x8x8x3xf32>, tensor<4xi32>) -> tensor<3x3x8x8xf32>
+ %17 = tosa.floor %16 : (tensor<3x3x8x8xf32>) -> tensor<3x3x8x8xf32>
+ %18 = tosa.transpose %17, %6 : (tensor<3x3x8x8xf32>, tensor<4xi32>) -> tensor<3x8x8x3xf32>
+ %19 = tosa.transpose %1, %6 : (tensor<3x3x2x2xf32>, tensor<4xi32>) -> tensor<3x2x2x3xf32>
+ %20 = tosa.conv2d %18, %19, %0 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<3x8x8x3xf32>, tensor<3x2x2x3xf32>, tensor<3xf32>) -> tensor<3x7x7x3xf32>
+ %21 = tosa.transpose %20, %7 : (tensor<3x7x7x3xf32>, tensor<4xi32>) -> tensor<3x3x7x7xf32>
+ return %21 : tensor<3x3x7x7xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_mulop_conversion
+// CHECK-NEXT: tosa.const
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: tosa.mul
+// CHECK-NEXT: %[[RES:.*]] = tosa.mul %arg0, %arg1
+// CHECK-NEXT: tosa.const
+// CHECK-NOT: tosa.transpose
+// CHECK-NEXT: return %[[RES]]
+func.func @test_mulop_conversion(%arg0: tensor<1x2x3x4xi32>, %arg1: tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %transpose0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %transpose1 = tosa.transpose %arg1, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %mul = tosa.mul %transpose0, %transpose1 {shift = 0 : i8} : (tensor<1x3x4x2xi32>, tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %result = tosa.transpose %mul, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ return %result : tensor<1x2x3x4xi32>
+}
+
+// -----
+
+// COM: this case is a simple example of how ResNet18 functions, where they
+// COM: reshape values in NCHW and rely on the downstream transpose to turn it to NHWC.
+// COM: although it looks as though no transform occurs, the transpose is added and subsequently replaced
+// CHECK-LABEL: @test_basic_reshape
+// CHECK: tosa.const
+// CHECK: tosa.reshape
+// CHECK: tosa.transpose
+func.func @test_basic_reshape(%arg0: tensor<2x3xi32>) -> tensor<1x2x3xi32> {
+ %perms = "tosa.const"() {value = dense<[0, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
+ %1 = tosa.reshape %arg0 {new_shape = array<i64: 1, 3, 2>} : (tensor<2x3xi32>) -> tensor<1x3x2xi32>
+ %2 = tosa.transpose %1, %perms : (tensor<1x3x2xi32>, tensor<3xi32>) -> tensor<1x2x3xi32>
+ return %2 : tensor<1x2x3xi32>
+}
+
+// -----
+
+// COM: omitting checks for dead code
+// COM: in fact, despite a new transpose being inserted in this case and the downstream removed,
+// COM: --canonicalize is able to fold it into the reshape. additionally, the other
+// COM: upstream is removed as well (since no other uses)
+// CHECK-LABEL: @test_reshape_for_broadcast
+// CHECK-DAG: %[[PERMS:.*]] = "tosa.const"() <{value = dense<[2, 1, 0]> : tensor<3xi32>}>
+// CHECK-DAG: %[[RESHAPE_INPUT:.*]] = "tosa.const"() <{value = dense<[1, 2, 3, 4]>
+// CHECK-DAG: %[[RESHAPE:.*]] = tosa.reshape %[[RESHAPE_INPUT]] {new_shape = array<i64: 1, 1, 4>}
+// CHECK-DAG: %[[TRANSPOSE_INSERTED:.*]] = tosa.transpose %[[RESHAPE]], %[[PERMS]] : (tensor<1x1x4xi32>, tensor<3xi32>)
+// CHECK-DAG: %[[ADD:.*]] = tosa.add %arg0, %[[TRANSPOSE_INSERTED]]
+// CHECK: return %[[ADD]]
+func.func @test_reshape_for_broadcast(%arg0: tensor<4x3x2xi32>) -> tensor<4x3x2xi32> {
+ %0 = "tosa.const"() {value = dense<[1,2,3,4]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %reshape = tosa.reshape %0 {new_shape = array<i64: 1, 1, 4>} : (tensor<4xi32>) -> tensor<1x1x4xi32>
+ %perms0 = "tosa.const"() {value = dense<[2, 1, 0]> : tensor<3xi32>} : () -> tensor<3xi32>
+ %transpose0 = tosa.transpose %arg0, %perms0 : (tensor<4x3x2xi32>, tensor<3xi32>) -> tensor<2x3x4xi32>
+ %add = tosa.add %transpose0, %reshape : (tensor<2x3x4xi32>, tensor<1x1x4xi32>) -> tensor<2x3x4xi32>
+ %transpose1 = tosa.transpose %add, %perms0 : (tensor<2x3x4xi32>, tensor<3xi32>) -> tensor<4x3x2xi32>
+ return %transpose1 : tensor<4x3x2xi32>
+}
+
+// -----
+
+// COM: taken directly from ResNet18 translation.
+// COM: changes: %74 as argument instead of result of conv2d
+
+// COM: despite the fact that we have more transposes, since these 'transposes are reshapes', if we run --canonicalize right after
+// COM: the pass, we see all the transposes dissapear and exactly zero are left! so the pass enables this to happen.
+// COM: we demonstrate here to show the results of just our pass, and in the pipeline file, to demonstrate the full reduction.
+
+// CHECK-LABEL: @test_resnet18_common_case
+// COM: note that %74 is now represented by %arg2
+// CHECK-DAG: "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG: %[[VAL_01:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG: %[[VAL_02:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG: %[[VAL_03:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG: %[[VAL_04:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32>
+// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{value = dense_resource<torch_tensor_64_torch.float32_1> : tensor<64xf32>}> : () -> tensor<64xf32>
+// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{value = dense_resource<torch_tensor_64_torch.float32> : tensor<64xf32>}> : () -> tensor<64xf32>
+// CHECK-DAG: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<9.99999974E-6> : tensor<1xf32>}> : () -> tensor<1xf32>
+// CHECK-DAG: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor<1xf32>}> : () -> tensor<1xf32>
+// CHECK-DAG: %[[VAL_9:.*]] = tosa.add %arg1, %[[VAL_6]] : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32>
+// CHECK-DAG: %[[VAL_10:.*]] = tosa.pow %[[VAL_9]], %[[VAL_7]] : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32>
+// CHECK-DAG: %[[VAL_11:.*]] = tosa.reciprocal %[[VAL_10]] : (tensor<64xf32>) -> tensor<64xf32>
+// CHECK-DAG: %[[VAL_12:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 64, 1, 1>} : (tensor<64xf32>) -> tensor<1x64x1x1xf32>
+// CHECK-DAG: %[[VAL_13:.*]] = tosa.transpose %[[VAL_12]], %[[VAL_01]] : (tensor<1x64x1x1xf32>, tensor<4xi32>) -> tensor<1x1x1x64xf32>
+// CHECK-DAG: %[[VAL_15:.*]] = tosa.sub %arg2, %[[VAL_13]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32>
+// CHECK-DAG: %[[VAL_16:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array<i64: 1, 64, 1, 1>} : (tensor<64xf32>) -> tensor<1x64x1x1xf32>
+// CHECK-DAG: %[[VAL_17:.*]] = tosa.transpose %[[VAL_16]], %[[VAL_02]] : (tensor<1x64x1x1xf32>, tensor<4xi32>) -> tensor<1x1x1x64xf32>
+// CHECK-DAG: %[[VAL_19:.*]] = tosa.mul %[[VAL_15]], %[[VAL_17]] {shift = 0 : i8} : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32>
+// CHECK-DAG: %[[VAL_20:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array<i64: 1, 64, 1, 1>} : (tensor<64xf32>) -> tensor<1x64x1x1xf32>
+// CHECK-DAG: %[[VAL_21:.*]] = tosa.transpose %[[VAL_20]], %[[VAL_03]] : (tensor<1x64x1x1xf32>, tensor<4xi32>) -> tensor<1x1x1x64xf32>
+// CHECK-DAG: %[[VAL_23:.*]] = tosa.mul %[[VAL_19]], %[[VAL_21]] {shift = 0 : i8} : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32>
+// CHECK-DAG: %[[VAL_24:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array<i64: 1, 64, 1, 1>} : (tensor<64xf32>) -> tensor<1x64x1x1xf32>
+// CHECK-DAG: %[[VAL_25:.*]] = tosa.transpose %[[VAL_24]], %[[VAL_04]] : (tensor<1x64x1x1xf32>, tensor<4xi32>) -> tensor<1x1x1x64xf32>
+// CHECK-DAG: %[[VAL_27:.*]] = tosa.add %[[VAL_23]], %[[VAL_25]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32>
+// CHECK-DAG: %[[VAL_29:.*]] = tosa.clamp %[[VAL_27]] {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x112x112x64xf32>) -> tensor<1x112x112x64xf32>
+// CHECK: return %[[VAL_29]] : tensor<1x112x112x64xf32>
+
+func.func @test_resnet18_common_case(%arg0: tensor<64xf32>, %arg1: tensor<64xf32>, %74: tensor<1x112x112x64xf32>) -> tensor<1x112x112x64xf32> {
+ %59 = "tosa.const"() <{value = dense_resource<torch_tensor_64_torch.float32_1> : tensor<64xf32>}> : () -> tensor<64xf32>
+ %60 = "tosa.const"() <{value = dense_resource<torch_tensor_64_torch.float32> : tensor<64xf32>}> : () -> tensor<64xf32>
+ %63 = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32>
+ %64 = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
+ %69 = "tosa.const"() <{value = dense<9.99999974E-6> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %70 = "tosa.const"() <{value = dense<5.000000e-01> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %75 = tosa.transpose %74, %64 : (tensor<1x112x112x64xf32>, tensor<4xi32>) -> tensor<1x64x112x112xf32>
+ %76 = tosa.add %arg1, %69 : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32>
+ %77 = tosa.pow %76, %70 : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32>
+ %78 = tosa.reciprocal %77 : (tensor<64xf32>) -> tensor<64xf32>
+ %79 = tosa.reshape %arg0 {new_shape = array<i64: 1, 64, 1, 1>} : (tensor<64xf32>) -> tensor<1x64x1x1xf32>
+ %80 = tosa.sub %75, %79 : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32>
+ %81 = tosa.reshape %78 {new_shape = array<i64: 1, 64, 1, 1>} : (tensor<64xf32>) -> tensor<1x64x1x1xf32>
+ %82 = tosa.mul %80, %81 {shift = 0 : i8} : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32>
+ %83 = tosa.reshape %60 {new_shape = array<i64: 1, 64, 1, 1>} : (tensor<64xf32>) -> tensor<1x64x1x1xf32>
+ %84 = tosa.mul %82, %83 {shift = 0 : i8} : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32>
+ %85 = tosa.reshape %59 {new_shape = array<i64: 1, 64, 1, 1>} : (tensor<64xf32>) -> tensor<1x64x1x1xf32>
+ %86 = tosa.add %84, %85 : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32>
+ %87 = tosa.clamp %86 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x64x112x112xf32>) -> tensor<1x64x112x112xf32>
+ %88 = tosa.transpose %87, %63 : (tensor<1x64x112x112xf32>, tensor<4xi32>) -> tensor<1x112x112x64xf32>
+ return %88 : tensor<1x112x112x64xf32>
+}
+
+
+// -----
+
+// CHECK-LABEL: @test_back_to_back_nullifiers
+// CHECK: %[[PERMS:.*]] = "tosa.const"
+// CHECK: %[[RES:.*]] = tosa.transpose %arg0, %[[PERMS]]
+// CHECK: return %[[RES]]
+func.func @test_back_to_back_nullifiers(%arg0: tensor<2x3xi32>) -> tensor<3x2xi32> {
+ %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+ %0 = tosa.transpose %arg0, %perms : (tensor<2x3xi32>, tensor<2xi32>) -> tensor<3x2xi32>
+ %1 = tosa.transpose %0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
+ %2 = tosa.transpose %1, %perms : (tensor<2x3xi32>, tensor<2xi32>) -> tensor<3x2xi32>
+ return %2 : tensor<3x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_back_to_back_nullifiers_different_transposes
+// CHECK: %[[PERMS:.*]] = "tosa.const"
+// CHECK: "tosa.const"
+// CHECK: %[[RES:.*]] = tosa.transpose %arg0, %[[PERMS]]
+// CHECK: return %[[RES]]
+func.func @test_back_to_back_nullifiers_different_transposes(%arg0: tensor<2x3x4x5xi32>) -> tensor<2x4x5x3xi32> {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %0 = tosa.transpose %arg0, %perms0 : (tensor<2x3x4x5xi32>, tensor<4xi32>) -> tensor<2x4x5x3xi32>
+ %1 = tosa.transpose %0, %perms1 : (tensor<2x4x5x3xi32>, tensor<4xi32>) -> tensor<2x3x4x5xi32>
+ %2 = tosa.transpose %1, %perms0 : (tensor<2x3x4x5xi32>, tensor<4xi32>) -> tensor<2x4x5x3xi32>
+ return %2 : tensor<2x4x5x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_no_transform_if_outside_fan_in_cone
+// CHECK: tosa.const
+// CHECK: %[[CLAMP_IN:.*]] = tosa.transpose
+// CHECK: %[[RES2:.*]] = tosa.clamp %[[CLAMP_IN]]
+// CHECK: tosa.const
+// CHECK: %[[RES1:.*]] = tosa.transpose
+// CHECK: return %[[RES1]], %[[RES2]]
+func.func @test_no_transform_if_outside_fan_in_cone(%arg0: tensor<3x3x3x3xi32>) -> (tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>) {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %0 = tosa.transpose %arg0, %perms0 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ %clamp = tosa.clamp %0 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %1 = tosa.transpose %clamp, %perms1 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ return %1, %clamp : tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_two_different_downstream_converge_to_reshape_same_perms
+// CHECK-DAG: %[[RESHAPE:.*]] = tosa.reshape %arg0
+// CHECK-DAG: %[[NEW_TRANSPOSE:.*]] = tosa.transpose %[[RESHAPE]]
+// CHECK-DAG: %[[CLAMP:.*]] = tosa.clamp %[[NEW_TRANSPOSE]]
+// CHECK: return %[[NEW_TRANSPOSE]], %[[CLAMP]]
+func.func @test_two_different_downstream_converge_to_reshape_same_perms(%arg0: tensor<64xf32>) -> (tensor<4x4x4xf32>, tensor<4x4x4xf32>) {
+ %0 = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
+ %1 = tosa.reshape %arg0 {new_shape = array<i64: 4, 4, 4>} : (tensor<64xf32>) -> tensor<4x4x4xf32>
+ %2 = tosa.clamp %1 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<4x4x4xf32>) -> tensor<4x4x4xf32>
+ %3 = tosa.transpose %1, %0 : (tensor<4x4x4xf32>, tensor<3xi32>) -> tensor<4x4x4xf32>
+ %4 = tosa.transpose %2, %0 : (tensor<4x4x4xf32>, tensor<3xi32>) -> tensor<4x4x4xf32>
+ return %3, %4 : tensor<4x4x4xf32>, tensor<4x4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_two_different_downstream_converge_to_reshape_different_perms
+// CHECK: tosa.const
+// CHECK: tosa.const
+// CHECK: %[[RESHAPE:.*]] = tosa.reshape
+// CHECK: %[[CLAMP:.*]] = tosa.clamp %[[RESHAPE]]
+// CHECK: %[[RET1:.*]] = tosa.transpose
+// CHECK: %[[RET2:.*]] = tosa.transpose
+// CHECK: return %[[RET1]], %[[RET2]]
+func.func @test_two_different_downstream_converge_to_reshape_different_perms(%arg0: tensor<64xf32>) -> (tensor<4x4x4xf32>, tensor<4x4x4xf32>) {
+ %0 = "tosa.const"() <{value = dense<[1, 2, 0]> : tensor<3xi32>}> : () -> tensor<3xi32>
+ %1 = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
+ %2 = tosa.reshape %arg0 {new_shape = array<i64: 4, 4, 4>} : (tensor<64xf32>) -> tensor<4x4x4xf32>
+ %3 = tosa.clamp %2 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<4x4x4xf32>) -> tensor<4x4x4xf32>
+ %4 = tosa.transpose %2, %1 : (tensor<4x4x4xf32>, tensor<3xi32>) -> tensor<4x4x4xf32>
+ %5 = tosa.transpose %3, %0 : (tensor<4x4x4xf32>, tensor<3xi32>) -> tensor<4x4x4xf32>
+ return %4, %5 : tensor<4x4x4xf32>, tensor<4x4x4xf32>
+}
+
+// -----
+
+// COM: no transform
+// CHECK-LABEL: @test_outside_perms_usage_of_fan_in
+// CHECK: tosa.const
+// CHECK: tosa.transpose
+// CHECK: tosa.clamp
+// CHECK: %[[RES1:.*]] = tosa.transpose
+// CHECK: tosa.add
+// CHECK: return %[[RES1]]
+func.func @test_outside_perms_usage_of_fan_in(%arg0: tensor<2x3xf32>, %arg1: tensor<3x2xf32>) -> (tensor<2x3xf32>, tensor<3x2xf32>) { %0 = "tosa.const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32>
+ %1 = tosa.transpose %arg0, %0 : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
+ %2 = tosa.clamp %1 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<3x2xf32>) -> tensor<3x2xf32>
+ %3 = tosa.transpose %2, %0 : (tensor<3x2xf32>, tensor<2xi32>) -> tensor<2x3xf32>
+ %4 = tosa.add %arg1, %2 : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32>
+ return %3, %4: tensor<2x3xf32>, tensor<3x2xf32>
+}
+
+// -----
+
+// COM: this use-case is important for ResNet. we want to allow these, but disallow if falls into an illegal area (outside fan-ins that get converted)
+// COM: since then we would get duplicate ops
+// CHECK-LABEL: @test_use_present_in_another_valid_perms_fan_in
+// CHECK-DAG: %[[NEW_CLAMP:.*]] = tosa.clamp %arg0
+// CHECK-DAG: %[[NEW_ADD:.*]] = tosa.add %arg1, %[[NEW_CLAMP]]
+// CHECK: return %[[NEW_CLAMP]], %[[NEW_ADD]]
+func.func @test_use_present_in_another_valid_perms_fan_in(%arg0: tensor<2x3xf32>, %arg1: tensor<2x3xf32>) -> (tensor<2x3xf32>, tensor<2x3xf32>) {
+ %0 = "tosa.const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32>
+ %1 = tosa.transpose %arg0, %0 : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
+ %2 = tosa.clamp %1 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<3x2xf32>) -> tensor<3x2xf32>
+ %3 = tosa.transpose %2, %0 : (tensor<3x2xf32>, tensor<2xi32>) -> tensor<2x3xf32>
+ %4 = tosa.transpose %arg1, %0 : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
+ %5 = tosa.add %4, %2 : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32>
+ %6 = tosa.transpose %5, %0 : (tensor<3x2xf32>, tensor<2xi32>) -> tensor<2x3xf32>
+ return %3, %6: tensor<2x3xf32>, tensor<2x3xf32>
+}
+
+// -----
+
+// COM: no transform, since we would get duplicates
+// CHECK-LABEL: @test_two_same_perms_fan_in_but_one_doesnt_convert_dependents
+// CHECK: tosa.const
+// CHECK: tosa.transpose
+// CHECK: %[[CEIL:.*]] = tosa.ceil
+// CHECK: %[[ADD:.*]] = tosa.add %[[CEIL]]
+// CHECK: %[[RES1:.*]] = tosa.transpose %[[CEIL]]
+// CHECK: %[[RES2:.*]] = tosa.transpose %[[ADD]]
+// CHECK: return %[[RES1]], %[[RES2]]
+func.func @test_two_same_perms_fan_in_but_one_doesnt_convert_dependents(%arg0: tensor<2x3xi32>, %arg1: tensor<3x2xi32>) -> (tensor<2x3xi32>, tensor<2x3xi32>) {
+ %0 = "tosa.const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32>
+ %1 = tosa.transpose %arg0, %0 : (tensor<2x3xi32>, tensor<2xi32>) -> tensor<3x2xi32>
+ %2 = tosa.ceil %1 : (tensor<3x2xi32>) -> tensor<3x2xi32>
+ %3 = tosa.add %2, %arg1 : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
+ %4 = tosa.transpose %2, %0 : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
+ %5 = tosa.transpose %3, %0 : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
+ return %4, %5 : tensor<2x3xi32>, tensor<2x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_direct_use_in_other_transpose_with_same_perms
+// CHECK: tosa.const
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: tosa.clamp
+// CHECK-NEXT: %[[RES:.*]] = tosa.clamp %arg0
+// CHECK-NEXT: tosa.const
+// CHECK-NOT: tosa.transpose
+// CHECK-NEXT: return %[[RES]], %[[RES]]
+func.func @test_direct_use_in_other_transpose_with_same_perms(%arg0: tensor<3x3x3x3xi32>) -> (tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>) {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %0 = tosa.transpose %arg0, %perms0 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ %clamp = tosa.clamp %0 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %1 = tosa.transpose %clamp, %perms1 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ %2 = tosa.transpose %clamp, %perms1 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ return %1, %2 : tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_const_transpose
+// CHECK: %[[NEW:.*]] = "tosa.const"() <{value = dense<0> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
+// CHECK-NOT: tosa.transpose
+// CHECK: return %[[NEW]]
+func.func @test_const_transpose() -> tensor<2x3xi32> {
+ %0 = "tosa.const"() {value = dense<0> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
+ %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+ %1 = tosa.transpose %0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
+ return %1 : tensor<2x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_transpose_tracks_to_const__single_step
+// CHECK: %[[NEW_CONST:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x2x3x4xi32>}> : () -> tensor<1x2x3x4xi32>
+// CHECK: %[[NEW_CLAMP:.*]] = tosa.clamp %[[NEW_CONST]] {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32>
+// CHECK-NOT: tosa.transpose
+// CHECK: return %[[NEW_CLAMP]]
+func.func @test_transpose_tracks_to_const__single_step() -> tensor<1x2x3x4xi32> {
+ %0 = "tosa.const"() {value = dense<0> : tensor<1x3x4x2xi32>} : () -> tensor<1x3x4x2xi32>
+ %clamp = tosa.clamp %0 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %1 = tosa.transpose %clamp, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ return %1 : tensor<1x2x3x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_static_unary_path_to_const
+// CHECK: %[[NEW_CONST:.*]] = "tosa.const"() <{value = dense<1> : tensor<1x2x3x4xi32>}> : () -> tensor<1x2x3x4xi32>
+// CHECK: %[[NEW_CLAMP:.*]] = tosa.clamp %[[NEW_CONST]] {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32>
+// CHECK: %[[NEW_ABS:.*]] = tosa.abs %[[NEW_CLAMP]] : (tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32>
+// CHECK: %[[NEW_NOT:.*]] = tosa.bitwise_not %[[NEW_ABS]] : (tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32>
+// CHECK: return %[[NEW_NOT]]
+func.func @test_static_unary_path_to_const() -> tensor<1x2x3x4xi32> {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %0 = "tosa.const"() {value = dense<1> : tensor<1x3x4x2xi32>} : () -> tensor<1x3x4x2xi32>
+ %clamp = tosa.clamp %0 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %abs = tosa.abs %clamp : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %bitwise_not = tosa.bitwise_not %abs : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %1 = tosa.transpose %bitwise_not, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ return %1 : tensor<1x2x3x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_static_diverges_to_non_splat_const_and_nullifying
+// CHECK: tosa.const
+// CHECK: tosa.const
+// CHECK: %[[NEW_CONST:.*]] = "tosa.const"()
+// CHECK-SAME{LITERAL}: dense<[[[[1, 3, 5, 7], [9, 11, 13, 15], [17, 19, 21, 23]], [[2, 4, 6, 8], [10, 12, 14, 16], [18, 20, 22, 24]]]]>
+// CHECK: tensor<1x2x3x4xi32>}> : () -> tensor<1x2x3x4xi32>
+// CHECK: %[[NEW_CLAMP:.*]] = tosa.clamp %arg0 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32>
+// CHECK: %[[NEW_ABS:.*]] = tosa.abs %[[NEW_CONST]] : (tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32>
+// CHECK: %[[NEW_ADD:.*]] = tosa.add %[[NEW_ABS]], %[[NEW_CLAMP]] : (tensor<1x2x3x4xi32>, tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32>
+// CHECK: return %[[NEW_ADD]]
+func.func @test_static_diverges_to_non_splat_const_and_nullifying(%arg0: tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %transpose0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %const = "tosa.const"() {value = dense<[[[[1, 2], [3, 4], [5, 6], [7, 8]],
+ [[9, 10], [11, 12], [13, 14], [15, 16]],
+ [[17, 18], [19, 20], [21, 22], [23, 24]]]]> : tensor<1x3x4x2xi32>} : () -> tensor<1x3x4x2xi32>
+ %clamp = tosa.clamp %transpose0 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %abs = tosa.abs %const : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %add = tosa.add %abs, %clamp : (tensor<1x3x4x2xi32>, tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %perms2 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %result = tosa.transpose %add, %perms2 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ return %result : tensor<1x2x3x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_multi_downstream_both_nullify
+// CHECK: tosa.const
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: tosa.clamp
+// CHECK-NEXT: %[[RES:.*]] = tosa.clamp %arg0
+// CHECK-NEXT: tosa.const
+// CHECK-NOT: tosa.transpose
+// CHECK-NEXT: return %[[RES]], %[[RES]]
+func.func @test_multi_downstream_both_nullify(%arg0: tensor<3x3x3x3xi32>) -> (tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>) {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %0 = tosa.transpose %arg0, %perms0 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ %clamp = tosa.clamp %0 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %1 = tosa.transpose %clamp, %perms1 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ %2 = tosa.transpose %clamp, %perms1 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ return %1, %2 : tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>
+}
+
+// -----
+
+// COM: we don't perform this transformation intentionally, since we would then get duplicates
+// CHECK-LABEL:test_multi_downstream_one_nullifies_upstream_other_does_not
+// CHECK: tosa.const
+// CHECK: tosa.transpose
+// CHECK: tosa.clamp
+// CHECK: tosa.const
+// CHECK: tosa.transpose
+// CHECK: tosa.transpose
+func.func @test_multi_downstream_one_nullifies_upstream_other_does_not(%arg0: tensor<3x3x3x3xi32>) -> (tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>) {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %0 = tosa.transpose %arg0, %perms0 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ %clamp = tosa.clamp %0 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %1 = tosa.transpose %clamp, %perms1 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ %2 = tosa.transpose %clamp, %perms0 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32>
+ return %1, %2 : tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-remove-redundant-transposes/tosa-remove-redundant-transposes-unimplemented.mlir b/mlir/test/Dialect/Tosa/tosa-remove-redundant-transposes/tosa-remove-redundant-transposes-unimplemented.mlir
new file mode 100644
index 00000000000000..bef5706c07ef83
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-remove-redundant-transposes/tosa-remove-redundant-transposes-unimplemented.mlir
@@ -0,0 +1,120 @@
+// RUN: mlir-opt --verify-diagnostics --split-input-file --verify-each --tosa-remove-redundant-transposes %s | FileCheck %s
+
+// COM: we cannot do anything to the transpose in this case.
+// CHECK-LABEL: @test_non_const_perms
+// CHECK: tosa.const
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: return
+func.func @test_non_const_perms(%perms: tensor<2xi32>) -> tensor<?x?xi32> {
+ %0 = "tosa.const"() {value = dense<0> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
+ %1 = tosa.transpose %0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<?x?xi32>
+ return %1 : tensor<?x?xi32>
+}
+
+// -----
+
+// COM: due to tracking back to a non-nullifying transpose, we can't get rid of the transposes entirely.
+// COM: later editions of the pass may wish to fold these into a single transpose.
+// CHECK-LABEL: @test_transpose_tracks_to_non_nullifying_transpose__single_step
+// CHECK: tosa.const
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: tosa.clamp
+// CHECK-NEXT: tosa.const
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: return
+func.func @test_transpose_tracks_to_non_nullifying_transpose__single_step(%arg0: tensor<1x2x3x4xi32>) -> tensor<1x2x4x3xi32> {
+ %perms0 = "tosa.const"() {value = dense<[0, 3, 2, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x4x3x2xi32>
+ %clamp = tosa.clamp %0 {min_int = 0 : i64, max_int = 1 : i64, min_fp = 0.0 : f64, max_fp = 1.0 : f64} : (tensor<1x4x3x2xi32>) -> tensor<1x4x3x2xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %1 = tosa.transpose %clamp, %perms1 : (tensor<1x4x3x2xi32>, tensor<4xi32>) -> tensor<1x2x4x3xi32>
+ return %1 : tensor<1x2x4x3xi32>
+}
+
+// -----
+
+// COM: we don't deal with this case. --tosa-input-shapes is required.
+// CHECK-LABEL: @test_unknown_dim_input_nullifying_pair
+// CHECK: tosa.const
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: return
+func.func @test_unknown_dim_input_nullifying_pair(%arg0: tensor<3x?xi32>) -> tensor<3x2xi32> {
+ %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+ %0 = tosa.transpose %arg0, %perms : (tensor<3x?xi32>, tensor<2xi32>) -> tensor<2x3xi32>
+ %1 = tosa.transpose %0, %perms : (tensor<2x3xi32>, tensor<2xi32>) -> tensor<3x2xi32>
+ return %1 : tensor<3x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_unknown_dim__replacement_does_not_match
+// CHECK: tosa.const
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: return
+func.func @test_unknown_dim__replacement_does_not_match(%arg0: tensor<3x?xi32>) -> tensor<?x?xi32> {
+ %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+ %0 = tosa.transpose %arg0, %perms : (tensor<3x?xi32>, tensor<2xi32>) -> tensor<?x3xi32>
+ %1 = tosa.transpose %0, %perms : (tensor<?x3xi32>, tensor<2xi32>) -> tensor<?x?xi32>
+ return %1 : tensor<?x?xi32>
+}
+
+// -----
+
+// COM: this would be able to be converted if --tosa-infer-shapes was run beforehand
+// CHECK-LABEL: @test_unranked_tensors_present
+// CHECK: tosa.const
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: return
+func.func @test_unranked_tensors_present(%arg0: tensor<3x2xi32>) -> tensor<*xi32> {
+ %perms = "tosa.const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> tensor<2xi32>
+ %0 = tosa.transpose %arg0, %perms : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<*xi32>
+ %1 = tosa.transpose %0, %perms : (tensor<*xi32>, tensor<2xi32>) -> tensor<*xi32>
+ return %1 : tensor<*xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_unranked_everything
+// CHECK: tosa.const
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: return
+func.func @test_unranked_everything(%arg0: tensor<*xi32>) -> tensor<*xi32> {
+ %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+ %0 = tosa.transpose %arg0, %perms : (tensor<*xi32>, tensor<2xi32>) -> tensor<*xi32>
+ %1 = tosa.transpose %0, %perms : (tensor<*xi32>, tensor<2xi32>) -> tensor<*xi32>
+ return %1 : tensor<*xi32>
+}
+
+// -----
+
+// COM: this is an example of some dead code we generate despite no transform taking place.
+// COM: it will be removed by --canonicalize.
+// COM: it's generated because at the add, we track back on the first argument first.
+// CHECK-LABEL: @test_static_diverges_to_one_nullifying_one_non_nullifying
+// CHECK: tosa.const
+// CHECK-NEXT: tosa.const
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: tosa.clamp
+// CHECK-NEXT: tosa.clamp
+// CHECK-NEXT: tosa.abs
+// CHECK-NEXT: tosa.add
+// CHECK-NEXT: tosa.const
+// CHECK-NEXT: tosa.transpose
+// CHECK-NEXT: return
+func.func @test_static_diverges_to_one_nullifying_one_non_nullifying(%arg0: tensor<1x2x3x4xi32>, %arg1: tensor<1x2x4x3xi32>) -> tensor<1x2x3x4xi32> {
+ %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %perms1 = "tosa.const"() {value = dense<[0, 3, 2, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %transpose0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %transpose1 = tosa.transpose %arg1, %perms1 : (tensor<1x2x4x3xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32>
+ %clamp = tosa.clamp %transpose0 {min_int = 0 : i64, max_int = 1 : i64, min_fp = 0.0 : f64, max_fp = 1.0 : f64} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %abs = tosa.abs %transpose1 : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %add = tosa.add %clamp, %abs : (tensor<1x3x4x2xi32>, tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32>
+ %perms2 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %result = tosa.transpose %add, %perms2 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32>
+ return %result : tensor<1x2x3x4xi32>
+}
More information about the Mlir-commits
mailing list