[Mlir-commits] [mlir] e325ebb - [mlir][tosa] Add some transpose folders
Lei Zhang
llvmlistbot at llvm.org
Fri Sep 24 12:25:19 PDT 2021
Author: Lei Zhang
Date: 2021-09-24T15:25:14-04:00
New Revision: e325ebb9c70bbdd48866926a42d4c4373b832035
URL: https://github.com/llvm/llvm-project/commit/e325ebb9c70bbdd48866926a42d4c4373b832035
DIFF: https://github.com/llvm/llvm-project/commit/e325ebb9c70bbdd48866926a42d4c4373b832035.diff
LOG: [mlir][tosa] Add some transpose folders
* If the input is a constant splat value, we just
need to reshape it.
* If the input is a general constant with one user,
we can also constant fold it, without bloating
the IR.
Reviewed By: rsuderman
Differential Revision: https://reviews.llvm.org/D110439
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/Dialect/Tosa/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index b5d556ae7cea3..cc8c8fab56e5b 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1534,6 +1534,7 @@ def Tosa_TransposeOp : Tosa_Op<"transpose", [
outs Tosa_Tensor1Dto6D:$output
);
+ let hasCanonicalizer = 1;
let hasFolder = 1;
}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 483493528b3a2..a51d02e865380 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -159,6 +159,71 @@ void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
results.insert<ReshapeReshapeOptimization>(context);
}
+struct ConstantTransposeOptimization
+ : public OpRewritePattern<tosa::TransposeOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::TransposeOp op,
+ PatternRewriter &rewriter) const override {
+ DenseElementsAttr inputValues;
+ if (!matchPattern(op.input1(), m_Constant(&inputValues)))
+ return failure();
+ // Make sure the input is a constant that has a single user.
+ if (!llvm::hasSingleElement(op.input1().getDefiningOp()->getUsers()))
+ return failure();
+
+ DenseIntElementsAttr permAttr;
+ if (!matchPattern(op.perms(), m_Constant(&permAttr)))
+ return failure();
+ auto permValues = llvm::to_vector<6>(llvm::map_range(
+ // TOSA allows both 32- and 64-bit integer tensors here.
+ permAttr.getValues<APInt>(),
+ [](const APInt &val) { return val.getZExtValue(); }));
+
+ auto inputType = op.input1().getType().cast<ShapedType>();
+ ArrayRef<int64_t> inputShape = inputType.getShape();
+ int64_t numElements = inputType.getNumElements();
+
+ auto outputType = op.getType().cast<ShapedType>();
+ ArrayRef<int64_t> outputShape = outputType.getShape();
+
+ SmallVector<Attribute, 4> outputValues;
+ outputValues.resize(numElements);
+
+ // Transpose the input constant. Because we don't know its rank in advance,
+ // we need to loop over the range [0, element count) and delinearize the
+ // index.
+ for (int srcLinearIndex = 0; srcLinearIndex < numElements;
+ ++srcLinearIndex) {
+ SmallVector<uint64_t, 6> srcIndices(inputType.getRank(), 0);
+ int totalCount = srcLinearIndex;
+ for (int dim = inputType.getRank() - 1; dim >= 0; --dim) {
+ srcIndices[dim] = totalCount % inputShape[dim];
+ totalCount /= inputShape[dim];
+ }
+
+ SmallVector<uint64_t, 6> dstIndices(outputType.getRank(), 0);
+ for (int dim = outputType.getRank() - 1; dim >= 0; --dim)
+ dstIndices[dim] = srcIndices[permValues[dim]];
+
+ uint64_t dstLinearIndex = dstIndices.front();
+ for (int dim = 1; dim < outputType.getRank(); ++dim)
+ dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
+
+ outputValues[dstLinearIndex] = inputValues.getValue(srcIndices);
+ }
+
+ rewriter.replaceOpWithNewOp<tosa::ConstOp>(
+ op, outputType, DenseElementsAttr::get(outputType, outputValues));
+ return success();
+ }
+};
+
+void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<ConstantTransposeOptimization>(context);
+}
+
//===----------------------------------------------------------------------===//
// Operator Folders.
//===----------------------------------------------------------------------===//
@@ -225,15 +290,18 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
if (!operands[1])
return {};
- DenseIntElementsAttr perms = operands[1].cast<DenseIntElementsAttr>();
-
- bool isRange = true;
- for (auto it : llvm::enumerate(perms)) {
- isRange = isRange &&
- it.value().getSExtValue() == static_cast<int64_t>(it.index());
+ // Transposing splat values just means reshaping.
+ if (auto input = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {
+ if (input.isSplat())
+ return input.reshape(getType().cast<ShapedType>());
}
- if (isRange && input1().getType() == getType())
+ auto perms = llvm::to_vector<6>(llvm::map_range(
+ operands[1].cast<DenseIntElementsAttr>().getValues<APInt>(),
+ [](const APInt &val) { return val.getSExtValue(); }));
+
+ if (llvm::equal(llvm::seq<int64_t>(0, perms.size()), perms) &&
+ input1().getType() == getType())
return input1();
return {};
}
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index e0a9d29403125..a3ede85817564 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --canonicalize %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --canonicalize %s | FileCheck %s
// CHECK-LABEL: @argmax_nofold
func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
@@ -237,3 +237,80 @@ func @transpose_nofold_shape(%arg0: tensor<3x4xf32>) -> tensor<?x?xf32> {
%1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
+
+// -----
+
+// CHECK-LABEL: @transpose_fold_splat
+func @transpose_fold_splat() -> tensor<3x2xf32> {
+ %input = "tosa.const"() {value = dense<4.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
+ %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+ // CHECK: %[[CST:.+]] = "tosa.const"()
+ // CHECK-SAME{LITERAL}: value = dense<4.000000e+00> : tensor<3x2xf32>
+ %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
+ // CHECK: return %[[CST]]
+ return %1 : tensor<3x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_fold_2d_float
+func @transpose_fold_2d_float() -> tensor<3x2xf32> {
+ %input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
+ %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+ // CHECK: %[[CST:.+]] = "tosa.const"()
+ // CHECK-SAME{LITERAL}: value = dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32>
+ %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
+ // CHECK: return %[[CST]]
+ return %1 : tensor<3x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_fold_4d_int
+func @transpose_fold_4d_int() -> tensor<3x1x4x2xi32> {
+ %input = "tosa.const"() {value = dense<[[
+ [[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]],
+ [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
+ ]]> : tensor<1x2x3x4xi32>} : () -> tensor<1x2x3x4xi32>
+ %perms = "tosa.const"() {value = dense<[2, 0, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64>
+ // CHECK: %[[CST:.+]] = "tosa.const"()
+ // CHECK-SAME{LITERAL}: value = dense<[
+ // CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]],
+ // CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]],
+ // CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]]
+ // CHECK-SAME{LITERAL}: ]>
+ %1 = "tosa.transpose"(%input, %perms) : (tensor<1x2x3x4xi32>, tensor<4xi64>) -> tensor<3x1x4x2xi32>
+ // CHECK: return %[[CST]]
+ return %1 : tensor<3x1x4x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_nofold_non_cst_input
+func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>) -> tensor<3x2xf32> {
+ %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+ // CHECK: tosa.transpose
+ %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
+ return %1 : tensor<3x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_nofold_non_cst_perms
+func @transpose_nofold_non_cst_perms(%perms: tensor<2xi32>) -> tensor<3x2xf32> {
+ %input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
+ // CHECK: tosa.transpose
+ %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
+ return %1 : tensor<3x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_nofold_multi_users
+func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>) {
+ %input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
+ %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+ // CHECK: tosa.transpose
+ %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
+ return %1, %input : tensor<3x2xf32>, tensor<2x3xf32>
+}
More information about the Mlir-commits
mailing list