[Mlir-commits] [mlir] [mlir][Linalg] Support lowerUnPack for identity out_dims_perm cases. (PR #79594)

Han-Chung Wang llvmlistbot at llvm.org
Fri Jan 26 05:16:28 PST 2024


https://github.com/hanhanW created https://github.com/llvm/llvm-project/pull/79594

None

>From 48441d434408f24b2e7a6fc2ddb471f23ca471d4 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Fri, 26 Jan 2024 21:14:16 +0800
Subject: [PATCH] [mlir][Linalg] Support lowerUnPack for identity out_dims_perm
 cases.

---
 .../Dialect/Linalg/Transforms/Transforms.cpp  |  7 ++--
 .../Dialect/Linalg/transform-lower-pack.mlir  | 35 +++++++++++++++++++
 2 files changed, 40 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 4df105af5bcd6f..02bc3e672bf7a7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -377,8 +377,11 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
 FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
                                                    tensor::UnPackOp unPackOp) {
   // 1. Filter out NYI cases.
-  if (!unPackOp.getOuterDimsPerm().empty())
-    return rewriter.notifyMatchFailure(unPackOp, "outer dims perm NYI");
+  if (!unPackOp.getOuterDimsPerm().empty() &&
+      !isIdentityPermutation(unPackOp.getOuterDimsPerm())) {
+    return rewriter.notifyMatchFailure(unPackOp,
+                                       "non-identity outer dims perm NYI");
+  }
 
   Location loc = unPackOp->getLoc();
   OpBuilder::InsertionGuard g(rewriter);
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 316df431a9c0c8..926969bfc73880 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -163,6 +163,41 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: func.func @unpack_with_identity_outer_dims_perm(
+func.func @unpack_with_identity_outer_dims_perm(%arg0: tensor<17x2x16x16x32x8xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {
+  %cst_0 = arith.constant 0.0 : f32
+  // CHECK-SAME: %[[ARG0:.*]]: tensor<17x2x16x16x32x8xf32>, %[[ARG1:.*]]: tensor<129x47x16x16xf32>
+  //      CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<17x8x2x32x16x16xf32>
+  //      CHECK: %[[TRAN:.*]] = linalg.transpose
+  // CHECK-SAME:    ins(%[[ARG0]] : tensor<17x2x16x16x32x8xf32>)
+  // CHECK-SAME:   outs(%[[EMPTY]] : tensor<17x8x2x32x16x16xf32>)
+  // CHECK-SAME:   permutation = [0, 5, 1, 4, 2, 3]
+  //      CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3], [4], [5]]
+  // CHECK-SAME:   : tensor<17x8x2x32x16x16xf32> into tensor<136x64x16x16xf32>
+  //      CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0, 0, 0] [129, 47, 16, 16] [1, 1, 1, 1]
+  // CHECK-SAME:   : tensor<136x64x16x16xf32> to tensor<129x47x16x16xf32>
+  //      CHECK: linalg.copy ins(%[[SLICE]] : tensor<129x47x16x16xf32>)
+  // CHECK-SAME:        outs(%[[ARG1]] : tensor<129x47x16x16xf32>)
+  %unpack = tensor.unpack %arg0 outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg1
+    : tensor<17x2x16x16x32x8xf32> -> tensor<129x47x16x16xf32>
+  return %unpack : tensor<129x47x16x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
+      : (!transform.any_op) -> !transform.op<"tensor.unpack">
+    transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
+      -> (!transform.op<"tensor.empty">,
+          !transform.op<"linalg.transpose">,
+          !transform.op<"tensor.collapse_shape">,
+          !transform.op<"tensor.extract_slice">)
+          transform.yield
+  }
+}
+
+// -----
+
 // When an unpack is a plain 'unpad', lower it to a simple extract_slice.
 // CHECK-LABEL: func.func @unpack_as_pad(
 func.func @unpack_as_pad(%arg0: tensor<1x1x1x1x136x64x16x16xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {



More information about the Mlir-commits mailing list