[Mlir-commits] [mlir] [mlir][linalg] Add folder for transpose(transpose) -> transpose (PR #93606)

Ryan Holt llvmlistbot at llvm.org
Wed May 29 11:30:23 PDT 2024


https://github.com/ryan-holt-1 updated https://github.com/llvm/llvm-project/pull/93606

>From 0b2a7c6c28e44203fe627f9c0376fd782cfad603 Mon Sep 17 00:00:00 2001
From: ryan-holt-1 <ryanholt at mathworks.com>
Date: Tue, 28 May 2024 15:20:18 -0400
Subject: [PATCH] [mlir][linalg] Add folder for transpose(transpose) ->
 transpose

Back to back `linalg.transpose` can be rewritten to a single transpose
---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  |  1 +
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 29 ++++++++++++
 mlir/test/Dialect/Linalg/canonicalize.mlir    | 45 +++++++++++++++++++
 3 files changed, 75 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 5ee363ed32572..ac61117c3d6e3 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -458,6 +458,7 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
   }];
 
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
   let hasCustomAssemblyFormat = 1;
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 6a5f25a7605f1..0daceb9eeab60 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1866,6 +1866,35 @@ LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
   return failure();
 }
 
+/// Fold transpose with transpose.
+struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
+  using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    if (auto defTransposeOp =
+            transposeOp.getInput().getDefiningOp<TransposeOp>()) {
+      ArrayRef<int64_t> defPerms = defTransposeOp.getPermutation();
+      ArrayRef<int64_t> perms = transposeOp.getPermutation();
+      SmallVector<int64_t> foldedPerms;
+      foldedPerms.reserve(perms.size());
+      for (int64_t perm : perms)
+        foldedPerms.push_back(defPerms[perm]);
+
+      rewriter.replaceOpWithNewOp<TransposeOp>(
+          transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
+          foldedPerms);
+      return success();
+    }
+    return failure();
+  }
+};
+
+void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                              MLIRContext *context) {
+  results.add<FoldTransposeWithTranspose>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // BroadcastOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 19cea6c2066c9..6e8fad04077fe 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1051,3 +1051,48 @@ func.func @transpose_identity_perm(%input: tensor<16x32x64xf32>,
 //   CHECK-NOT:   linalg.transpose
 //       CHECK:   return %[[INPUT]] : tensor<16x32x64xf32>
 
+// -----
+
+func.func @transpose_transpose_cancel(%input: tensor<5x4x3xf32>, 
+                                      %init1: tensor<4x3x5xf32>, 
+                                      %init2: tensor<5x4x3xf32>) -> tensor<5x4x3xf32> {
+  // CHECK-LABEL: @transpose_transpose_cancel
+  //  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<5x4x3xf32>
+  //  CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<4x3x5xf32>
+  //  CHECK-SAME:     %[[INIT2:[a-zA-Z0-9]+]]: tensor<5x4x3xf32>
+  //   CHECK-NOT:   linalg.transpose
+  //       CHECK:   return %[[INPUT]] : tensor<5x4x3xf32>
+  %transpose1 = linalg.transpose
+      ins(%input:tensor<5x4x3xf32>)
+      outs(%init1:tensor<4x3x5xf32>)
+      permutation = [1, 2, 0]
+  %transpose2 = linalg.transpose
+      ins(%transpose1:tensor<4x3x5xf32>)
+      outs(%init2:tensor<5x4x3xf32>)
+      permutation = [2, 0, 1]
+  func.return %transpose2 : tensor<5x4x3xf32>
+}
+
+// -----
+
+func.func @transpose_transpose_fold(%input: tensor<5x4x3xf32>,
+                                    %init1: tensor<4x3x5xf32>,
+                                    %init2: tensor<3x4x5xf32>) -> tensor<3x4x5xf32> {
+  // CHECK-LABEL: @transpose_transpose_fold
+  //  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<5x4x3xf32>
+  //  CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<4x3x5xf32>
+  //  CHECK-SAME:     %[[INIT2:[a-zA-Z0-9]+]]: tensor<3x4x5xf32>
+  //       CHECK:   %[[TRANSPOSE:.+]] = linalg.transpose ins(%{{.+}} : tensor<5x4x3xf32>) outs(%{{.+}} : tensor<3x4x5xf32>) permutation = [2, 1, 0]
+  //   CHECK-NOT:   linalg.transpose
+  //       CHECK:   return %[[TRANSPOSE]] : tensor<3x4x5xf32>
+  %transpose1 = linalg.transpose
+      ins(%input:tensor<5x4x3xf32>)
+      outs(%init1:tensor<4x3x5xf32>)
+      permutation = [1, 2, 0]
+  %transpose2 = linalg.transpose
+      ins(%transpose1:tensor<4x3x5xf32>)
+      outs(%init2:tensor<3x4x5xf32>)
+      permutation = [1, 0, 2]
+  func.return %transpose2 : tensor<3x4x5xf32>
+}
+



More information about the Mlir-commits mailing list