[Mlir-commits] [mlir] ded988e - [mlir][tosa] Remove redundant "tosa.transpose" operations
Rob Suderman
llvmlistbot at llvm.org
Tue Jan 10 13:58:08 PST 2023
Author: Aviad Cohen
Date: 2023-01-10T13:56:25-08:00
New Revision: ded988ed0c00e033aa7fa9ea42d7ad19f3dd983e
URL: https://github.com/llvm/llvm-project/commit/ded988ed0c00e033aa7fa9ea42d7ad19f3dd983e
DIFF: https://github.com/llvm/llvm-project/commit/ded988ed0c00e033aa7fa9ea42d7ad19f3dd983e.diff
LOG: [mlir][tosa] Remove redundant "tosa.transpose" operations
We can fold redundant Tosa::TransposeOp actions like identity tranpose/transpose(traspose).
Reviewed By: rsuderman
Differential Revision: https://reviews.llvm.org/D140466
Added:
mlir/test/IR/transpose-fold.mlir
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index b73368fc086d7..6609c6bdf8199 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1542,6 +1542,10 @@ def Tosa_TransposeOp : Tosa_Op<"transpose", [
outs Tosa_Tensor1Dto6D:$output
);
+ let extraClassDeclaration = [{
+ LogicalResult getConstantPerms(llvm::SmallVector<int64_t> &perms);
+ }];
+
let hasCanonicalizer = 1;
let hasFolder = 1;
}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index f8b48f10a22f7..5f44634d5482d 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -131,29 +131,49 @@ LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
return success();
}
-struct TransposeNoOp : public OpRewritePattern<tosa::TransposeOp> {
+struct ConsolidateTransposeOptimization
+ : public OpRewritePattern<tosa::TransposeOp> {
using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(tosa::TransposeOp op,
+ LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
- auto perm = op.getPerms();
+ // Input is also TransposeOp - transpose(transpose(A)).
+ auto innerTranspose =
+ transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
+ if (!innerTranspose)
+ return rewriter.notifyMatchFailure(transposeOp,
+ "input must be transpose operation");
+
+ SmallVector<int64_t> transposePerms, innerTransposePerms;
+ if (transposeOp.getConstantPerms(transposePerms).failed())
+ return rewriter.notifyMatchFailure(transposeOp,
+ "transpose perms must be constant");
+ if (innerTranspose.getConstantPerms(innerTransposePerms).failed())
+ return rewriter.notifyMatchFailure(
+ transposeOp, "inner transpose perms must be constant");
+ if (transposePerms.size() != innerTransposePerms.size())
+ return rewriter.notifyMatchFailure(
+ transposeOp,
+ "transpose and inner transpose perms sizes must be equal");
+ if (transposePerms.empty())
+ return rewriter.notifyMatchFailure(
+ transposeOp, "transpose perms sizes must be positive");
- DenseIntElementsAttr permAttr;
- if (!matchPattern(perm, m_Constant(&permAttr))) {
- return failure();
- }
+ // Consolidate transposes into one transpose.
+ SmallVector<int32_t> perms(transposePerms.size());
+ for (int i = 0, s = transposePerms.size(); i < s; ++i)
+ perms[i] = innerTransposePerms[transposePerms[i]];
- SmallVector<int64_t> permValues = llvm::to_vector<6>(
- llvm::map_range(permAttr.getValues<APInt>(),
- [](const APInt &val) { return val.getSExtValue(); }));
+ auto permsTy =
+ RankedTensorType::get(transposePerms.size(), rewriter.getI32Type());
+ auto permsAttr = DenseIntElementsAttr::get(permsTy, perms);
+ Value permsValue =
+ rewriter.create<arith::ConstantOp>(transposeOp.getLoc(), permsAttr);
- for (int i = 0, s = permValues.size(); i < s; i++) {
- if (i != permValues[i]) {
- return failure();
- }
- }
+ rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
+ transposeOp, transposeOp.getResult().getType(),
+ innerTranspose.getInput1(), permsValue);
- rewriter.replaceOp(op, op.getInput1());
return success();
}
};
@@ -212,7 +232,7 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<TransposeNoOp, TransposeIsReshape>(context);
+ results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
}
struct AddZeroOptimization : public OpRewritePattern<tosa::AddOp> {
@@ -997,26 +1017,27 @@ OpFoldResult TileOp::fold(ArrayRef<Attribute> operands) {
}
OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
- if (!operands[1])
- return {};
-
auto inputTy = getInput1().getType().cast<ShapedType>();
auto resultTy = getType().cast<ShapedType>();
- if (inputTy.getElementType() != resultTy.getElementType())
- return {};
// Transposing splat values just means reshaping.
if (auto input = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {
- if (input.isSplat())
- return input.reshape(getType().cast<ShapedType>());
+ if (input.isSplat() && resultTy.hasStaticShape() &&
+ inputTy.getElementType() == resultTy.getElementType())
+ return input.reshape(resultTy);
}
- auto perms = llvm::to_vector<6>(llvm::map_range(
- operands[1].cast<DenseIntElementsAttr>().getValues<APInt>(),
- [](const APInt &val) { return val.getSExtValue(); }));
+ // Transpose does not change the input type.
+ if (getInput1().getType() != getType())
+ return {};
- if (llvm::equal(llvm::seq<int64_t>(0, perms.size()), perms) &&
- getInput1().getType() == getType())
- return getInput1();
- return {};
+ // Transpose is not the identity transpose.
+ SmallVector<int64_t> perms;
+ if (getConstantPerms(perms).failed())
+ return {};
+
+ if (!llvm::equal(llvm::seq<int64_t>(0, perms.size()), perms))
+ return {};
+
+ return getInput1();
}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 7ce00812064b5..82f34f93cb3d1 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -688,6 +688,20 @@ mlir::LogicalResult tosa::ReshapeOp::verify() {
return mlir::success();
}
+LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int64_t> &perms) {
+ // Perms must be constants.
+ DenseIntElementsAttr permsAttr;
+ if (!matchPattern(getPerms(), m_Constant(&permsAttr)))
+ return failure();
+
+ // Transpose is not the identity transpose.
+ perms = llvm::to_vector(
+ llvm::map_range(permsAttr.getValues<APInt>(),
+ [](const APInt &val) { return val.getSExtValue(); }));
+
+ return success();
+}
+
LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
diff --git a/mlir/test/IR/transpose-fold.mlir b/mlir/test/IR/transpose-fold.mlir
new file mode 100644
index 0000000000000..1079bf3e02cfb
--- /dev/null
+++ b/mlir/test/IR/transpose-fold.mlir
@@ -0,0 +1,44 @@
+// RUN: mlir-opt %s --canonicalize -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @test_cancel_transpose_transpose(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3xi32>) -> tensor<1x2x3xi32> {
+// CHECK: return %[[VAL_0]] : tensor<1x2x3xi32>
+// CHECK: }
+
+func.func @test_cancel_transpose_transpose(%arg0: tensor<1x2x3xi32>) -> (tensor<1x2x3xi32>) {
+ %0 = arith.constant dense<[1, 2, 0]> : tensor<3xi32>
+ %1 = "tosa.transpose"(%arg0, %0) : (tensor<1x2x3xi32>, tensor<3xi32>) -> (tensor<2x3x1xi32>)
+ %2 = arith.constant dense<[2, 0, 1]> : tensor<3xi32>
+ %3 = "tosa.transpose"(%1, %2) : (tensor<2x3x1xi32>, tensor<3xi32>) -> tensor<1x2x3xi32>
+ return %3 : tensor<1x2x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_remove_identity_transpose(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3xi32>) -> tensor<1x2x3xi32> {
+// CHECK: return %[[VAL_0]] : tensor<1x2x3xi32>
+// CHECK: }
+
+func.func @test_remove_identity_transpose(%arg0: tensor<1x2x3xi32>) -> (tensor<1x2x3xi32>) {
+ %0 = arith.constant dense<[0, 1, 2]> : tensor<3xi32>
+ %1 = "tosa.transpose"(%arg0, %0) : (tensor<1x2x3xi32>, tensor<3xi32>) -> (tensor<1x2x3xi32>)
+ return %1 : tensor<1x2x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_do_not_cancel_
diff erent_transpose(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x3x4x5xi32>) -> tensor<5x4x3x2xi32> {
+// CHECK: %[[VAL_1:.*]] = arith.constant dense<[3, 2, 1, 0]> : tensor<4xi32>
+// CHECK: %[[VAL_2:.*]] = "tosa.transpose"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2x3x4x5xi32>, tensor<4xi32>) -> tensor<5x4x3x2xi32>
+// CHECK: return %[[VAL_2]] : tensor<5x4x3x2xi32>
+// CHECK: }
+
+func.func @test_do_not_cancel_
diff erent_transpose(%arg0: tensor<2x3x4x5xi32>) -> (tensor<5x4x3x2xi32>) {
+ %0 = arith.constant dense<[1, 2, 0, 3]> : tensor<4xi32>
+ %1 = "tosa.transpose"(%arg0, %0) : (tensor<2x3x4x5xi32>, tensor<4xi32>) -> (tensor<3x4x2x5xi32>)
+ %2 = arith.constant dense<[3, 1, 0, 2]> : tensor<4xi32>
+ %3 = "tosa.transpose"(%1, %2) : (tensor<3x4x2x5xi32>, tensor<4xi32>) -> tensor<5x4x3x2xi32>
+ return %3 : tensor<5x4x3x2xi32>
+}
More information about the Mlir-commits
mailing list