[Mlir-commits] [mlir] [mlir][Linalg] Add folders for `linalg.transpose` (PR #81709)
Diego Caballero
llvmlistbot at llvm.org
Tue Feb 13 22:28:24 PST 2024
https://github.com/dcaballe created https://github.com/llvm/llvm-project/pull/81709
This PR adds folders for linalg transpose ops with only one dimension or an identity permutation. The folding removes the `linalg.transpose` and just propagates the input tensor. Given that this is a DPS op, I'm now wondering if this folding is incorrect and we should instead replace the op with a `linalg.copy` so that the init tensor is still used. Feedback would be appreciated. I think that propagating the input tensor if the DPS op is folded away should be ok, given that all the uses of the init tensor are replaced with the input tensor, but I might be missing something.
>From 49af44262a2f86074cb10eecc1af69baa2ab7be3 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Wed, 14 Feb 2024 06:16:22 +0000
Subject: [PATCH] [mlir][Linalg] Add folders for `linalg.transpose`
This PR adds folders for linalg transpose ops with only one dimension or
an identity permutation. The folding removes the `linalg.transpose` and
just propagates the input tensor. Given that this is a DPS op, I'm now
wondering if this folding is incorrect and we should instead replace the
op with a `linalg.copy` so that the init tensor is still used. Feedback
would be appreciated. I think that propagating the input tensor if the
DPS op is folded away should be ok, given that all the uses of the init
tensor are replaced with the input tensor, but I might be missing something.
---
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 3 +-
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 16 +++++++++
mlir/test/Dialect/Linalg/canonicalize.mlir | 34 +++++++++++++++++++
3 files changed, 52 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 751edd02288301..de9414598b0282 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -245,7 +245,7 @@ def MapOp : LinalgStructuredBase_Op<"map", [
}
```
- Shortened print form is available. Applies to simple maps with one
+ Shortened print form is available. Applies to simple maps with one
non-yield operation inside the body.
The example above will be printed as:
@@ -458,6 +458,7 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
::mlir::OperationState & odsState);
}];
+ let hasFolder = 1;
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e86b9762d8581f..2f6ab7e32e5872 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1785,6 +1785,22 @@ void TransposeOp::getEffects(
getDpsInits());
}
+LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
+ SmallVectorImpl<OpFoldResult> &result) {
+ // Single dimension transpose.
+ if (getPermutation().size() == 0) {
+ result.push_back(getInput());
+ return success();
+ }
+ // Identity permutation.
+ if (isIdentityPermutation(getPermutation())) {
+ result.push_back(getInput());
+ return success();
+ }
+
+ return failure();
+}
+
//===----------------------------------------------------------------------===//
// BroadcastOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 052dc367ca6779..5bc4cb82f8cdbc 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1017,3 +1017,37 @@ func.func @canonicalize_fill_to_copy_dest(%arg0 : tensor<?x?xf32>, %arg1 : tenso
%copy = linalg.copy ins(%arg1 : tensor<?x?xf32>) outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
return %copy : tensor<?x?xf32>
}
+
+// -----
+
+func.func @transpose_1d(%input: tensor<16xf32>,
+ %init: tensor<16xf32>) -> tensor<16xf32> {
+ %transpose = linalg.transpose
+ ins(%input:tensor<16xf32>)
+ outs(%init:tensor<16xf32>)
+ permutation = [0]
+ func.return %transpose : tensor<16xf32>
+}
+
+// CHECK-LABEL: func @transpose_1d(
+// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<16xf32>,
+// CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<16xf32>)
+// CHECK-NOT: linalg.transpose
+// CHECK: return %[[INPUT]] : tensor<16xf32>
+
+// -----
+
+func.func @transpose_identity_perm(%input: tensor<16x32x64xf32>,
+ %init: tensor<16x32x64xf32>) -> tensor<16x32x64xf32> {
+ %transpose = linalg.transpose
+ ins(%input:tensor<16x32x64xf32>)
+ outs(%init:tensor<16x32x64xf32>)
+ permutation = [0, 1, 2]
+ func.return %transpose : tensor<16x32x64xf32>
+}
+
+// CHECK-LABEL: func @transpose_identity_perm(
+// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<16x32x64xf32>,
+// CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<16x32x64xf32>)
+// CHECK-NOT: linalg.transpose
+// CHECK: return %[[INPUT]] : tensor<16x32x64xf32>
More information about the Mlir-commits
mailing list