[Mlir-commits] [mlir] e325ebb - [mlir][tosa] Add some transpose folders

Lei Zhang llvmlistbot at llvm.org
Fri Sep 24 12:25:19 PDT 2021


Author: Lei Zhang
Date: 2021-09-24T15:25:14-04:00
New Revision: e325ebb9c70bbdd48866926a42d4c4373b832035

URL: https://github.com/llvm/llvm-project/commit/e325ebb9c70bbdd48866926a42d4c4373b832035
DIFF: https://github.com/llvm/llvm-project/commit/e325ebb9c70bbdd48866926a42d4c4373b832035.diff

LOG: [mlir][tosa] Add some transpose folders

* If the input is a constant splat value, we just
  need to reshape it.
* If the input is a general constant with one user,
  we can also constant fold it, without bloating
  the IR.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D110439

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Dialect/Tosa/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index b5d556ae7cea3..cc8c8fab56e5b 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1534,6 +1534,7 @@ def Tosa_TransposeOp : Tosa_Op<"transpose", [
     outs Tosa_Tensor1Dto6D:$output
   );
 
+  let hasCanonicalizer = 1;
   let hasFolder = 1;
 }
 

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 483493528b3a2..a51d02e865380 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -159,6 +159,71 @@ void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
   results.insert<ReshapeReshapeOptimization>(context);
 }
 
+struct ConstantTransposeOptimization
+    : public OpRewritePattern<tosa::TransposeOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::TransposeOp op,
+                                PatternRewriter &rewriter) const override {
+    DenseElementsAttr inputValues;
+    if (!matchPattern(op.input1(), m_Constant(&inputValues)))
+      return failure();
+    // Make sure the input is a constant that has a single user.
+    if (!llvm::hasSingleElement(op.input1().getDefiningOp()->getUsers()))
+      return failure();
+
+    DenseIntElementsAttr permAttr;
+    if (!matchPattern(op.perms(), m_Constant(&permAttr)))
+      return failure();
+    auto permValues = llvm::to_vector<6>(llvm::map_range(
+        // TOSA allows both 32- and 64-bit integer tensors here.
+        permAttr.getValues<APInt>(),
+        [](const APInt &val) { return val.getZExtValue(); }));
+
+    auto inputType = op.input1().getType().cast<ShapedType>();
+    ArrayRef<int64_t> inputShape = inputType.getShape();
+    int64_t numElements = inputType.getNumElements();
+
+    auto outputType = op.getType().cast<ShapedType>();
+    ArrayRef<int64_t> outputShape = outputType.getShape();
+
+    SmallVector<Attribute, 4> outputValues;
+    outputValues.resize(numElements);
+
+    // Transpose the input constant. Because we don't know its rank in advance,
+    // we need to loop over the range [0, element count) and delinearize the
+    // index.
+    for (int srcLinearIndex = 0; srcLinearIndex < numElements;
+         ++srcLinearIndex) {
+      SmallVector<uint64_t, 6> srcIndices(inputType.getRank(), 0);
+      int totalCount = srcLinearIndex;
+      for (int dim = inputType.getRank() - 1; dim >= 0; --dim) {
+        srcIndices[dim] = totalCount % inputShape[dim];
+        totalCount /= inputShape[dim];
+      }
+
+      SmallVector<uint64_t, 6> dstIndices(outputType.getRank(), 0);
+      for (int dim = outputType.getRank() - 1; dim >= 0; --dim)
+        dstIndices[dim] = srcIndices[permValues[dim]];
+
+      uint64_t dstLinearIndex = dstIndices.front();
+      for (int dim = 1; dim < outputType.getRank(); ++dim)
+        dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
+
+      outputValues[dstLinearIndex] = inputValues.getValue(srcIndices);
+    }
+
+    rewriter.replaceOpWithNewOp<tosa::ConstOp>(
+        op, outputType, DenseElementsAttr::get(outputType, outputValues));
+    return success();
+  }
+};
+
+void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                              MLIRContext *context) {
+  results.insert<ConstantTransposeOptimization>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // Operator Folders.
 //===----------------------------------------------------------------------===//
@@ -225,15 +290,18 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
   if (!operands[1])
     return {};
 
-  DenseIntElementsAttr perms = operands[1].cast<DenseIntElementsAttr>();
-
-  bool isRange = true;
-  for (auto it : llvm::enumerate(perms)) {
-    isRange = isRange &&
-              it.value().getSExtValue() == static_cast<int64_t>(it.index());
+  // Transposing splat values just means reshaping.
+  if (auto input = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {
+    if (input.isSplat())
+      return input.reshape(getType().cast<ShapedType>());
   }
 
-  if (isRange && input1().getType() == getType())
+  auto perms = llvm::to_vector<6>(llvm::map_range(
+      operands[1].cast<DenseIntElementsAttr>().getValues<APInt>(),
+      [](const APInt &val) { return val.getSExtValue(); }));
+
+  if (llvm::equal(llvm::seq<int64_t>(0, perms.size()), perms) &&
+      input1().getType() == getType())
     return input1();
   return {};
 }

diff  --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index e0a9d29403125..a3ede85817564 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --canonicalize %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --canonicalize %s | FileCheck %s
 
 // CHECK-LABEL: @argmax_nofold
 func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
@@ -237,3 +237,80 @@ func @transpose_nofold_shape(%arg0: tensor<3x4xf32>) -> tensor<?x?xf32> {
   %1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: @transpose_fold_splat
+func @transpose_fold_splat() -> tensor<3x2xf32> {
+  %input = "tosa.const"() {value = dense<4.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
+  %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+  //               CHECK: %[[CST:.+]] = "tosa.const"()
+  // CHECK-SAME{LITERAL}: value = dense<4.000000e+00> : tensor<3x2xf32>
+  %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
+  // CHECK: return %[[CST]]
+  return %1 : tensor<3x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_fold_2d_float
+func @transpose_fold_2d_float() -> tensor<3x2xf32> {
+  %input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
+  %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+  //               CHECK: %[[CST:.+]] = "tosa.const"()
+  // CHECK-SAME{LITERAL}: value = dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32>
+  %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
+  // CHECK: return %[[CST]]
+  return %1 : tensor<3x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_fold_4d_int
+func @transpose_fold_4d_int() -> tensor<3x1x4x2xi32> {
+  %input = "tosa.const"() {value = dense<[[
+    [[ 0,  1,  2,  3], [ 4,  5,  6,  7], [ 8,  9, 10, 11]],
+    [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
+  ]]> : tensor<1x2x3x4xi32>} : () -> tensor<1x2x3x4xi32>
+  %perms = "tosa.const"() {value = dense<[2, 0, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64>
+  //               CHECK: %[[CST:.+]] = "tosa.const"()
+  // CHECK-SAME{LITERAL}: value = dense<[
+  // CHECK-SAME{LITERAL}:   [[[0, 12], [1, 13], [2, 14], [3, 15]]],
+  // CHECK-SAME{LITERAL}:   [[[4, 16], [5, 17], [6, 18], [7, 19]]],
+  // CHECK-SAME{LITERAL}:   [[[8, 20], [9, 21], [10, 22], [11, 23]]]
+  // CHECK-SAME{LITERAL}: ]>
+  %1 = "tosa.transpose"(%input, %perms) : (tensor<1x2x3x4xi32>, tensor<4xi64>) -> tensor<3x1x4x2xi32>
+  // CHECK: return %[[CST]]
+  return %1 : tensor<3x1x4x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_nofold_non_cst_input
+func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>) -> tensor<3x2xf32> {
+  %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+  // CHECK: tosa.transpose
+  %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
+  return %1 : tensor<3x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_nofold_non_cst_perms
+func @transpose_nofold_non_cst_perms(%perms: tensor<2xi32>) -> tensor<3x2xf32> {
+  %input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
+  // CHECK: tosa.transpose
+  %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
+  return %1 : tensor<3x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_nofold_multi_users
+func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>) {
+  %input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
+  %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+  // CHECK: tosa.transpose
+  %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
+  return %1, %input : tensor<3x2xf32>, tensor<2x3xf32>
+}


        


More information about the Mlir-commits mailing list