[Mlir-commits] [mlir] 110c1b6 - [mlir][tosa] Improve performance of tosa.transpose constant folding

Robert Suderman llvmlistbot at llvm.org
Fri Mar 24 12:51:16 PDT 2023


Author: Spenser Bauman
Date: 2023-03-24T19:50:13Z
New Revision: 110c1b64a7b9984a604aa2809e0fb8c01278609d

URL: https://github.com/llvm/llvm-project/commit/110c1b64a7b9984a604aa2809e0fb8c01278609d
DIFF: https://github.com/llvm/llvm-project/commit/110c1b64a7b9984a604aa2809e0fb8c01278609d.diff

LOG: [mlir][tosa] Improve performance of tosa.transpose constant folding

Folding of the tosa.transpose operation is both time and memory
intensive as the underlying ElementsAttr is processed as a sequence of
Attributes. This change attempts operate on the underlying raw data of
the ElementsAttr.

In an example resnet50 network, this change reduces the time spent in
folding transpose ops from 35s to 1.5s.

Reviewed By: GeorgeARM, rsuderman, stellaraccident

Differential Revision: https://reviews.llvm.org/D146526

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp
    mlir/test/Dialect/Tosa/constant-op-fold.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp
index 3147073819235..9e2102ee1d0ab 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/Pass/Pass.h"
 
@@ -20,6 +21,82 @@ using namespace mlir::tosa;
 
 namespace {
 
+template <typename BaseType>
+DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType,
+                                ShapedType outputType,
+                                llvm::ArrayRef<int64_t> permValues) {
+  if (inputType.getNumElements() == 0)
+    return DenseElementsAttr::get(outputType, llvm::ArrayRef<BaseType>{});
+
+  auto attrValues = attr.getValues<BaseType>();
+  auto inputShape = inputType.getShape();
+
+  // The inverted permutation map and strides of the output are used to compute
+  // the contribution of a given dimension to the destination linear index in
+  // an order-independent way.
+  auto outputStrides = computeStrides(outputType.getShape());
+  auto invertedPermValues = invertPermutationVector(permValues);
+
+  auto initialValue = *std::begin(attrValues);
+  SmallVector<BaseType> outputValues(inputType.getNumElements(), initialValue);
+
+  for (const auto &it : llvm::enumerate(attrValues)) {
+    auto srcLinearIndex = it.index();
+
+    uint64_t dstLinearIndex = 0;
+    for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) {
+      // Compute the index into the current dimension of the source vector.
+      auto sourceIndexForDim = srcLinearIndex % inputShape[dim];
+      srcLinearIndex /= inputShape[dim];
+
+      // Add the contribution of the current dimension to the output using the
+      // permutation map.
+      dstLinearIndex +=
+          outputStrides[invertedPermValues[dim]] * sourceIndexForDim;
+    }
+
+    outputValues[dstLinearIndex] = it.value();
+  }
+
+  return DenseElementsAttr::get(outputType,
+                                llvm::ArrayRef<BaseType>(outputValues));
+}
+
+// A type specialized transposition of an ElementsAttr.
+// This implementation tries to operate on the underlying data in its raw
+// representation when possible to avoid allocating a large number of Attribute
+// objects.
+DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType,
+                            ShapedType outputType,
+                            llvm::ArrayRef<int64_t> permValues) {
+  auto baseType = inputType.getElementType();
+
+  // Handle possible integer types
+  if (auto intType = baseType.dyn_cast<IntegerType>()) {
+    switch (intType.getWidth()) {
+    case 1:
+      return transposeType<bool>(attr, inputType, outputType, permValues);
+    case 8:
+      return transposeType<int8_t>(attr, inputType, outputType, permValues);
+    case 16:
+      return transposeType<int16_t>(attr, inputType, outputType, permValues);
+    case 32:
+      return transposeType<int32_t>(attr, inputType, outputType, permValues);
+    case 64:
+      return transposeType<int64_t>(attr, inputType, outputType, permValues);
+    default:
+      return transposeType<APInt>(attr, inputType, outputType, permValues);
+    }
+  }
+
+  // Handle possible float types
+  if (baseType.isF32()) {
+    return transposeType<float>(attr, inputType, outputType, permValues);
+  }
+
+  return transposeType<APFloat>(attr, inputType, outputType, permValues);
+}
+
 struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -43,41 +120,12 @@ struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
     auto permValues = llvm::to_vector<6>(llvm::map_range(
         // TOSA allows both 32- and 64-bit integer tensors here.
         permAttr.getValues<APInt>(),
-        [](const APInt &val) { return val.getZExtValue(); }));
+        [](const APInt &val) { return val.getSExtValue(); }));
 
     auto inputType = op.getInput1().getType().cast<ShapedType>();
-    ArrayRef<int64_t> inputShape = inputType.getShape();
-    int64_t numElements = inputType.getNumElements();
-
-    SmallVector<Attribute, 4> outputValues;
-    outputValues.resize(numElements);
-
-    // Transpose the input constant. Because we don't know its rank in advance,
-    // we need to loop over the range [0, element count) and delinearize the
-    // index.
-    auto attrValues = inputValues.getValues<Attribute>();
-    ArrayRef<int64_t> outputShape = outputType.getShape();
-    for (const auto &it : llvm::enumerate(attrValues)) {
-      SmallVector<uint64_t, 6> srcIndices(inputType.getRank(), 0);
-      int totalCount = it.index();
-      for (int dim = inputType.getRank() - 1; dim >= 0; --dim) {
-        srcIndices[dim] = totalCount % inputShape[dim];
-        totalCount /= inputShape[dim];
-      }
-
-      SmallVector<uint64_t, 6> dstIndices(outputType.getRank(), 0);
-      for (int dim = outputType.getRank() - 1; dim >= 0; --dim)
-        dstIndices[dim] = srcIndices[permValues[dim]];
-
-      uint64_t dstLinearIndex = dstIndices.front();
-      for (int dim = 1; dim < outputType.getRank(); ++dim)
-        dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
-
-      outputValues[dstLinearIndex] = it.value();
-    }
 
-    rewriter.replaceOpWithNewOp<tosa::ConstOp>(
-        op, outputType, DenseElementsAttr::get(outputType, outputValues));
+    auto resultAttr = transpose(inputValues, inputType, outputType, permValues);
+    rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputType, resultAttr);
     return success();
   }
 };

diff  --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index 54ba37ace3030..6c8f5935b4cb0 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -46,6 +46,17 @@ func.func @transpose_fold_2d_float() -> tensor<3x2xf32> {
   return %1 : tensor<3x2xf32>
 }
 
+// CHECK-LABEL: @transpose_fold_2d_bool
+func.func @transpose_fold_2d_bool() -> tensor<3x2xi1> {
+  %input = "tosa.const"() {value = dense<[[true, false, false], [false, false, true]]> : tensor<2x3xi1>} : () -> tensor<2x3xi1>
+  %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+  //               CHECK: %[[CST:.+]] = "tosa.const"()
+  // CHECK-SAME{LITERAL}: value = dense<[[true, false], [false, false], [false, true]]> : tensor<3x2xi1>
+  %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xi1>, tensor<2xi32>) -> tensor<3x2xi1>
+  // CHECK: return %[[CST]]
+  return %1 : tensor<3x2xi1>
+}
+
 // CHECK-LABEL: @transpose_fold_4d_int
 func.func @transpose_fold_4d_int() -> tensor<3x1x4x2xi32> {
   %input = "tosa.const"() {value = dense<[[


        


More information about the Mlir-commits mailing list