[Mlir-commits] [mlir] [MLIR][TOSA] Add --tosa-remove-redundant-transposes pass (PR #108260)

Suraj Sudhir llvmlistbot at llvm.org
Fri Sep 13 07:20:32 PDT 2024

@@ -0,0 +1,733 @@
+//===- TosaRemoveRedundantTransposes.cpp
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+// ----------
+// Motivation:
+// ----------
+// Some legalization pathways introduce redundant tosa.TRANSPOSE
+// operations that result in avoidable data movement. For example,
+// PyTorch -> TOSA contains a lot of unnecessary transposes due
+// to conversions between NCHW and NHWC.
+// We wish to remove all the ones that we can, since in general
+// it is possible to remove the overwhelming majority.
+// -------------------
+// High-Level Overview:
+// -------------------
+// The pass begins at a downstream transpose with some perms tensor.
+// It traverses the dependencies upward, accepting only TosaElementwise
+// operators. Dependencies must terminate in nullifying transposes (when
+// composed, they form the identity), reshapes we can fold the transpose into,
+// or consts.
+// Conceptually, we then "bubble up" the downstream transpose until
+// we hit the sources. For constants, we generate a new constants, composed
+// with the downstream transpose. For nullifying transposes, we "cancel"
+// them. For reshapes, we fold the transpose into them.
+// We then ensure that we do not cause any duplication by "converting"
+// this chain we bubbled-up into its transposed form. We do this by analyzing
+// the dependency fan-ins across all transposes with the same perms tensor
+// in order to ensure that they do not have uses outside this group, which
+// would cause the old code section to remain "live", and not removed by
+// DCE.
+// We then perform a simple one-pass DCE, so no canonicalization is necessary.
+// --------------
+// Impact of Pass:
+// --------------
+// We note that up to 98.3% of transpose data movement and 98.0%
+// of transposes can be removed from MobilenetV3 and ResNet networks.
+// -----------
+// Future Work:
+// -----------
+// (1)
+// Evaluate tradeoffs with the duplication of ConstOp, especially
+// across many downstream transposes with different perms, which can result
+// in the same ConstOp being duplicated (but transposed) multiple times.
+// Observe tradeoffs between a lower memory footprint and potentially
+// converting many fan-ins of downstream transposes with the same perms,
+// which if not converted may affect ability of other inter-dependent fan-in
+// to convert.
+// (2)
+// Expand the class of foldable upstream ReshapeOp we permit beyond
+// N -> 1x1x...x1xNx1x...x1x1.
+// (3)
+// Make the pass more general, beyond just allowing upstream transposes
+// to be nullifying. For example,
+// transpose1 -> ... -> transpose2
+// where transpose2(transpose1) do not cancel to identity.
+// This can be done by propagating the downstream transpose up
+// and inserting after transpose1, just like how it is done for
+// reshape. However, in the case of chains like
+// transpose1 -> ... -> transpose2 -> ... -> transpose3
+// this could require running the current runOnOperation() function
+// until we converge. This can be done by stopping when all transposes
+// that we can successfully collect the fan-ins of have the owner
+// of their first operand being either another TransposeOp or a
+// ReshapeOp, since those are what we propagate to and where we leave
+// behind / insert another TransposeOp. Otherwise, we would could potentially
+// have infinite looping.
+// Folding of the transposes is then necessary.
+// (4)
+// Add support for more instructions (for example, those that reduce
+// alongside an axis) to be one of the intervening operations in the
+// fan-in cones (other than those with TosaElementwiseOperator trait).
+// (5)
+// Support bubbling transposes up to the input parameter. May not
+// need extensive fan-in analysis as no operation cost associated
+// if used elsewhere.
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/IR/Iterators.h"
+#include "mlir/IR/Matchers.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include <memory>
+#include <set>
+#include <stack>
+namespace mlir {
+namespace tosa {
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+} // namespace tosa
+} // namespace mlir
+using namespace mlir;
+// TOSA Remove Redundant Transposes Pass.
+namespace {
+struct TosaRemoveRedundantTransposes final
+    : public tosa::impl::TosaRemoveRedundantTransposesBase<
+          TosaRemoveRedundantTransposes> {
+  void runOnOperation() override;
+  // This will collect all the data dependencies for the given Operation
+  // up to and including ConstOp, ReshapeOp, and TransposeOp.
+  bool collectFanIn(Operation *op, SetVector<Operation *> &collected);
+  bool convertDependentOps(SetVector<Operation *> &dependentOps,
+                           DenseMap<Value, Value> &valuesMap,
+                           IRRewriter &rewriter,
+                           ArrayRef<int32_t> downstreamPerms);
+  // Checks if the two permutations, when applied consecutively, result
+  // in the identity.
+  bool areNullifyingTransposes(ArrayRef<int32_t> perms1,
+                               ArrayRef<int32_t> perms2);
+  // This is meant to apply to operations with the TosaElementwiseOperator
+  // trait.
+  std::optional<Value>
+  buildMappedToValue(Operation *op, const DenseMap<Value, Value> &valuesMap,
+                     IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms);
+  // This updates valuesMap when we encounter another TransposeOp as a
+  // dependency of the downstream one. %0 = tosa.transpose %arg0 <- applies to
+  // this %1 = tosa.transpose %0 <- when tracking back from this
+  std::optional<Value>
+  buildMappedToValue(tosa::TransposeOp transposeOp,
+                     const DenseMap<Value, Value> &valuesMap,
+                     IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms);
+  // Inserts the downstream TransposeOp after the ReshapeOp, since we generally
+  // cannot propagate through it.
+  std::optional<Value>
+  buildMappedToValue(tosa::ReshapeOp reshapeOp,
+                     const DenseMap<Value, Value> &valuesMap,
+                     IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms);
+  // We may have something like:
+  // %0 = tosa.const
+  // %1 = tosa.transpose
+  // %2 = tosa.add %0, %1
+  // %3 = tosa.transpose %2
+  // that --tosa-layerwise-const-fold wouldn't handle. This use shows up
+  // in MobilenetV3.
+  std::optional<Value>
+  buildMappedToValue(tosa::ConstOp constOp,
+                     const DenseMap<Value, Value> &valuesMap,
+                     IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms);
+  // Checks which TransposeOp we should "replace", turning their converted
+  // chains of ops, through which they were propagated, "live", and the old code
+  // "dead." Attempts to avoid doing so when doing so would result in the old
+  // code staying "live," resulting in duplication.
+  std::set<tosa::TransposeOp> getGoodReplacements(
+      ArrayRef<int32_t> perms,
+      std::vector<std::pair<tosa::TransposeOp, SetVector<Operation *>>>
+          &transposeInfo);
+  // Helper function for getGoodReplacements to check if some TransposeOp's
+  // dependencies are OK.
+  bool dependenciesAreValid(
+      ArrayRef<int32_t> perms, const SetVector<Operation *> &dependentOps,
+      std::set<tosa::TransposeOp> &validTransposes,
+      std::vector<std::pair<tosa::TransposeOp, SetVector<Operation *>>>
+          &transposeInfo);
+  // Applies perms to the DenseElementsAttr.
+  // If it returns std::nullopt, it also triggers pass failure, since verifier
+  // guarantees from TOSA are not in place (and otherwise, if used elsewhere
+  // it should fail).
+  // This is a basic API and may benefit from refactor into the core MLIR APIs.
+  std::optional<DenseElementsAttr>
+  transposeDenseAttribute(DenseElementsAttr input, ArrayRef<int32_t> perms);
+    DenseElementsAttr input, ArrayRef<int32_t> perms) {
+  RankedTensorType oldType = llvm::cast<RankedTensorType>(input.getType());
+  RankedTensorType newType = RankedTensorType::get(
+      tosa::applyTOSAPermutation(oldType.getShape(), perms),
+      oldType.getElementType());
+  size_t rank = oldType.getRank();
+  if (input.isSplat())
+    return input.reshape(newType);
+  // Asserted by TransposeOp verifier and TOSA disallowing tensor with dimension
+  // 0.
+  // If not in place, something is very wrong.
+  if (rank <= 0 || oldType.getNumElements() <= 0 || perms.size() != rank) {
+    signalPassFailure();
+    return std::nullopt;
+  }
+  // The algorithm is approximately as follows:
+  // input: perms, input flat array, input tensor type
+  // (1/2) determine the strides of input/output if
+  // they were strided in row-major order. (3) adjust the strides for the
+  // input to be in the same order of indices as the output is written.
+  // (4) process dimension by dimension. example: perms 2, 0, 1; input
+  // 2x3x4; output 4x2x3 for i ... 4, j ... 2, k ... 3: output[i][j][k] =
+  // input[j][k][i] output[6i + 3j + k] = input[12j + 4k + i] and we adjust
+  // input strides to be as input[i + 12j + 4k] so we may process
+  // layer-by-layer.
+  // Step 1/2: Strides for input. We ignore output since row-major and can just
+  // push_back.
+  SmallVector<int64_t> originalInputStrides(rank);
+  originalInputStrides[rank - 1] = 1;
+  // index with int64_t to avoid overflow
+  for (int64_t i = rank - 2; i >= 0; i--)
+    originalInputStrides[i] =
+        originalInputStrides[i + 1] * oldType.getDimSize(i + 1);
+  // Step 3: Transpose strides of input to be same indexing (i, j, k, ...) as
+  // output which is done in row-major order.
+  SmallVector<int64_t> newInputStrides;
+  newInputStrides.reserve(rank);
+  for (int32_t v : perms)
+    newInputStrides.push_back(originalInputStrides[v]);
+  // Step 4: Write out the transposed "flat array" dimension by dimension.
+  auto inputArray = input.getValues<Attribute>();
+  SmallVector<std::pair<int64_t, int64_t>> boundsAndStrides;
+  for (size_t i = 0; i < rank; i++)
+    boundsAndStrides.push_back({newType.getDimSize(i), newInputStrides[i]});
+  SmallVector<Attribute> resultArray;
+  resultArray.reserve(inputArray.size());
+  std::function<void(int64_t,
+                     SmallVector<std::pair<int64_t, int64_t>>::const_iterator)>
+      processTransposeDim = [&](auto accumulatedIndex, auto it) {
+        if (it == boundsAndStrides.end()) {
+          resultArray.push_back(inputArray[accumulatedIndex]);
+          return;
+        }
+        for (int64_t i = 0; i < it->first; i++) {
+          int64_t j = accumulatedIndex + i * it->second;
+          processTransposeDim(j, it + 1);
+        }
+      };
+  processTransposeDim(0, boundsAndStrides.begin());
+  return DenseElementsAttr::get(newType, resultArray);
+// The SetVector should only contain ConstOp, ReshapeOp, TransposeOp
+// as the sources of the data dependencies, and TosaElementWiseOperator
+// after that, if the function returns true.
+bool TosaRemoveRedundantTransposes::collectFanIn(
+    Operation *op, SetVector<Operation *> &collected) {
+  // Can occur if defined through the parameter to a func.func.
+  if (!op)
+    return false;
+  if (!llvm::isa_and_present<tosa::TosaDialect>(op->getDialect()))
+    return false;
+  // Prevent extra work if already seen.
+  if (collected.contains(op))
+    return true;
+  // Throw it out so later don't have to deal with this.
+  if (op->getNumResults() != 1 ||
+      !llvm::isa<RankedTensorType>(op->getResult(0).getType()))
+    return false;
+  // We don't wish to traverse up a ReshapeOp,
+  // since generally we can't propagate a TransposeOp through it.
+  // TransposeOp, ReshapeOp, ConstOp will have no in-edges in the data
+  // dependency graph we construct for the downstream TransposeOp.
+  if (!llvm::isa<tosa::TransposeOp>(op) && !llvm::isa<tosa::ReshapeOp>(op) &&
+      !llvm::isa<tosa::ConstOp>(op)) {
+    if (!op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>())
+      return false;
+    for (Value operand : op->getOperands()) {
+      if (!collectFanIn(operand.getDefiningOp(), collected))
+        return false;
+    }
+  }
+  // Insert in topological order.
+  collected.insert(op);
+  return true;
+// Assuming that due to the verification of TransposeOp
+// perms arrays are permutations of 0 - perms.size() - 1.
+bool TosaRemoveRedundantTransposes::areNullifyingTransposes(
+    ArrayRef<int32_t> perms1, ArrayRef<int32_t> perms2) {
+  if (perms1.size() != perms2.size())
+    return false;
+  for (int32_t i = 0; i < static_cast<int32_t>(perms1.size()); i++)
+    if (perms2[perms1[i]] != i)
+      return false;
+  return true;
+// Primary overload for those with TosaElementwiseOperator trait.
+// The other ones handle the case of the operations that occur at the
+// roots of the data dependency graph (ConstOp, ReshapeOp, TransposeOp).
+std::optional<Value> TosaRemoveRedundantTransposes::buildMappedToValue(
+    Operation *op, const DenseMap<Value, Value> &valuesMap,
+    IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms) {
+  if (op->getNumResults() != 1 ||
+      !op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>())
+    return std::nullopt;
+  auto resultType = llvm::cast<RankedTensorType>(op->getResult(0).getType());
+  SmallVector<Value, 3> operands;
+  for (Value v : op->getOperands()) {
+    if (valuesMap.contains(v)) {
+      operands.push_back(valuesMap.at(v));
+    } else {
+      return std::nullopt;
+    }
+  }
+  // Conceptually, we propagate the downstream TransposeOp through
+  // these interveaning operations. For example,
+  // %0 = tosa.clamp %input : (tensor<2x3xi32>) -> tensor<2x3xi32>
+  // %1 = tosa.transpose %0 {perms = [1, 0]} : (tensor<2x3xi32>) ->
+  // tensor<3x2xi32> becomes: %0 = tosa.transpose %input {perms = [1, 0]} :
+  // (tensor<2x3xi32>) -> tensor<3x2xi32> %1 = tosa.clamp %0 : (tensor<3x2xi32>)
+  // -> tensor<3x2xi32>) We construct this new tosa.clamp here, but it doesn't
+  // turn "live" until the final downstream transpose in the chain (that we are
+  // currently traversing up its dependencies) is replaced with the proper value
+  // from this new chain.
+  return rewriter
+      .create(op->getLoc(),
+              rewriter.getStringAttr(op->getName().getStringRef()), operands,
+              RankedTensorType::get(tosa::applyTOSAPermutation(
+                                        resultType.getShape(), downstreamPerms),
+                                    resultType.getElementType()),
+              op->getAttrs())
+      ->getResult(0);
+std::optional<Value> TosaRemoveRedundantTransposes::buildMappedToValue(
+    tosa::TransposeOp transposeOp, const DenseMap<Value, Value> &valuesMap,
+    IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms) {
+  SmallVector<int32_t> perms;
+  if (failed(transposeOp.getConstantPerms(perms)) ||
+      !areNullifyingTransposes(downstreamPerms, perms))
+    return std::nullopt;
+  return transposeOp.getInput1();
+std::optional<Value> TosaRemoveRedundantTransposes::buildMappedToValue(
+    tosa::ReshapeOp reshapeOp, const DenseMap<Value, Value> &valuesMap,
+    IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms) {
+  auto reshapeOutput = reshapeOp.getOutput();
+  auto reshapeInputType =
+      llvm::dyn_cast<RankedTensorType>(reshapeOp.getInput1().getType());
+  auto reshapeInputShape = reshapeInputType.getShape();
+  // want reshape N -> 1x1x...x1xNx1x...x1x1
+  if (!reshapeInputType || reshapeInputShape.size() != 1)
+    return std::nullopt;
+  auto reshapeOutputType =
+      llvm::cast<RankedTensorType>(reshapeOutput.getType());
+  // Instead of inserting a TransposeOp here, we
+  // check if we can fold it into the ReshapeOp.
+  // There is more complex cases where this is possible,
+  // and this check can be extended.
+  // Checking if reshape is N -> 1x1x...x1xNx1x...x1x1
+  auto shape = reshapeOutputType.getShape();
+  size_t ones = llvm::count(shape, 1);
+  // N == 1 and N != 1
+  if (ones != shape.size() - 1 &&
+      !(ones == shape.size() && reshapeInputShape[0] == 1))
+    return std::nullopt;
+  // Do not insert a TransposeOp, instead we fold the reshape and its attribute.
+  auto foldedReshape = rewriter.create<tosa::ReshapeOp>(
+      reshapeOp.getLoc(),
+      RankedTensorType::get(tosa::applyTOSAPermutation(shape, downstreamPerms),
+                            reshapeOutputType.getElementType()),
+      reshapeOp.getInput1(),
+      rewriter.getDenseI64ArrayAttr(tosa::applyTOSAPermutation(
+          reshapeOp.getNewShape(), downstreamPerms)));
+  return foldedReshape->getResult(0);
+std::optional<Value> TosaRemoveRedundantTransposes::buildMappedToValue(
+    tosa::ConstOp constOp, const DenseMap<Value, Value> &valuesMap,
+    IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms) {
+  auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(constOp.getValue());
+  if (!denseAttr)
+    return std::nullopt;
+  auto maybeNewDenseAttr = transposeDenseAttribute(denseAttr, downstreamPerms);
+  if (!maybeNewDenseAttr.has_value())
+    return std::nullopt;
+  auto newDenseAttr = maybeNewDenseAttr.value();
+  auto newConstOp = rewriter.create<tosa::ConstOp>(
+      constOp.getLoc(), newDenseAttr.getType(), newDenseAttr);
+  return newConstOp->getResult(0);
+bool TosaRemoveRedundantTransposes::convertDependentOps(
+    SetVector<Operation *> &dependentOps, DenseMap<Value, Value> &valuesMap,
+    IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms) {
+  for (Operation *op : dependentOps) {
+    if (!op || op->getNumResults() != 1)
+      return false;
+    Value priorValue = op->getResult(0);
+    // It's possible on a prior transposeOp
+    // we had the same dependency and already resolved it.
+    if (valuesMap.contains(priorValue))
+      continue;
+    // Keep converted ops close to the original.
+    rewriter.setInsertionPointAfter(op);
+    std::optional<Value> maybeValue =
+        llvm::TypeSwitch<Operation *, std::optional<Value>>(op)
+            .Case<tosa::TransposeOp>([&](tosa::TransposeOp transposeOp) {
+              return buildMappedToValue(transposeOp, valuesMap, rewriter,
+                                        downstreamPerms);
+            })
+            .Case<tosa::ReshapeOp>([&](tosa::ReshapeOp reshapeOp) {
+              return buildMappedToValue(reshapeOp, valuesMap, rewriter,
+                                        downstreamPerms);
+            })
+            .Case<tosa::ConstOp>([&](tosa::ConstOp constOp) {
+              return buildMappedToValue(constOp, valuesMap, rewriter,
+                                        downstreamPerms);
+            })
+            .Default([&](Operation *op) {
+              return buildMappedToValue(op, valuesMap, rewriter,
+                                        downstreamPerms);
+            });
+    if (!maybeValue.has_value())
+      return false;
+    valuesMap[priorValue] = maybeValue.value();
+  }
+  return true;
+// Dependencies are valid for an operation if none of them occur outside
+// of the proper fan-in cones of the downstream TransposeOp with the same perms
+// that we can replace. Described in more detail within.
+bool TosaRemoveRedundantTransposes::dependenciesAreValid(
+    ArrayRef<int32_t> perms, const SetVector<Operation *> &dependentOps,
+    std::set<tosa::TransposeOp> &validTransposes,
+    std::vector<std::pair<tosa::TransposeOp, SetVector<Operation *>>>
+        &transposeInfo) {
+  for (Operation *op : dependentOps) {
+    // It's OK wherever ConstOp has uses -- in the worst case, we duplicate.
+    // This can be changed later if we find the memory impact is too high.
+    if (llvm::isa<tosa::ConstOp>(op))
+      continue;
+    for (OpOperand &use : op->getUses()) {
+      // Want the uses to be (1) contained in the dependentOps of other
+      // validTransposes, or (2) to be directly used in a TransposeOp with the
+      // same perms. For (2) it means the fan-in is a subset of our
+      // dependentOps, so it is also a validTranspose that will eventually be
+      // replaced.
+      Operation *user = use.getOwner();
+      if (auto otherTranspose = llvm::dyn_cast<tosa::TransposeOp>(user)) {
+        SmallVector<int32_t> otherPerms;
+        // Can later think about cases where transpose -> transpose
+        // or reshape -> transpose, where the transposes are not necessarily
+        // the same perms as the downstream, if implementing a more general
+        // transform. These could be permitted.
+        if (failed(otherTranspose.getConstantPerms(otherPerms)) ||
+            !llvm::equal(perms, otherPerms))
+          return false;
+      } else if (llvm::none_of(
+                     transposeInfo,
+                     [&validTransposes,
+                      user](const std::pair<tosa::TransposeOp,
+                                            SetVector<Operation *>> &info) {
+                       const auto &[transposeOp, dependentOps] = info;
+                       return validTransposes.count(transposeOp) &&
+                              dependentOps.contains(user);
+                     })) {
+        return false;
+      }
+    }
+  }
+  return true;
+// Getting the set of TransposeOp that we can replace without causing
+// the old fan-in cones of any TransposeOp to remain "live", i.e, -- not being
+// dead code. This is done by iterating the set until convergence, since
+// if you are used outside your own fan-in cone, it's possible to be used
+// in another fan-in cone of a TransposeOp that is being replaced -- unless
+// we find that that one has a usage outside of it too.
+std::set<tosa::TransposeOp> TosaRemoveRedundantTransposes::getGoodReplacements(
+    ArrayRef<int32_t> perms,
+    std::vector<std::pair<tosa::TransposeOp, SetVector<Operation *>>>
+        &transposeInfo) {
+  // Initially, we assume they are all good to replace,
+  // and we whittle them down based on our criteria.
+  std::set<tosa::TransposeOp> ableToReplace;
+  for (const auto &[transposeOp, _] : transposeInfo)
+    ableToReplace.insert(transposeOp);
+  bool gotRid;
+  do {
+    gotRid = false;
+    for (const auto &[transposeOp, dependentOps] : transposeInfo) {
+      // We don't care about it. Already invalidated.
+      if (!ableToReplace.count(transposeOp))
+        continue;
+      // Check for validity.
+      if (!dependenciesAreValid(perms, dependentOps, ableToReplace,
+                                transposeInfo)) {
+        ableToReplace.erase(transposeOp);
+        gotRid = true;
+        break;
+      }
+    }
+  } while (gotRid);
+  return ableToReplace;
+void TosaRemoveRedundantTransposes::runOnOperation() {
+  // We want to operate only within a single block.
+  // Call --inline before to run the pass.
sjarus wrote:

This is probably a stale comment - it's leftover from early conversation on TBD work on control flow ops that have nested region bodies and other such constructs. Would removing this comment resolve this, @joker-eph ?


More information about the Mlir-commits mailing list