[Mlir-commits] [mlir] [MLIR] Make generic skip packing init operand when not used in DataLayoutPropagation (PR #146139)
Zhuoran Yin
llvmlistbot at llvm.org
Mon Jun 30 21:04:28 PDT 2025
https://github.com/jerryyin updated https://github.com/llvm/llvm-project/pull/146139
>From bd5d5cf63061628445fe9abe4d270f18154d3404 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Tue, 17 Jun 2025 19:24:18 +0000
Subject: [PATCH 1/2] Make generic skip packing init when out not used
---
.../Transforms/DataLayoutPropagation.cpp | 39 ++++++++++----
.../Linalg/data-layout-propagation.mlir | 52 ++++++++++++++-----
2 files changed, 69 insertions(+), 22 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 7188987e5e938..3b8ed6bfb6e6f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -358,6 +358,19 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
return newGenericOp;
}
+static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
+ Block *block = genericOp.getBody();
+ int numBlockArgs = block->getNumArguments();
+ int numDpsOuts = genericOp.getNumDpsInits();
+ int initArgStartIndex = numBlockArgs - numDpsOuts;
+ for (int i = 0; i < numDpsOuts; ++i) {
+ int matchingInitArgIndex = initArgStartIndex + i;
+ if (!block->getArgument(matchingInitArgIndex).use_empty())
+ return false;
+ }
+ return true;
+}
+
/// Bubbles up linalg.pack op through a producer generic op. This
/// swap pack(generic) to generic(pack). The new generic op works on packed
/// domain; pack ops are created for input and output operands. E.g.,
@@ -470,12 +483,15 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
genericOp, opOperand);
- // If the dps init operand of the generic is a tensor.empty forward the pack
- // op destination.
+ // Forward the new tensor.empty as a destination if it is one of the following
+ // situations:
+ // 1) The dps init operand is a tensor.empty.
+ // 2) The dps init is a write-only operand, i.e., it is not used in the
+ // genericOp
Value dest = packedOutOperand;
- if (auto initTensor = genericOp.getDpsInitOperand(0)
- ->get()
- .getDefiningOp<tensor::EmptyOp>()) {
+ auto initTensor =
+ genericOp.getDpsInitOperand(0)->get().getDefiningOp<tensor::EmptyOp>();
+ if (initTensor || isGenericOutsNotUsed(genericOp)) {
dest = packOpDest;
}
// pack(unpack) isn't naively foldable because the unpack op can be from
@@ -1101,12 +1117,15 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
genericOp, genericOp.getDpsInitOperand(0));
auto destPack = packedOutOperand.getDefiningOp<linalg::PackOp>();
- // If the dps init operand of the generic is a tensor.empty, do not pack it
- // and forward the new tensor.empty as a destination.
+ // Forward the new tensor.empty as a destination if it is one of the following
+ // situations:
+ // 1) The dps init operand is a tensor.empty.
+ // 2) The dps init is a write-only operand, i.e., it is not used in the
+ // genericOp
Value dest = packedOutOperand;
- if (auto initTensor = genericOp.getDpsInitOperand(0)
- ->get()
- .getDefiningOp<tensor::EmptyOp>()) {
+ auto initTensor =
+ genericOp.getDpsInitOperand(0)->get().getDefiningOp<tensor::EmptyOp>();
+ if (initTensor || isGenericOutsNotUsed(genericOp)) {
if (destPack)
dest = destPack.getDest();
}
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 31c9e9ed3c501..6fc8d9f152f4e 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -435,6 +435,40 @@ func.func @elem_pack_transpose_outer_dims(%arg0: tensor<128x256xi32>, %init: ten
// -----
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+func.func @elem_pack_transpose_outer_dims_unused_init(%arg0: tensor<128x256xi32>, %init: tensor<128x256xi32>) -> tensor<16x4x32x16xi32>{
+ %elem = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : tensor<128x256xi32>)
+ outs(%init : tensor<128x256xi32>) {
+ ^bb0(%arg3: i32, %arg4: i32):
+ %4 = arith.addi %arg3, %arg3 : i32
+ linalg.yield %4 : i32
+ } -> tensor<128x256xi32>
+ %empty = tensor.empty() : tensor<16x4x32x16xi32>
+ %pack = linalg.pack %elem
+ outer_dims_perm = [1, 0]
+ inner_dims_pos = [0, 1]
+ inner_tiles = [32, 16]
+ into %empty : tensor<128x256xi32> -> tensor<16x4x32x16xi32>
+ return %pack : tensor<16x4x32x16xi32>
+}
+
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @elem_pack_transpose_outer_dims
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK: %[[ARG1_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
+// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
+// CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME: into %[[ARG0_EMPTY]]
+// CHECK: %[[RES:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
+// CHECK-SAME: ins(%[[PACKED_ARG0]]
+// CHECK-SAME: outs(%[[ARG1_EMPTY]]
+
+// -----
+
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
func.func @unpack_on_output(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56x64xf32> {
@@ -497,7 +531,7 @@ func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56x56x64xf16>) -> tensor<12x56x56x64xf16> {
+func.func @unpack_element_type_change_no_use(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56x56x64xf16>) -> tensor<12x56x56x64xf16> {
%0 = tensor.empty() : tensor<12x56x56x64xf32>
%1 = linalg.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<12x2x56x56x32xf32> -> tensor<12x56x56x64xf32>
%2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1: tensor<12x56x56x64xf32>) outs(%init : tensor<12x56x56x64xf16>) {
@@ -509,17 +543,14 @@ func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: t
}
// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
-// CHECK-LABEL: func.func @unpack_element_type_change
+// CHECK-LABEL: func.func @unpack_element_type_change_no_use
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
-// CHECK: %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME: into %[[ARG1_PACK_EMPTY]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
// CHECK-SAME: ins(%[[ARG0]]
-// CHECK-SAME: outs(%[[ARG1_PACK]]
+// CHECK-SAME: outs(%[[EMPTY]]
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG1]]
@@ -1402,13 +1433,10 @@ func.func @push_unpack_in_padded_domain_foldable(%arg0: tensor<8x8x4x8xf32>, %de
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
-// CHECK: %[[ARG2_PACK_EMPTY:.+]] = tensor.empty
-// CHECK: %[[ARG2_PACK:.+]] = linalg.pack %[[ARG2]]
-// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [4, 8]
-// CHECK-SAME: into %[[ARG2_PACK_EMPTY]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
-// CHECK-SAME: outs(%[[ARG2_PACK]] : tensor<?x8x4x8xbf16>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x8x4x8xbf16>)
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[GENERIC]]
// CHECK-SAME: into %[[ARG2]]
// CHECK: return %[[UNPACK]] : tensor<?x64xbf16>
>From d543525dcaf8e9856cca494c8bf40de8fcb825ac Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Mon, 30 Jun 2025 18:05:25 +0000
Subject: [PATCH 2/2] Simplify implementation according to review feedback
---
.../Linalg/Transforms/DataLayoutPropagation.cpp | 13 +++----------
1 file changed, 3 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 3b8ed6bfb6e6f..31ac87bacf267 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -359,16 +359,9 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
}
static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
- Block *block = genericOp.getBody();
- int numBlockArgs = block->getNumArguments();
- int numDpsOuts = genericOp.getNumDpsInits();
- int initArgStartIndex = numBlockArgs - numDpsOuts;
- for (int i = 0; i < numDpsOuts; ++i) {
- int matchingInitArgIndex = initArgStartIndex + i;
- if (!block->getArgument(matchingInitArgIndex).use_empty())
- return false;
- }
- return true;
+ return llvm::all_of(genericOp.getDpsInitsMutable(), [&](OpOperand &operand) {
+ return genericOp.getMatchingBlockArgument(&operand).use_empty();
+ });
}
/// Bubbles up linalg.pack op through a producer generic op. This
More information about the Mlir-commits
mailing list