[Mlir-commits] [mlir] [mlir][Linalg] Support lowerUnPack for identity out_dims_perm cases. (PR #79594)
Han-Chung Wang
llvmlistbot at llvm.org
Fri Jan 26 05:16:28 PST 2024
https://github.com/hanhanW created https://github.com/llvm/llvm-project/pull/79594
None
>From 48441d434408f24b2e7a6fc2ddb471f23ca471d4 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Fri, 26 Jan 2024 21:14:16 +0800
Subject: [PATCH] [mlir][Linalg] Support lowerUnPack for identity out_dims_perm
cases.
---
.../Dialect/Linalg/Transforms/Transforms.cpp | 7 ++--
.../Dialect/Linalg/transform-lower-pack.mlir | 35 +++++++++++++++++++
2 files changed, 40 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 4df105af5bcd6f..02bc3e672bf7a7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -377,8 +377,11 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
tensor::UnPackOp unPackOp) {
// 1. Filter out NYI cases.
- if (!unPackOp.getOuterDimsPerm().empty())
- return rewriter.notifyMatchFailure(unPackOp, "outer dims perm NYI");
+ if (!unPackOp.getOuterDimsPerm().empty() &&
+ !isIdentityPermutation(unPackOp.getOuterDimsPerm())) {
+ return rewriter.notifyMatchFailure(unPackOp,
+ "non-identity outer dims perm NYI");
+ }
Location loc = unPackOp->getLoc();
OpBuilder::InsertionGuard g(rewriter);
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 316df431a9c0c8..926969bfc73880 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -163,6 +163,41 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: func.func @unpack_with_identity_outer_dims_perm(
+func.func @unpack_with_identity_outer_dims_perm(%arg0: tensor<17x2x16x16x32x8xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {
+ %cst_0 = arith.constant 0.0 : f32
+ // CHECK-SAME: %[[ARG0:.*]]: tensor<17x2x16x16x32x8xf32>, %[[ARG1:.*]]: tensor<129x47x16x16xf32>
+ // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<17x8x2x32x16x16xf32>
+ // CHECK: %[[TRAN:.*]] = linalg.transpose
+ // CHECK-SAME: ins(%[[ARG0]] : tensor<17x2x16x16x32x8xf32>)
+ // CHECK-SAME: outs(%[[EMPTY]] : tensor<17x8x2x32x16x16xf32>)
+ // CHECK-SAME: permutation = [0, 5, 1, 4, 2, 3]
+ // CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3], [4], [5]]
+ // CHECK-SAME: : tensor<17x8x2x32x16x16xf32> into tensor<136x64x16x16xf32>
+ // CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0, 0, 0] [129, 47, 16, 16] [1, 1, 1, 1]
+ // CHECK-SAME: : tensor<136x64x16x16xf32> to tensor<129x47x16x16xf32>
+ // CHECK: linalg.copy ins(%[[SLICE]] : tensor<129x47x16x16xf32>)
+ // CHECK-SAME: outs(%[[ARG1]] : tensor<129x47x16x16xf32>)
+ %unpack = tensor.unpack %arg0 outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg1
+ : tensor<17x2x16x16x32x8xf32> -> tensor<129x47x16x16xf32>
+ return %unpack : tensor<129x47x16x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
+ : (!transform.any_op) -> !transform.op<"tensor.unpack">
+ transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
+ -> (!transform.op<"tensor.empty">,
+ !transform.op<"linalg.transpose">,
+ !transform.op<"tensor.collapse_shape">,
+ !transform.op<"tensor.extract_slice">)
+ transform.yield
+ }
+}
+
+// -----
+
// When an unpack is a plain 'unpad', lower it to a simple extract_slice.
// CHECK-LABEL: func.func @unpack_as_pad(
func.func @unpack_as_pad(%arg0: tensor<1x1x1x1x136x64x16x16xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {
More information about the Mlir-commits
mailing list