[Mlir-commits] [mlir] 198d1d9 - [mlir][tosa] Prefer tosa.transpose composition canonicalization to reshape
Rob Suderman
llvmlistbot at llvm.org
Wed Jan 18 09:13:54 PST 2023
Author: Rob Suderman
Date: 2023-01-18T08:59:24-08:00
New Revision: 198d1d99769700d0136ac90275a8d6fa1871accf
URL: https://github.com/llvm/llvm-project/commit/198d1d99769700d0136ac90275a8d6fa1871accf
DIFF: https://github.com/llvm/llvm-project/commit/198d1d99769700d0136ac90275a8d6fa1871accf.diff
LOG: [mlir][tosa] Prefer tosa.transpose composition canonicalization to reshape
It is preferred to merge tosa.transpose operations together rather than convert
one to a tosa.reshape. This is to leverage the tosa.transpose -> tosa.transpose
merging canonicalization.
Reviewed By: AviadCo
Differential Revision: https://reviews.llvm.org/D141434
Added:
mlir/test/Dialect/Tosa/transpose-fold.mlir
Modified:
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Removed:
mlir/test/IR/transpose-fold.mlir
################################################################################
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 625c85593a9cf..74325c85de2b3 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -188,6 +188,17 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
return rewriter.notifyMatchFailure(op, "Non-constant permutation");
+ if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
+ return rewriter.notifyMatchFailure(
+ op, "Src is from transpose, can compose transposes");
+
+ Value result = op.getResult();
+ for (Operation *subop : result.getUsers()) {
+ if (dyn_cast_or_null<tosa::TransposeOp>(subop))
+ return rewriter.notifyMatchFailure(
+ op, "Dest is used by transpose, can compose transposes");
+ }
+
auto input = op.getInput1();
auto inputTy = input.getType().cast<ShapedType>();
if (!inputTy.hasRank())
diff --git a/mlir/test/IR/transpose-fold.mlir b/mlir/test/Dialect/Tosa/transpose-fold.mlir
similarity index 70%
rename from mlir/test/IR/transpose-fold.mlir
rename to mlir/test/Dialect/Tosa/transpose-fold.mlir
index 1079bf3e02cfb..df49b7940b34d 100644
--- a/mlir/test/IR/transpose-fold.mlir
+++ b/mlir/test/Dialect/Tosa/transpose-fold.mlir
@@ -42,3 +42,20 @@ func.func @test_do_not_cancel_
diff erent_transpose(%arg0: tensor<2x3x4x5xi32>) ->
%3 = "tosa.transpose"(%1, %2) : (tensor<3x4x2x5xi32>, tensor<4xi32>) -> tensor<5x4x3x2xi32>
return %3 : tensor<5x4x3x2xi32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @test_prefer_compose_transpose(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3x4xi32>) -> tensor<4x3x2x1xi32> {
+// CHECK: %[[VAL_1:.*]] = arith.constant dense<[3, 2, 1, 0]> : tensor<4xi32>
+// CHECK: %[[VAL_2:.*]] = "tosa.transpose"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<4x3x2x1xi32>
+// CHECK: return %[[VAL_2]] : tensor<4x3x2x1xi32>
+// CHECK: }
+
+func.func @test_prefer_compose_transpose(%arg0: tensor<1x2x3x4xi32>) -> (tensor<4x3x2x1xi32>) {
+ %0 = arith.constant dense<[1, 2, 0, 3]> : tensor<4xi32>
+ %1 = "tosa.transpose"(%arg0, %0) : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> (tensor<2x3x1x4xi32>)
+ %2 = arith.constant dense<[3, 1, 0, 2]> : tensor<4xi32>
+ %3 = "tosa.transpose"(%1, %2) : (tensor<2x3x1x4xi32>, tensor<4xi32>) -> tensor<4x3x2x1xi32>
+ return %3 : tensor<4x3x2x1xi32>
+}
More information about the Mlir-commits
mailing list