[Mlir-commits] [mlir] ded988e - [mlir][tosa] Remove redundant "tosa.transpose" operations

Rob Suderman llvmlistbot at llvm.org
Tue Jan 10 13:58:08 PST 2023


Author: Aviad Cohen
Date: 2023-01-10T13:56:25-08:00
New Revision: ded988ed0c00e033aa7fa9ea42d7ad19f3dd983e

URL: https://github.com/llvm/llvm-project/commit/ded988ed0c00e033aa7fa9ea42d7ad19f3dd983e
DIFF: https://github.com/llvm/llvm-project/commit/ded988ed0c00e033aa7fa9ea42d7ad19f3dd983e.diff

LOG: [mlir][tosa] Remove redundant "tosa.transpose" operations

We can fold redundant Tosa::TransposeOp actions like identity tranpose/transpose(traspose).

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D140466

Added: 
    mlir/test/IR/transpose-fold.mlir

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index b73368fc086d7..6609c6bdf8199 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1542,6 +1542,10 @@ def Tosa_TransposeOp : Tosa_Op<"transpose", [
     outs Tosa_Tensor1Dto6D:$output
   );
 
+  let extraClassDeclaration = [{
+    LogicalResult getConstantPerms(llvm::SmallVector<int64_t> &perms);
+  }];
+
   let hasCanonicalizer = 1;
   let hasFolder = 1;
 }

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index f8b48f10a22f7..5f44634d5482d 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -131,29 +131,49 @@ LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
   return success();
 }
 
-struct TransposeNoOp : public OpRewritePattern<tosa::TransposeOp> {
+struct ConsolidateTransposeOptimization
+    : public OpRewritePattern<tosa::TransposeOp> {
   using OpRewritePattern::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(tosa::TransposeOp op,
+  LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp,
                                 PatternRewriter &rewriter) const override {
-    auto perm = op.getPerms();
+    // Input is also TransposeOp - transpose(transpose(A)).
+    auto innerTranspose =
+        transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
+    if (!innerTranspose)
+      return rewriter.notifyMatchFailure(transposeOp,
+                                         "input must be transpose operation");
+
+    SmallVector<int64_t> transposePerms, innerTransposePerms;
+    if (transposeOp.getConstantPerms(transposePerms).failed())
+      return rewriter.notifyMatchFailure(transposeOp,
+                                         "transpose perms must be constant");
+    if (innerTranspose.getConstantPerms(innerTransposePerms).failed())
+      return rewriter.notifyMatchFailure(
+          transposeOp, "inner transpose perms must be constant");
+    if (transposePerms.size() != innerTransposePerms.size())
+      return rewriter.notifyMatchFailure(
+          transposeOp,
+          "transpose and inner transpose perms sizes must be equal");
+    if (transposePerms.empty())
+      return rewriter.notifyMatchFailure(
+          transposeOp, "transpose perms sizes must be positive");
 
-    DenseIntElementsAttr permAttr;
-    if (!matchPattern(perm, m_Constant(&permAttr))) {
-      return failure();
-    }
+    // Consolidate transposes into one transpose.
+    SmallVector<int32_t> perms(transposePerms.size());
+    for (int i = 0, s = transposePerms.size(); i < s; ++i)
+      perms[i] = innerTransposePerms[transposePerms[i]];
 
-    SmallVector<int64_t> permValues = llvm::to_vector<6>(
-        llvm::map_range(permAttr.getValues<APInt>(),
-                        [](const APInt &val) { return val.getSExtValue(); }));
+    auto permsTy =
+        RankedTensorType::get(transposePerms.size(), rewriter.getI32Type());
+    auto permsAttr = DenseIntElementsAttr::get(permsTy, perms);
+    Value permsValue =
+        rewriter.create<arith::ConstantOp>(transposeOp.getLoc(), permsAttr);
 
-    for (int i = 0, s = permValues.size(); i < s; i++) {
-      if (i != permValues[i]) {
-        return failure();
-      }
-    }
+    rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
+        transposeOp, transposeOp.getResult().getType(),
+        innerTranspose.getInput1(), permsValue);
 
-    rewriter.replaceOp(op, op.getInput1());
     return success();
   }
 };
@@ -212,7 +232,7 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
 
 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
-  results.add<TransposeNoOp, TransposeIsReshape>(context);
+  results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
 }
 
 struct AddZeroOptimization : public OpRewritePattern<tosa::AddOp> {
@@ -997,26 +1017,27 @@ OpFoldResult TileOp::fold(ArrayRef<Attribute> operands) {
 }
 
 OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
-  if (!operands[1])
-    return {};
-
   auto inputTy = getInput1().getType().cast<ShapedType>();
   auto resultTy = getType().cast<ShapedType>();
-  if (inputTy.getElementType() != resultTy.getElementType())
-    return {};
 
   // Transposing splat values just means reshaping.
   if (auto input = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {
-    if (input.isSplat())
-      return input.reshape(getType().cast<ShapedType>());
+    if (input.isSplat() && resultTy.hasStaticShape() &&
+        inputTy.getElementType() == resultTy.getElementType())
+      return input.reshape(resultTy);
   }
 
-  auto perms = llvm::to_vector<6>(llvm::map_range(
-      operands[1].cast<DenseIntElementsAttr>().getValues<APInt>(),
-      [](const APInt &val) { return val.getSExtValue(); }));
+  // Transpose does not change the input type.
+  if (getInput1().getType() != getType())
+    return {};
 
-  if (llvm::equal(llvm::seq<int64_t>(0, perms.size()), perms) &&
-      getInput1().getType() == getType())
-    return getInput1();
-  return {};
+  // Transpose is not the identity transpose.
+  SmallVector<int64_t> perms;
+  if (getConstantPerms(perms).failed())
+    return {};
+
+  if (!llvm::equal(llvm::seq<int64_t>(0, perms.size()), perms))
+    return {};
+
+  return getInput1();
 }

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 7ce00812064b5..82f34f93cb3d1 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -688,6 +688,20 @@ mlir::LogicalResult tosa::ReshapeOp::verify() {
   return mlir::success();
 }
 
+LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int64_t> &perms) {
+  // Perms must be constants.
+  DenseIntElementsAttr permsAttr;
+  if (!matchPattern(getPerms(), m_Constant(&permsAttr)))
+    return failure();
+
+  // Transpose is not the identity transpose.
+  perms = llvm::to_vector(
+      llvm::map_range(permsAttr.getValues<APInt>(),
+                      [](const APInt &val) { return val.getSExtValue(); }));
+
+  return success();
+}
+
 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,

diff  --git a/mlir/test/IR/transpose-fold.mlir b/mlir/test/IR/transpose-fold.mlir
new file mode 100644
index 0000000000000..1079bf3e02cfb
--- /dev/null
+++ b/mlir/test/IR/transpose-fold.mlir
@@ -0,0 +1,44 @@
+// RUN: mlir-opt %s --canonicalize -split-input-file | FileCheck %s
+
+// CHECK-LABEL:   func.func @test_cancel_transpose_transpose(
+// CHECK-SAME:                                               %[[VAL_0:.*]]: tensor<1x2x3xi32>) -> tensor<1x2x3xi32> {
+// CHECK:           return %[[VAL_0]] : tensor<1x2x3xi32>
+// CHECK:         }
+
+func.func @test_cancel_transpose_transpose(%arg0: tensor<1x2x3xi32>) -> (tensor<1x2x3xi32>) {
+	%0 = arith.constant dense<[1, 2, 0]> : tensor<3xi32>
+	%1 = "tosa.transpose"(%arg0, %0) : (tensor<1x2x3xi32>, tensor<3xi32>) -> (tensor<2x3x1xi32>)
+	%2 = arith.constant dense<[2, 0, 1]> : tensor<3xi32>
+	%3 = "tosa.transpose"(%1, %2) : (tensor<2x3x1xi32>, tensor<3xi32>) -> tensor<1x2x3xi32>
+  return %3 : tensor<1x2x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @test_remove_identity_transpose(
+// CHECK-SAME:                                              %[[VAL_0:.*]]: tensor<1x2x3xi32>) -> tensor<1x2x3xi32> {
+// CHECK:           return %[[VAL_0]] : tensor<1x2x3xi32>
+// CHECK:         }
+
+func.func @test_remove_identity_transpose(%arg0: tensor<1x2x3xi32>) -> (tensor<1x2x3xi32>) {
+	%0 = arith.constant dense<[0, 1, 2]> : tensor<3xi32>
+	%1 = "tosa.transpose"(%arg0, %0) : (tensor<1x2x3xi32>, tensor<3xi32>) -> (tensor<1x2x3xi32>)
+  return %1 : tensor<1x2x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @test_do_not_cancel_
diff erent_transpose(
+// CHECK-SAME:                                                      %[[VAL_0:.*]]: tensor<2x3x4x5xi32>) -> tensor<5x4x3x2xi32> {
+// CHECK:           %[[VAL_1:.*]] = arith.constant dense<[3, 2, 1, 0]> : tensor<4xi32>
+// CHECK:           %[[VAL_2:.*]] = "tosa.transpose"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2x3x4x5xi32>, tensor<4xi32>) -> tensor<5x4x3x2xi32>
+// CHECK:           return %[[VAL_2]] : tensor<5x4x3x2xi32>
+// CHECK:         }
+
+func.func @test_do_not_cancel_
diff erent_transpose(%arg0: tensor<2x3x4x5xi32>) -> (tensor<5x4x3x2xi32>) {
+	%0 = arith.constant dense<[1, 2, 0, 3]> : tensor<4xi32>
+	%1 = "tosa.transpose"(%arg0, %0) : (tensor<2x3x4x5xi32>, tensor<4xi32>) -> (tensor<3x4x2x5xi32>)
+	%2 = arith.constant dense<[3, 1, 0, 2]> : tensor<4xi32>
+	%3 = "tosa.transpose"(%1, %2) : (tensor<3x4x2x5xi32>, tensor<4xi32>) -> tensor<5x4x3x2xi32>
+  return %3 : tensor<5x4x3x2xi32>
+}


        


More information about the Mlir-commits mailing list