[Mlir-commits] [mlir] [mlir] Reuse pack dest in tensor.pack decomposition (PR #108025)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 10 06:32:52 PDT 2024


https://github.com/Max191 created https://github.com/llvm/llvm-project/pull/108025

In the `lowerPack` transform, there is a special case for lowering into a simple `tensor.pad` + `tensor.insert_slice`, but the destination becomes a newly created `tensor.empty`. This PR fixes the transform to reuse the original destination of the `tensor.pack`.

>From 241ad8e6a49ca9e304738aa0bebfca81956fb92d Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Fri, 6 Sep 2024 12:10:11 -0400
Subject: [PATCH] [mlir] Reuse pack dest in tensor.pack decomposition

---
 mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp  |  7 ++-----
 mlir/test/Dialect/Linalg/transform-lower-pack.mlir | 14 ++++++++------
 2 files changed, 10 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 0e5e563ed5450a..77f0ea9d2236ea 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -305,8 +305,6 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
     if (rankReduces == SliceVerificationResult::Success) {
       // This pack is just a plain pad.
       // Just insert the pad in the higher ranked tensor.
-      auto emptyOp =
-          rewriter.create<tensor::EmptyOp>(loc, packedTensorType, ValueRange{});
       // Offsets.
       SmallVector<OpFoldResult> zeros(packOp.getDestRank(),
                                       rewriter.getIndexAttr(0));
@@ -317,9 +315,8 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
           tensor::getMixedSizes(rewriter, loc, packOp.getDest());
 
       auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
-          loc, /*source=*/padOp, /*dest=*/emptyOp,
-          /*offsets=*/zeros, sizes,
-          /*strides=*/ones);
+          loc, /*source=*/padOp, /*dest=*/packOp.getDest(),
+          /*offsets=*/zeros, sizes, /*strides=*/ones);
 
       LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL(););
 
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index f34ef4f961483d..48bf1c151de8f5 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -62,14 +62,15 @@ module attributes {transform.with_named_sequence} {
 // -----
 
 // CHECK-LABEL: func.func @pack_as_pad(
+// CHECK: %[[SRC:.+]]: tensor<129x47x16x16xf32>,
+// CHECK: %[[OUT:.+]]: tensor<1x1x1x1x136x64x16x16xf32>)
 func.func @pack_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
   %cst_0 = arith.constant 0.0 : f32
 
   // tensor.pack is lowered to tensor.pad + tensor.insert_slice
-  //      CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[0, 0, 0, 0]
+  //      CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0, 0, 0] high[7, 17, 0, 0]
   //      CHECK:   : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
-  //      CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x1x1x136x64x16x16xf32>
-  //      CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[EMPTY]]
+  //      CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[OUT]]
   // offsets.
   // CHECK-SAME:   [0, 0, 0, 0, 0, 0, 0, 0]
   // sizes.
@@ -387,14 +388,15 @@ module attributes {transform.with_named_sequence} {
 // -----
 
 // CHECK-LABEL: func.func @pack_as_pad_with_outer_dims_perm(
+// CHECK: %[[SRC:.+]]: tensor<129x47x16x16xf32>,
+// CHECK: %[[OUT:.+]]: tensor<1x1x1x1x136x64x16x16xf32>)
 func.func @pack_as_pad_with_outer_dims_perm(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
   %cst_0 = arith.constant 0.0 : f32
 
   // tensor.pack is lowered to tensor.pad + tensor.insert_slice
-  //      CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[0, 0, 0, 0]
+  //      CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0, 0, 0] high[7, 17, 0, 0]
   //      CHECK:   : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
-  //      CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x1x1x136x64x16x16xf32>
-  //      CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[EMPTY]]
+  //      CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[OUT]]
   // offsets.
   // CHECK-SAME:   [0, 0, 0, 0, 0, 0, 0, 0]
   // sizes.



More information about the Mlir-commits mailing list