[Mlir-commits] [mlir] 198d1d9 - [mlir][tosa] Prefer tosa.transpose composition canonicalization to reshape

Rob Suderman llvmlistbot at llvm.org
Wed Jan 18 09:13:54 PST 2023


Author: Rob Suderman
Date: 2023-01-18T08:59:24-08:00
New Revision: 198d1d99769700d0136ac90275a8d6fa1871accf

URL: https://github.com/llvm/llvm-project/commit/198d1d99769700d0136ac90275a8d6fa1871accf
DIFF: https://github.com/llvm/llvm-project/commit/198d1d99769700d0136ac90275a8d6fa1871accf.diff

LOG: [mlir][tosa] Prefer tosa.transpose composition canonicalization to reshape

It is preferred to merge tosa.transpose operations together rather than convert
one to a tosa.reshape. This is to leverage the tosa.transpose -> tosa.transpose
merging canonicalization.

Reviewed By: AviadCo

Differential Revision: https://reviews.llvm.org/D141434

Added: 
    mlir/test/Dialect/Tosa/transpose-fold.mlir

Modified: 
    mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Removed: 
    mlir/test/IR/transpose-fold.mlir


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 625c85593a9cf..74325c85de2b3 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -188,6 +188,17 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
     if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
       return rewriter.notifyMatchFailure(op, "Non-constant permutation");
 
+    if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
+      return rewriter.notifyMatchFailure(
+          op, "Src is from transpose, can compose transposes");
+
+    Value result = op.getResult();
+    for (Operation *subop : result.getUsers()) {
+      if (dyn_cast_or_null<tosa::TransposeOp>(subop))
+        return rewriter.notifyMatchFailure(
+            op, "Dest is used by transpose, can compose transposes");
+    }
+
     auto input = op.getInput1();
     auto inputTy = input.getType().cast<ShapedType>();
     if (!inputTy.hasRank())

diff  --git a/mlir/test/IR/transpose-fold.mlir b/mlir/test/Dialect/Tosa/transpose-fold.mlir
similarity index 70%
rename from mlir/test/IR/transpose-fold.mlir
rename to mlir/test/Dialect/Tosa/transpose-fold.mlir
index 1079bf3e02cfb..df49b7940b34d 100644
--- a/mlir/test/IR/transpose-fold.mlir
+++ b/mlir/test/Dialect/Tosa/transpose-fold.mlir
@@ -42,3 +42,20 @@ func.func @test_do_not_cancel_
diff erent_transpose(%arg0: tensor<2x3x4x5xi32>) ->
 	%3 = "tosa.transpose"(%1, %2) : (tensor<3x4x2x5xi32>, tensor<4xi32>) -> tensor<5x4x3x2xi32>
   return %3 : tensor<5x4x3x2xi32>
 }
+
+// -----
+
+// CHECK-LABEL:   func.func @test_prefer_compose_transpose(
+// CHECK-SAME:                                                      %[[VAL_0:.*]]: tensor<1x2x3x4xi32>) -> tensor<4x3x2x1xi32> {
+// CHECK:           %[[VAL_1:.*]] = arith.constant dense<[3, 2, 1, 0]> : tensor<4xi32>
+// CHECK:           %[[VAL_2:.*]] = "tosa.transpose"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<4x3x2x1xi32>
+// CHECK:           return %[[VAL_2]] : tensor<4x3x2x1xi32>
+// CHECK:         }
+
+func.func @test_prefer_compose_transpose(%arg0: tensor<1x2x3x4xi32>) -> (tensor<4x3x2x1xi32>) {
+	%0 = arith.constant dense<[1, 2, 0, 3]> : tensor<4xi32>
+	%1 = "tosa.transpose"(%arg0, %0) : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> (tensor<2x3x1x4xi32>)
+	%2 = arith.constant dense<[3, 1, 0, 2]> : tensor<4xi32>
+	%3 = "tosa.transpose"(%1, %2) : (tensor<2x3x1x4xi32>, tensor<4xi32>) -> tensor<4x3x2x1xi32>
+  return %3 : tensor<4x3x2x1xi32>
+}


        


More information about the Mlir-commits mailing list