[Mlir-commits] [mlir] [MLIR] Make generic skip packing init operand when not used in DataLayoutPropagation (PR #146139)

Zhuoran Yin llvmlistbot at llvm.org
Mon Jun 30 21:04:28 PDT 2025


https://github.com/jerryyin updated https://github.com/llvm/llvm-project/pull/146139

>From bd5d5cf63061628445fe9abe4d270f18154d3404 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Tue, 17 Jun 2025 19:24:18 +0000
Subject: [PATCH 1/2] Make generic skip packing init when out not used

---
 .../Transforms/DataLayoutPropagation.cpp      | 39 ++++++++++----
 .../Linalg/data-layout-propagation.mlir       | 52 ++++++++++++++-----
 2 files changed, 69 insertions(+), 22 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 7188987e5e938..3b8ed6bfb6e6f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -358,6 +358,19 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
   return newGenericOp;
 }
 
+static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
+  Block *block = genericOp.getBody();
+  int numBlockArgs = block->getNumArguments();
+  int numDpsOuts = genericOp.getNumDpsInits();
+  int initArgStartIndex = numBlockArgs - numDpsOuts;
+  for (int i = 0; i < numDpsOuts; ++i) {
+    int matchingInitArgIndex = initArgStartIndex + i;
+    if (!block->getArgument(matchingInitArgIndex).use_empty())
+      return false;
+  }
+  return true;
+}
+
 /// Bubbles up linalg.pack op through a producer generic op. This
 /// swap pack(generic) to generic(pack). The new generic op works on packed
 /// domain; pack ops are created for input and output operands. E.g.,
@@ -470,12 +483,15 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
       getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
                                      genericOp, opOperand);
 
-  // If the dps init operand of the generic is a tensor.empty forward the pack
-  // op destination.
+  // Forward the new tensor.empty as a destination if it is one of the following
+  // situations:
+  // 1) The dps init operand is a tensor.empty.
+  // 2) The dps init is a write-only operand, i.e., it is not used in the
+  // genericOp
   Value dest = packedOutOperand;
-  if (auto initTensor = genericOp.getDpsInitOperand(0)
-                            ->get()
-                            .getDefiningOp<tensor::EmptyOp>()) {
+  auto initTensor =
+      genericOp.getDpsInitOperand(0)->get().getDefiningOp<tensor::EmptyOp>();
+  if (initTensor || isGenericOutsNotUsed(genericOp)) {
     dest = packOpDest;
   }
   // pack(unpack) isn't naively foldable because the unpack op can be from
@@ -1101,12 +1117,15 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
                                      genericOp, genericOp.getDpsInitOperand(0));
   auto destPack = packedOutOperand.getDefiningOp<linalg::PackOp>();
 
-  // If the dps init operand of the generic is a tensor.empty, do not pack it
-  // and forward the new tensor.empty as a destination.
+  // Forward the new tensor.empty as a destination if it is one of the following
+  // situations:
+  // 1) The dps init operand is a tensor.empty.
+  // 2) The dps init is a write-only operand, i.e., it is not used in the
+  // genericOp
   Value dest = packedOutOperand;
-  if (auto initTensor = genericOp.getDpsInitOperand(0)
-                            ->get()
-                            .getDefiningOp<tensor::EmptyOp>()) {
+  auto initTensor =
+      genericOp.getDpsInitOperand(0)->get().getDefiningOp<tensor::EmptyOp>();
+  if (initTensor || isGenericOutsNotUsed(genericOp)) {
     if (destPack)
       dest = destPack.getDest();
   }
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 31c9e9ed3c501..6fc8d9f152f4e 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -435,6 +435,40 @@ func.func @elem_pack_transpose_outer_dims(%arg0: tensor<128x256xi32>, %init: ten
 
 // -----
 
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+func.func @elem_pack_transpose_outer_dims_unused_init(%arg0: tensor<128x256xi32>, %init: tensor<128x256xi32>) -> tensor<16x4x32x16xi32>{
+  %elem = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]}
+      ins(%arg0 : tensor<128x256xi32>)
+      outs(%init : tensor<128x256xi32>) {
+    ^bb0(%arg3: i32, %arg4: i32):
+      %4 = arith.addi %arg3, %arg3 : i32
+      linalg.yield %4 : i32
+  } -> tensor<128x256xi32>
+  %empty = tensor.empty() : tensor<16x4x32x16xi32>
+  %pack = linalg.pack %elem
+    outer_dims_perm = [1, 0]
+    inner_dims_pos = [0, 1]
+    inner_tiles = [32, 16]
+    into %empty : tensor<128x256xi32> -> tensor<16x4x32x16xi32>
+  return %pack : tensor<16x4x32x16xi32>
+}
+
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @elem_pack_transpose_outer_dims
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK:         %[[ARG1_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
+// CHECK:         %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
+// CHECK:         %[[PACKED_ARG0:.+]] = linalg.pack %[[ARG0]]
+// CHECK-SAME:      outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:      into %[[ARG0_EMPTY]]
+// CHECK:         %[[RES:.+]] = linalg.generic
+// CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP]]]
+// CHECK-SAME:      ins(%[[PACKED_ARG0]]
+// CHECK-SAME:      outs(%[[ARG1_EMPTY]]
+
+// -----
+
 #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 
 func.func @unpack_on_output(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56x64xf32> {
@@ -497,7 +531,7 @@ func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56
 
 #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 
-func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56x56x64xf16>) -> tensor<12x56x56x64xf16> {
+func.func @unpack_element_type_change_no_use(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56x56x64xf16>) -> tensor<12x56x56x64xf16> {
   %0 = tensor.empty() : tensor<12x56x56x64xf32>
   %1 = linalg.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<12x2x56x56x32xf32> -> tensor<12x56x56x64xf32>
   %2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1: tensor<12x56x56x64xf32>) outs(%init : tensor<12x56x56x64xf16>) {
@@ -509,17 +543,14 @@ func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: t
 }
 
 // CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
-// CHECK-LABEL: func.func @unpack_element_type_change
+// CHECK-LABEL: func.func @unpack_element_type_change_no_use
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK:         %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
-// CHECK:         %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:      into %[[ARG1_PACK_EMPTY]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
 // CHECK:         %[[RES:.+]] = linalg.generic
 // CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP]]]
 // CHECK-SAME:      ins(%[[ARG0]]
-// CHECK-SAME:      outs(%[[ARG1_PACK]]
+// CHECK-SAME:      outs(%[[EMPTY]]
 // CHECK:         %[[UNPACK:.+]] = linalg.unpack %[[RES]]
 // CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[ARG1]]
@@ -1402,13 +1433,10 @@ func.func @push_unpack_in_padded_domain_foldable(%arg0: tensor<8x8x4x8xf32>, %de
 // CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]
 // CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]
-// CHECK:         %[[ARG2_PACK_EMPTY:.+]] = tensor.empty
-// CHECK:         %[[ARG2_PACK:.+]] = linalg.pack %[[ARG2]]
-// CHECK-SAME:      inner_dims_pos = [0, 1] inner_tiles = [4, 8]
-// CHECK-SAME:      into %[[ARG2_PACK_EMPTY]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty
 // CHECK:         %[[GENERIC:.+]] = linalg.generic
 // CHECK-SAME:    ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
-// CHECK-SAME:    outs(%[[ARG2_PACK]] : tensor<?x8x4x8xbf16>)
+// CHECK-SAME:    outs(%[[EMPTY]] : tensor<?x8x4x8xbf16>)
 // CHECK:         %[[UNPACK:.+]] = linalg.unpack %[[GENERIC]]
 // CHECK-SAME:    into %[[ARG2]]
 // CHECK:         return %[[UNPACK]] : tensor<?x64xbf16>

>From d543525dcaf8e9856cca494c8bf40de8fcb825ac Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Mon, 30 Jun 2025 18:05:25 +0000
Subject: [PATCH 2/2] Simplify implementation according to review feedback

---
 .../Linalg/Transforms/DataLayoutPropagation.cpp     | 13 +++----------
 1 file changed, 3 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 3b8ed6bfb6e6f..31ac87bacf267 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -359,16 +359,9 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
 }
 
 static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
-  Block *block = genericOp.getBody();
-  int numBlockArgs = block->getNumArguments();
-  int numDpsOuts = genericOp.getNumDpsInits();
-  int initArgStartIndex = numBlockArgs - numDpsOuts;
-  for (int i = 0; i < numDpsOuts; ++i) {
-    int matchingInitArgIndex = initArgStartIndex + i;
-    if (!block->getArgument(matchingInitArgIndex).use_empty())
-      return false;
-  }
-  return true;
+  return llvm::all_of(genericOp.getDpsInitsMutable(), [&](OpOperand &operand) {
+    return genericOp.getMatchingBlockArgument(&operand).use_empty();
+  });
 }
 
 /// Bubbles up linalg.pack op through a producer generic op. This



More information about the Mlir-commits mailing list