[Mlir-commits] [mlir] [MLIR][TOSA] Add --tosa-remove-redundant-transposes pass (PR #108260)
Jacques Pienaar
llvmlistbot at llvm.org
Fri Sep 13 11:08:06 PDT 2024
================
@@ -0,0 +1,731 @@
+//===- TosaRemoveRedundantTransposes.cpp
+//------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+// ----------
+// Motivation:
+// ----------
+
+// Some legalization pathways introduce redundant tosa.TRANSPOSE
+// operations that result in avoidable data movement. For example,
+// PyTorch -> TOSA contains a lot of unnecessary transposes due
+// to conversions between NCHW and NHWC.
+
+// We wish to remove all the ones that we can, since in general
+// it is possible to remove the overwhelming majority.
+
+// -------------------
+// High-Level Overview:
+// -------------------
+
+// The pass begins at a downstream transpose with some perms tensor.
+// It traverses the dependencies upward, accepting only TosaElementwise
+// operators. Dependencies must terminate in nullifying transposes (when
+// composed, they form the identity), reshapes we can fold the transpose into,
+// or consts.
+
+// Conceptually, we then "bubble up" the downstream transpose until
+// we hit the sources. For constants, we generate a new constants, composed
+// with the downstream transpose. For nullifying transposes, we "cancel"
+// them. For reshapes, we fold the transpose into them.
+
+// We then ensure that we do not cause any duplication by "converting"
+// this chain we bubbled-up into its transposed form. We do this by analyzing
+// the dependency fan-ins across all transposes with the same perms tensor
+// in order to ensure that they do not have uses outside this group, which
+// would cause the old code section to remain "live", and not removed by
+// DCE.
+
+// We then perform a simple one-pass DCE, so no canonicalization is necessary.
+
+// --------------
+// Impact of Pass:
+// --------------
+
+// We note that up to 98.3% of transpose data movement and 98.0%
+// of transposes can be removed from MobilenetV3 and ResNet networks.
+
+// -----------
+// Future Work:
+// -----------
+
+// (1)
+
+// Evaluate tradeoffs with the duplication of ConstOp, especially
+// across many downstream transposes with different perms, which can result
+// in the same ConstOp being duplicated (but transposed) multiple times.
+
+// Observe tradeoffs between a lower memory footprint and potentially
+// converting many fan-ins of downstream transposes with the same perms,
+// which if not converted may affect ability of other inter-dependent fan-in
+// to convert.
+
+// (2)
+
+// Expand the class of foldable upstream ReshapeOp we permit beyond
+// N -> 1x1x...x1xNx1x...x1x1.
+
+// (3)
+
+// Make the pass more general, beyond just allowing upstream transposes
+// to be nullifying. For example,
+
+// transpose1 -> ... -> transpose2
+
+// where transpose2(transpose1) do not cancel to identity.
+
+// This can be done by propagating the downstream transpose up
+// and inserting after transpose1, just like how it is done for
+// reshape. However, in the case of chains like
+
+// transpose1 -> ... -> transpose2 -> ... -> transpose3
+
+// this could require running the current runOnOperation() function
+// until we converge. This can be done by stopping when all transposes
+// that we can successfully collect the fan-ins of have the owner
+// of their first operand being either another TransposeOp or a
+// ReshapeOp, since those are what we propagate to and where we leave
+// behind / insert another TransposeOp. Otherwise, we would could potentially
+// have infinite looping.
+
+// Folding of the transposes is then necessary.
+
+// (4)
+
+// Add support for more instructions (for example, those that reduce
+// alongside an axis) to be one of the intervening operations in the
+// fan-in cones (other than those with TosaElementwiseOperator trait).
+
+// (5)
+
+// Support bubbling transposes up to the input parameter. May not
+// need extensive fan-in analysis as no operation cost associated
+// if used elsewhere.
+
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/IR/Iterators.h"
+#include "mlir/IR/Matchers.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include <memory>
+#include <set>
+#include <stack>
+
+namespace mlir {
+namespace tosa {
+#define GEN_PASS_DEF_TOSAREMOVEREDUNDANTTRANSPOSES
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+} // namespace tosa
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+//===----------------------------------------------------------------------===//
+// TOSA Remove Redundant Transposes Pass.
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+struct TosaRemoveRedundantTransposes final
+ : public tosa::impl::TosaRemoveRedundantTransposesBase<
+ TosaRemoveRedundantTransposes> {
+ void runOnOperation() override;
+
+private:
+ // This will collect all the data dependencies for the given Operation
+ // up to and including ConstOp, ReshapeOp, and TransposeOp.
+ bool collectFanIn(Operation *op, SetVector<Operation *> &collected);
+ bool convertDependentOps(SetVector<Operation *> &dependentOps,
+ DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter,
+ ArrayRef<int32_t> downstreamPerms);
+
+ // Checks if the two permutations, when applied consecutively, result
+ // in the identity.
+ bool areNullifyingTransposes(ArrayRef<int32_t> perms1,
+ ArrayRef<int32_t> perms2);
+
+ // This is meant to apply to operations with the TosaElementwiseOperator
+ // trait.
+ std::optional<Value>
+ buildMappedToValue(Operation *op, const DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms);
+
+ // This updates valuesMap when we encounter another TransposeOp as a
+ // dependency of the downstream one. %0 = tosa.transpose %arg0 <- applies to
+ // this %1 = tosa.transpose %0 <- when tracking back from this
+ std::optional<Value>
+ buildMappedToValue(TransposeOp transposeOp,
+ const DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms);
+
+ // Inserts the downstream TransposeOp after the ReshapeOp, since we generally
+ // cannot propagate through it.
+ std::optional<Value>
+ buildMappedToValue(ReshapeOp reshapeOp,
+ const DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms);
+
+ // We may have something like:
+ // %0 = tosa.const
+ // %1 = tosa.transpose
+ // %2 = tosa.add %0, %1
+ // %3 = tosa.transpose %2
+ // that --tosa-layerwise-const-fold wouldn't handle. This use shows up
+ // in MobilenetV3.
+ std::optional<Value>
+ buildMappedToValue(ConstOp constOp, const DenseMap<Value, Value> &valuesMap,
+ IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms);
+
+ // Checks which TransposeOp we should "replace", turning their converted
+ // chains of ops, through which they were propagated, "live", and the old code
+ // "dead." Attempts to avoid doing so when doing so would result in the old
+ // code staying "live," resulting in duplication.
+ std::set<TransposeOp> getGoodReplacements(
+ ArrayRef<int32_t> perms,
+ std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
+ &transposeInfo);
+
+ // Helper function for getGoodReplacements to check if some TransposeOp's
+ // dependencies are OK.
+ bool dependenciesAreValid(
+ ArrayRef<int32_t> perms, const SetVector<Operation *> &dependentOps,
+ std::set<TransposeOp> &validTransposes,
+ std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
+ &transposeInfo);
+
+ // Applies perms to the DenseElementsAttr.
+ // If it returns std::nullopt, it also triggers pass failure, since verifier
+ // guarantees from TOSA are not in place (and otherwise, if used elsewhere
+ // it should fail).
+ // This is a basic API and may benefit from refactor into the core MLIR APIs.
+ std::optional<DenseElementsAttr>
+ transposeDenseAttribute(DenseElementsAttr input, ArrayRef<int32_t> perms);
+};
+
+std::optional<DenseElementsAttr>
+TosaRemoveRedundantTransposes::transposeDenseAttribute(
+ DenseElementsAttr input, ArrayRef<int32_t> perms) {
+ RankedTensorType oldType = llvm::cast<RankedTensorType>(input.getType());
+ RankedTensorType newType =
+ RankedTensorType::get(applyTOSAPermutation(oldType.getShape(), perms),
+ oldType.getElementType());
+ size_t rank = oldType.getRank();
+
+ if (input.isSplat())
+ return input.reshape(newType);
+ // Asserted by TransposeOp verifier and TOSA disallowing tensor with dimension
+ // 0.
+ // If not in place, something is very wrong.
+ if (rank <= 0 || oldType.getNumElements() <= 0 || perms.size() != rank) {
+ signalPassFailure();
+ return std::nullopt;
+ }
+
+ // The algorithm is approximately as follows:
+ // input: perms, input flat array, input tensor type
+ // (1/2) determine the strides of input/output if
+ // they were strided in row-major order. (3) adjust the strides for the
+ // input to be in the same order of indices as the output is written.
+ // (4) process dimension by dimension. example: perms 2, 0, 1; input
+ // 2x3x4; output 4x2x3 for i ... 4, j ... 2, k ... 3: output[i][j][k] =
+ // input[j][k][i] output[6i + 3j + k] = input[12j + 4k + i] and we adjust
+ // input strides to be as input[i + 12j + 4k] so we may process
+ // layer-by-layer.
+
+ // Step 1/2: Strides for input. We ignore output since row-major and can just
+ // push_back.
+
+ SmallVector<int64_t> originalInputStrides(rank);
+ originalInputStrides[rank - 1] = 1;
+ // index with int64_t to avoid overflow
+ for (int64_t i = rank - 2; i >= 0; i--)
+ originalInputStrides[i] =
+ originalInputStrides[i + 1] * oldType.getDimSize(i + 1);
+
+ // Step 3: Transpose strides of input to be same indexing (i, j, k, ...) as
+ // output which is done in row-major order.
+
+ SmallVector<int64_t> newInputStrides;
+ newInputStrides.reserve(rank);
+ for (int32_t v : perms)
+ newInputStrides.push_back(originalInputStrides[v]);
+
+ // Step 4: Write out the transposed "flat array" dimension by dimension.
+
+ auto inputArray = input.getValues<Attribute>();
+ SmallVector<std::pair<int64_t, int64_t>> boundsAndStrides;
+ for (size_t i = 0; i < rank; i++)
+ boundsAndStrides.push_back({newType.getDimSize(i), newInputStrides[i]});
+
+ SmallVector<Attribute> resultArray;
+ resultArray.reserve(inputArray.size());
+
+ std::function<void(int64_t,
+ SmallVector<std::pair<int64_t, int64_t>>::const_iterator)>
+ processTransposeDim = [&](auto accumulatedIndex, auto it) {
+ if (it == boundsAndStrides.end()) {
+ resultArray.push_back(inputArray[accumulatedIndex]);
+ return;
+ }
+
+ for (int64_t i = 0; i < it->first; i++) {
+ int64_t j = accumulatedIndex + i * it->second;
+ processTransposeDim(j, it + 1);
+ }
+ };
+
+ processTransposeDim(0, boundsAndStrides.begin());
+
+ return DenseElementsAttr::get(newType, resultArray);
+}
+
+// The SetVector should only contain ConstOp, ReshapeOp, TransposeOp
+// as the sources of the data dependencies, and TosaElementWiseOperator
+// after that, if the function returns true.
+bool TosaRemoveRedundantTransposes::collectFanIn(
+ Operation *op, SetVector<Operation *> &collected) {
+ // Can occur if defined through the parameter to a func.func.
+ if (!op)
+ return false;
+
+ if (!llvm::isa_and_present<TosaDialect>(op->getDialect()))
+ return false;
+
+ // Prevent extra work if already seen.
+ if (collected.contains(op))
+ return true;
+
+ // Throw it out so later don't have to deal with this.
+ if (op->getNumResults() != 1 ||
+ !llvm::isa<RankedTensorType>(op->getResult(0).getType()))
+ return false;
+
+ // We don't wish to traverse up a ReshapeOp,
+ // since generally we can't propagate a TransposeOp through it.
+ // TransposeOp, ReshapeOp, ConstOp will have no in-edges in the data
+ // dependency graph we construct for the downstream TransposeOp.
+ if (!llvm::isa<TransposeOp>(op) && !llvm::isa<ReshapeOp>(op) &&
+ !llvm::isa<ConstOp>(op)) {
+
+ if (!op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>())
+ return false;
+
+ for (Value operand : op->getOperands()) {
+
+ if (!collectFanIn(operand.getDefiningOp(), collected))
+ return false;
+ }
+ }
+
+ // Insert in topological order.
+ collected.insert(op);
+
+ return true;
+}
+
+// Assuming that due to the verification of TransposeOp
+// perms arrays are permutations of 0 - perms.size() - 1.
+bool TosaRemoveRedundantTransposes::areNullifyingTransposes(
----------------
jpienaar wrote:
involution is the name of the trait that is being used elsewhere to refer to these.
https://github.com/llvm/llvm-project/pull/108260
More information about the Mlir-commits
mailing list