[Mlir-commits] [mlir] [mlir][Linalg] Add folders for `linalg.transpose` (PR #81709)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 13 22:28:51 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Diego Caballero (dcaballe)

<details>
<summary>Changes</summary>

This PR adds folders for linalg transpose ops with only one dimension or an identity permutation. The folding removes the `linalg.transpose` and just propagates the input tensor. Given that this is a DPS op, I'm now wondering if this folding is incorrect and we should instead replace the op with a `linalg.copy` so that the init tensor is still used. Feedback would be appreciated. I think that propagating the input tensor if the DPS op is folded away should be ok, given that all the uses of the init tensor are replaced with the input tensor, but I might be missing something.

---
Full diff: https://github.com/llvm/llvm-project/pull/81709.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+2-1) 
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+16) 
- (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+34) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 751edd02288301..de9414598b0282 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -245,7 +245,7 @@ def MapOp : LinalgStructuredBase_Op<"map", [
           }
     ```
 
-    Shortened print form is available. Applies to simple maps with one 
+    Shortened print form is available. Applies to simple maps with one
     non-yield operation inside the body.
 
     The example above will be printed as:
@@ -458,6 +458,7 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
                              ::mlir::OperationState & odsState);
   }];
 
+  let hasFolder = 1;
   let hasCustomAssemblyFormat = 1;
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e86b9762d8581f..2f6ab7e32e5872 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1785,6 +1785,22 @@ void TransposeOp::getEffects(
                         getDpsInits());
 }
 
+LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
+                                SmallVectorImpl<OpFoldResult> &result) {
+  // Single dimension transpose.
+  if (getPermutation().size() == 0) {
+    result.push_back(getInput());
+    return success();
+  }
+  // Identity permutation.
+  if (isIdentityPermutation(getPermutation())) {
+    result.push_back(getInput());
+    return success();
+  }
+
+  return failure();
+}
+
 //===----------------------------------------------------------------------===//
 // BroadcastOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 052dc367ca6779..5bc4cb82f8cdbc 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1017,3 +1017,37 @@ func.func @canonicalize_fill_to_copy_dest(%arg0 : tensor<?x?xf32>, %arg1 : tenso
   %copy = linalg.copy ins(%arg1 : tensor<?x?xf32>) outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
   return %copy : tensor<?x?xf32>
 }
+
+// -----
+
+func.func @transpose_1d(%input: tensor<16xf32>,
+                        %init: tensor<16xf32>) -> tensor<16xf32> {
+  %transpose = linalg.transpose
+      ins(%input:tensor<16xf32>)
+      outs(%init:tensor<16xf32>)
+      permutation = [0]
+  func.return %transpose : tensor<16xf32>
+}
+
+// CHECK-LABEL: func @transpose_1d(
+//  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<16xf32>,
+//  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<16xf32>)
+//   CHECK-NOT:   linalg.transpose
+//       CHECK:   return %[[INPUT]] : tensor<16xf32>
+
+// -----
+
+func.func @transpose_identity_perm(%input: tensor<16x32x64xf32>,
+                                   %init: tensor<16x32x64xf32>) -> tensor<16x32x64xf32> {
+  %transpose = linalg.transpose
+      ins(%input:tensor<16x32x64xf32>)
+      outs(%init:tensor<16x32x64xf32>)
+      permutation = [0, 1, 2]
+  func.return %transpose : tensor<16x32x64xf32>
+}
+
+// CHECK-LABEL: func @transpose_identity_perm(
+//  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<16x32x64xf32>,
+//  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<16x32x64xf32>)
+//   CHECK-NOT:   linalg.transpose
+//       CHECK:   return %[[INPUT]] : tensor<16x32x64xf32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/81709


More information about the Mlir-commits mailing list