[Mlir-commits] [mlir] [MLIR][TOSA] Add --tosa-remove-redundant-transposes pass (PR #108260)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Sep 11 10:45:59 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tosa
Author: Arteen Abrishami (arteen1000)
<details>
<summary>Changes</summary>
----------
Motivation:
----------
Some legalization pathways introduce redundant tosa.TRANSPOSE operations that result in avoidable data movement. For example, PyTorch -> TOSA contains a lot of unnecessary transposes due to conversions between NCHW and NHWC.
We wish to remove all the ones that we can, since in general it is possible to remove upwards of 90% of these transposes in a provable manner.
------------
Changes Made:
------------
- Add the --tosa-remove-redundant-transposes pass
- Add TosaElementwiseOperator trait.
-------------------
High-Level Overview:
-------------------
The pass begins at a downstream transpose with some perms tensor. It traverses the dependencies upward, accepting only TosaElementwise operators. Dependencies must terminate in nullifying transposes (when composed, they form the identity), reshapes, or consts.
Conceptually, we then "bubble up" the downstream transpose until we hit the sources. For constants, we generate a new constants, composed with the downstream transpose. For nullifying transposes, we "cancel" them. For reshapes, we generally cannot "bubble" through them, so we insert the downstream transpose there.
We then ensure that we do not cause any duplication by "converting" this chain we bubbled-up into its transposed form. We do this by analyzing the dependency fan-ins across all transposes with the same perms tensor in order to ensure that they do not have uses outside this group, which would cause the old code section to remain "live", and not removed by canonicalization.
--------------
Impact of Pass:
--------------
For the ResNet18 network, we are able to reduce it to 5 transposes, from 56 -- with the patching of the torch dense_resource artifacts with dense attributes. Otherwise, without that patch, we reduce to 23, since we cannot fold those artifacts.
In the second case (56 -> 23), instruction count is reduced by exactly 33. There are 3 transposes that would be removed if we omitted the fan-in analysis, however, with fan-in analysis, we end up with ~15 less operations, due to the lack of duplication.
For ResNet50, the results are essentially identical.
For MobilenetV3, we reduce the number of transposes from 82 to 38 without taking care of upstream constants. After also taking care of constants, we reduce it to 20 transposes. The remaining have a use elsewhere outside of the fan-in cones. The pass alone (after --canonicalize is run on the initial network), is responsible for the removal of 48 of the transposes.
Due to cases where a constant is used purely in its NCHW form without a transpose to NHWC and also separately used in a place where the downstream converts to NHWC, we do end up with 7 additional constants; however, due to their small size, this has minimal memory footprint.
-----------
Future Work:
-----------
(1)
Evaluate tradeoffs with the duplication of ConstOp, especially across many downstream transposes with different perms, which can result in the same ConstOp being duplicated (but transposed) multiple times.
Observe tradeoffs between a lower memory footprint and potentially converting many fan-ins of downstream transposes with the same perms, which if not converted may affect ability of other inter-dependent fan-in to convert.
(2)
Restrict the propagation of transposes up their fan-in cone if one of the sources is a ReshapeOp for which the inserted TransposeOp would not be a TransposeOp that lends itself to the TransposeIsReshape Canonicalization, which permits them to be folded to a single ReshapeOp.
Observe impact on how this restriction may be detrimental to the conversion of other downstream transpose conversions due to the fan-in cone analysis. Additionally, consider cases where there may be multiple upstream transposes that could be removed as a result of this -- and trade that off with how many you would effectively insert if the ReshapeOp/TransposeOp can't be folded to a single ReshapeOp.
(3)
Make the pass more general, beyond just allowing upstream transposes to be nullifying. For example,
transpose1 -> ... -> transpose2
where transpose2(transpose1) do not cancel to identity.
This can be done by propagating the downstream transpose up and inserting after transpose1, just like how it is done for reshape. However, in the case of chains like
transpose1 -> ... -> transpose2 -> ... -> transpose3
this could require running the current runOnOperation() function until we converge. This can be done by stopping when all transposes that we can successfully collect the fan-ins of have the owner of their first operand being either another TransposeOp or a ReshapeOp, since those are what we propagate to and where we leave behind / insert another TransposeOp. Otherwise, we would could potentially have infinite looping.
This additionally has the implication that we would not replace any transposes and instead we could have canonicalization handle that.
(4)
Add support for more instructions (for example, those that reduce alongside an axis) to be one of the intervening operations in the fan-in cones (other than those with TosaElementwiseOperator trait).
(5)
Support bubbling transposes up to the input parameter. May not need extensive fan-in analysis as no operation cost associated if used elsewhere.
---
Patch is 87.17 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/108260.diff
10 Files Affected:
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+10)
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h (+6)
- (modified) mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h (+2)
- (modified) mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td (+16)
- (modified) mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt (+1)
- (added) mlir/lib/Dialect/Tosa/Transforms/TosaRemoveRedundantTransposes.cpp (+761)
- (added) mlir/test/Dialect/Tosa/tosa-remove-redundant-transposes/tosa-remove-redundant-transposes-dynamic-shapes.mlir (+55)
- (added) mlir/test/Dialect/Tosa/tosa-remove-redundant-transposes/tosa-remove-redundant-transposes-pipeline.mlir (+48)
- (added) mlir/test/Dialect/Tosa/tosa-remove-redundant-transposes/tosa-remove-redundant-transposes-static-shapes.mlir (+552)
- (added) mlir/test/Dialect/Tosa/tosa-remove-redundant-transposes/tosa-remove-redundant-transposes-unimplemented.mlir (+120)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 1412c7a2615d20..64bacd0e432fe5 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -206,6 +206,15 @@ def Tosa_ExplicitValuePadOpQuantInfoBuilder : OpBuilder<
input, paddings, pad_value);
}]>;
+//===----------------------------------------------------------------------===//
+// TOSA Operator Trait.
+//===----------------------------------------------------------------------===//
+
+// Permits broadcasting. Elementwise trait is too strict.
+def TosaElementwiseOperator : NativeOpTrait<"TosaElementwiseOperator"> {
+ let cppNamespace = "mlir::OpTrait::tosa";
+}
+
//===----------------------------------------------------------------------===//
// TOSA Operator Class.
//===----------------------------------------------------------------------===//
@@ -219,6 +228,7 @@ class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
ResultsBroadcastableShape,
+ TosaElementwiseOperator,
Pure])> {
let assemblyFormat =
"operands attr-dict `:` functional-type(operands, results)";
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index 7ed89bff474a2e..8122752a9f3e1a 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -84,6 +84,12 @@ class MulOperandsAndResultElementType
}
};
+/// This class indicates that an op is tosa-elementwise (permits broadcasting,
+/// unlike Elementwise trait)
+template <typename ConcreteType>
+class TosaElementwiseOperator
+ : public TraitBase<ConcreteType, TosaElementwiseOperator> {};
+
} // namespace tosa
} // namespace OpTrait
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index 1f9522b51a4cf5..c0913171f9b17f 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -14,6 +14,7 @@
#define MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/PassesEnums.h.inc"
#include "mlir/Pass/Pass.h"
@@ -48,6 +49,7 @@ std::unique_ptr<Pass> createTosaInferShapesPass();
std::unique_ptr<Pass> createTosaMakeBroadcastablePass();
std::unique_ptr<Pass> createTosaTestQuantUtilAPIPass();
std::unique_ptr<Pass> createTosaOptionalDecompositions();
+std::unique_ptr<Pass> createTosaRemoveRedundantTransposes();
struct ValidationOptions {
/// Validate if operations match for the given profile.
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index a0f670de20150f..66d046c7040b4f 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -126,4 +126,20 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
];
}
+def TosaRemoveRedundantTransposes : Pass<"tosa-remove-redundant-transposes", "func::FuncOp"> {
+ let summary = "Remove redundant transposes";
+ let description = [{
+ Pass that identifies and removes redundant tosa.TRANSPOSE operations.
+ It does so by traversing dependencies of tosa.TRANSPOSE operations until they terminate in either
+ tosa.RESHAPE, a nullifying tosa.TRANSPOSE, or a tosa.CONST. It then propagates the downstream
+ transform upward through the intervening operators if it is able and replaces the downstream tosa.TRANSPOSE.
+ Results generally better when run after Canonicalization and resolution of dynamic shapes.
+ Canonicalization is required for dead code elimination after pass is run.
+ This pass has an important use-case in cleaning up the results of frameworks that introduce a lot
+ of data-layout transformations when legalizing to TOSA, a common one being transformations between NHWC and NCHW
+ layouts.
+ }];
+ let constructor = "tosa::createTosaRemoveRedundantTransposes()";
+}
+
#endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index c78a74b874aff1..624038b9b38981 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
TosaLayerwiseConstantFoldPass.cpp
TosaMakeBroadcastable.cpp
TosaOptionalDecompositions.cpp
+ TosaRemoveRedundantTransposes.cpp
TosaTypeConverters.cpp
TosaValidation.cpp
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaRemoveRedundantTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaRemoveRedundantTransposes.cpp
new file mode 100644
index 00000000000000..06d7754e79a023
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaRemoveRedundantTransposes.cpp
@@ -0,0 +1,761 @@
+//===- TosaRemoveRedundantTransposes.cpp
+//------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// ----------
+// Motivation:
+// ----------
+
+// Some legalization pathways introduce redundant tosa.TRANSPOSE
+// operations that result in avoidable data movement. For example,
+// PyTorch -> TOSA contains a lot of unnecessary transposes due
+// to conversions between NCHW and NHWC.
+
+// We wish to remove all the ones that we can, since in general
+// it is possible to remove upwards of 90% of these transposes
+// in a provable manner.
+//
+// -------------------
+// High-Level Overview:
+// -------------------
+
+// The pass begins at a downstream transpose with some perms tensor.
+// It traverses the dependencies upward, accepting only TosaElementwise
+// operators. Dependencies must terminate in nullifying transposes (when
+// composed, they form the identity), reshapes, or consts.
+
+// Conceptually, we then "bubble up" the downstream transpose until
+// we hit the sources. For constants, we generate a new constants, composed
+// with the downstream transpose. For nullifying transposes, we "cancel"
+// them. For reshapes, we generally cannot "bubble" through them, so we
+// insert the downstream transpose there.
+
+// We then ensure that we do not cause any duplication by replacing usages
+// of the downstream transpose with the converted value of the operand
+// that feeds into it (after this bubble-up process). We do this by analyzing
+// the dependency fan-ins across all transposes with the same perms tensor
+// in order to ensure that they do not have uses outside this group, which
+// would cause the old code section to remain "live", and not removed by
+// canonicalization.
+
+// --------------
+// Impact of Pass:
+// --------------
+
+// For the ResNet18 network, we are able to reduce it to 5 transposes, from
+// 56 -- with the patching of the torch dense_resource artifacts with dense
+// attributes. Otherwise, without that patch, we reduce to 23, since we cannot
+// fold those artifacts.
+
+// In the second case (56 -> 23), instruction count is reduced by exactly 33.
+// There are 3 transposes that would be removed if we omitted the fan-in
+// analysis, however, with fan-in analysis, we end up with ~15 less operations,
+// due to the lack of duplication.
+
+// For ResNet50, the results are essentially identical.
+
+// For MobilenetV3, we reduce the number of transposes from 82 to 38 without
+// taking care of upstream constants. After also taking care of constants, we
+// reduce it to 20 transposes. The remaining have a use elsewhere outside
+// of the fan-in cones. The pass alone (after --canonicalize is run on the
+// initial network), is responsible for the removal of 48 of the transposes.
+
+// Due to cases where a constant is used purely in its NCHW form without a
+// transpose to NHWC and also separately used in a place where the downstream
+// converts to NHWC, we do end up with 7 additional constants; however, due to
+// their small size, this has minimal memory footprint.
+
+// -----------
+// Future Work:
+// -----------
+
+// (1)
+
+// Evaluate tradeoffs with the duplication of ConstOp, especially
+// across many downstream transposes with different perms, which can result
+// in the same ConstOp being duplicated (but transposed) multiple times.
+
+// Observe tradeoffs between a lower memory footprint and potentially
+// converting many fan-ins of downstream transposes with the same perms,
+// which if not converted may affect ability of other inter-dependent fan-in
+// to convert.
+
+// (2)
+
+// Restrict the propagation of transposes up their fan-in cone if one
+// of the sources is a ReshapeOp for which the inserted TransposeOp would
+// not be a TransposeOp that lends itself to the TransposeIsReshape
+// Canonicalization, which permits them to be folded to a single ReshapeOp.
+
+// Observe impact on how this restriction may be detrimental to the
+// conversion of other downstream transpose conversions due to the
+// fan-in cone analysis. Additionally, consider cases where there
+// may be multiple upstream transposes that could be removed as a
+// result of this -- and trade that off with how many you would
+// effectively insert if the ReshapeOp/TransposeOp can't be folded
+// to a single ReshapeOp.
+
+// (3)
+
+// Make the pass more general, beyond just allowing upstream transposes
+// to be nullifying. For example,
+
+// transpose1 -> ... -> transpose2
+
+// where transpose2(transpose1) do not cancel to identity.
+
+// This can be done by propagating the downstream transpose up
+// and inserting after transpose1, just like how it is done for
+// reshape. However, in the case of chains like
+
+// transpose1 -> ... -> transpose2 -> ... -> transpose3
+
+// this could require running the current runOnOperation() function
+// until we converge. This can be done by stopping when all transposes
+// that we can successfully collect the fan-ins of have the owner
+// of their first operand being either another TransposeOp or a
+// ReshapeOp, since those are what we propagate to and where we leave
+// behind / insert another TransposeOp. Otherwise, we would could potentially
+// have infinite looping.
+
+// This additionally has the implication that we would not replace any
+// transposes and instead we could have canonicalization handle that.
+
+// (4)
+
+// Add support for more instructions (for example, those that reduce
+// alongside an axis) to be one of the intervening operations in the
+// fan-in cones (other than those with TosaElementwiseOperator trait).
+
+// (5)
+
+// Support bubbling transposes up to the input parameter. May not
+// need extensive fan-in analysis as no operation cost associated
+// if used elsewhere.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/IR/Matchers.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include <memory>
+#include <set>
+#include <stack>
+
+namespace mlir {
+namespace tosa {
+#define GEN_PASS_DEF_TOSAREMOVEREDUNDANTTRANSPOSES
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+} // namespace tosa
+} // namespace mlir
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// TOSA Remove Redundant Transposes Pass.
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+struct TosaRemoveRedundantTransposes final
+ : public tosa::impl::TosaRemoveRedundantTransposesBase<
+ TosaRemoveRedundantTransposes> {
+ void runOnOperation() override;
+
+private:
+ // This will collect all the data dependencies for the given Operation
+ // up to and including ConstOp, ReshapeOp, and TransposeOp.
+ bool collectFanIn(Operation *op, SetVector<Operation *> &collected);
+ bool convertDependentOps(SetVector<Operation *> &dependentOps,
+ DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter,
+ ArrayRef<int32_t> downstreamPerms);
+
+ // Checks if the two permutations, when applied consecutively, result
+ // in the identity.
+ bool areNullifyingTransposes(ArrayRef<int32_t> perms1,
+ ArrayRef<int32_t> perms2);
+
+ // This is meant to apply to operations with the TosaElementwiseOperator
+ // trait.
+ std::optional<Value>
+ buildMappedToValue(Operation *op, const DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms);
+
+ // This updates valuesMap when we encounter another TransposeOp as a
+ // dependency of the downstream one. %0 = tosa.transpose %arg0 <- applies to
+ // this %1 = tosa.transpose %0 <- when tracking back from this
+ std::optional<Value>
+ buildMappedToValue(tosa::TransposeOp transposeOp,
+ const DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms);
+
+ // Inserts the downstream TransposeOp after the ReshapeOp, since we generally
+ // cannot propagate through it.
+ std::optional<Value>
+ buildMappedToValue(tosa::ReshapeOp reshapeOp,
+ const DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms);
+
+ // We may have something like:
+ // %0 = tosa.const
+ // %1 = tosa.transpose
+ // %2 = tosa.add %0, %1
+ // %3 = tosa.transpose %2
+ // that --tosa-layerwise-const-fold wouldn't handle. This use shows up
+ // in MobilenetV3.
+ std::optional<Value>
+ buildMappedToValue(tosa::ConstOp constOp,
+ const DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms);
+
+ // Checks which TransposeOp we should "replace", turning their converted
+ // chains of ops, through which they were propagated, "live", and the old code
+ // "dead." Attempts to avoid doing so when doing so would result in the old
+ // code staying "live," resulting in duplication. Relies on --canonicalize to
+ // remove the dead code that results from performing said replacement.
+ std::set<tosa::TransposeOp> getGoodReplacements(
+ ArrayRef<int32_t> perms,
+ std::vector<std::pair<tosa::TransposeOp, SetVector<Operation *>>>
+ &transposeInfo);
+
+ // Helper function for getGoodReplacements to check if some TransposeOp's
+ // dependencies are OK.
+ bool dependenciesAreValid(
+ ArrayRef<int32_t> perms, const SetVector<Operation *> &dependentOps,
+ std::set<tosa::TransposeOp> &validTransposes,
+ std::vector<std::pair<tosa::TransposeOp, SetVector<Operation *>>>
+ &transposeInfo);
+
+ // Applies perms to the DenseElementsAttr.
+ // If it returns std::nullopt, it also triggers pass failure, since verifier
+ // guarantees from TOSA are not in place (and otherwise, if used elsewhere
+ // it should fail).
+ // This is a basic API and may benefit from refactor into the core MLIR APIs.
+ std::optional<DenseElementsAttr>
+ transposeDenseAttribute(DenseElementsAttr input, ArrayRef<int32_t> perms);
+};
+
+std::optional<DenseElementsAttr>
+TosaRemoveRedundantTransposes::transposeDenseAttribute(
+ DenseElementsAttr input, ArrayRef<int32_t> perms) {
+ RankedTensorType oldType = llvm::cast<RankedTensorType>(input.getType());
+ RankedTensorType newType = RankedTensorType::get(
+ tosa::applyTOSAPermutation(oldType.getShape(), perms),
+ oldType.getElementType());
+ size_t rank = oldType.getRank();
+
+ if (input.isSplat())
+ return input.reshape(newType);
+ // Asserted by TransposeOp verifier and TOSA disallowing tensor with dimension
+ // 0.
+ // If not in place, something is very wrong.
+ if (rank <= 0 || oldType.getNumElements() <= 0 || perms.size() != rank) {
+ signalPassFailure();
+ return std::nullopt;
+ }
+
+ // The algorithm is approximately as follows:
+ // input: perms, input flat array, input tensor type
+ // (1/2) determine the strides of input/output if
+ // they were strided in row-major order. (3) adjust the strides for the
+ // input to be in the same order of indices as the output is written.
+ // (4) process dimension by dimension. example: perms 2, 0, 1; input
+ // 2x3x4; output 4x2x3 for i ... 4, j ... 2, k ... 3: output[i][j][k] =
+ // input[j][k][i] output[6i + 3j + k] = input[12j + 4k + i] and we adjust
+ // input strides to be as input[i + 12j + 4k] so we may process
+ // layer-by-layer.
+
+ // Step 1/2: Strides for input. We ignore output since row-major and can just
+ // push_back.
+
+ SmallVector<int64_t> originalInputStrides(rank);
+ originalInputStrides[rank - 1] = 1;
+ // index with int64_t to avoid overflow
+ for (int64_t i = rank - 2; i >= 0; i--)
+ originalInputStrides[i] =
+ originalInputStrides[i + 1] * oldType.getDimSize(i + 1);
+
+ // Step 3: Transpose strides of input to be same indexing (i, j, k, ...) as
+ // output which is done in row-major order.
+
+ SmallVector<int64_t> newInputStrides;
+ newInputStrides.reserve(rank);
+ for (int32_t v : perms)
+ newInputStrides.push_back(originalInputStrides[v]);
+
+ // Step 4: Write out the transposed "flat array" dimension by dimension.
+
+ auto inputArray = input.getValues<Attribute>();
+ SmallVector<std::pair<int64_t, int64_t>> boundsAndStrides;
+ for (size_t i = 0; i < rank; i++)
+ boundsAndStrides.push_back({newType.getDimSize(i), newInputStrides[i]});
+
+ SmallVector<Attribute> resultArray;
+ resultArray.reserve(inputArray.size());
+
+ std::function<void(int64_t,
+ SmallVector<std::pair<int64_t, int64_t>>::const_iterator)>
+ processTransposeDim = [&](auto accumulatedIndex, auto it) {
+ if (it == boundsAndStrides.end()) {
+ resultArray.push_back(inputArray[accumulatedIndex]);
+ return;
+ }
+
+ for (int64_t i = 0; i < it->first; i++) {
+ int64_t j = accumulatedIndex + i * it->second;
+ processTransposeDim(j, it + 1);
+ }
+ };
+
+ processTransposeDim(0, boundsAndStrides.begin());
+
+ return DenseElementsAttr::get(newType, resultArray);
+}
+
+// The SetVector should only contain ConstOp, ReshapeOp, TransposeOp
+// as the sources of the data dependencies, and TosaElementWiseOperator
+// after that, if the function returns true.
+bool TosaRemoveRedundantTransposes::collectFanIn(
+ Operation *op, SetVector<Operation *> &collected) {
+ // Can occur if defined through the parameter to a func.func.
+ if (!op)
+ return false;
+
+ if (!llvm::isa_and_present<tosa::TosaDialect>(op->getDialect()))
+ return false;
+
+ // Prevent extra work if already seen.
+ if (collected.contains(op))
+ return true;
+
+ // Throw it out so later don't have to deal with this.
+ if (op->getNumResults() != 1 ||
+ !llvm::isa<RankedTensorType>(op->getResult(0).getType()))
+ return false;
+
+ // We don't wish to traverse up a ReshapeOp,
+ // since generally we can't propagate a TransposeOp through it.
+ // TransposeOp, ReshapeOp, ConstOp will have no in-edges in the data
+ // dependency graph we construct for the downstream TransposeOp.
+ if (!llvm::isa<tosa::TransposeOp>(op) && !llvm::isa<tosa::ReshapeOp>(op) &&
+ !llvm::isa<tosa::ConstOp>(op)) {
+
+ if (!op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>())
+ return false;
+
+ for (Value operand : op->getOperands()) {
+
+ if (!collectFanIn(operand.getDefiningOp(), collected))
+ return false;
+ }
+ }
+
+ // Insert in topological order.
+ collected.insert(op);
+
+ return true;
+}
+
+// Assuming that due to the verification of TransposeOp
+// perms arrays are permutations of 0 - perms.size() - 1.
+bool TosaRemoveRedundantTransposes::areNullifyingTransposes(
+ ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/108260
More information about the Mlir-commits
mailing list