[Mlir-commits] [mlir] 8cfd9b8 - [MLIR] Make generic skip packing init operand when not used in DataLayoutPropagation (#146139)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jul 1 06:39:34 PDT 2025


Author: Zhuoran Yin
Date: 2025-07-01T09:39:30-04:00
New Revision: 8cfd9b88215acd1bff339c0d4ed60d688dcbcfdd

URL: https://github.com/llvm/llvm-project/commit/8cfd9b88215acd1bff339c0d4ed60d688dcbcfdd
DIFF: https://github.com/llvm/llvm-project/commit/8cfd9b88215acd1bff339c0d4ed60d688dcbcfdd.diff

LOG: [MLIR] Make generic skip packing init operand when not used in DataLayoutPropagation (#146139)

In both `bubbleUpPackOpThroughGenericOp()` or
`pushDownUnPackOpThroughGenericOp()`, we can simplify the lowered IR by
removing the pack of an empty when the init tensor isn't used in generic
op. Instead of packing an empty tensor, the empty tensor can be
forwarded to the generic output. This allows cleaner result after data
layout propagation.

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
    mlir/test/Dialect/Linalg/data-layout-propagation.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 7188987e5e938..31ac87bacf267 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -358,6 +358,12 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
   return newGenericOp;
 }
 
+static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
+  return llvm::all_of(genericOp.getDpsInitsMutable(), [&](OpOperand &operand) {
+    return genericOp.getMatchingBlockArgument(&operand).use_empty();
+  });
+}
+
 /// Bubbles up linalg.pack op through a producer generic op. This
 /// swap pack(generic) to generic(pack). The new generic op works on packed
 /// domain; pack ops are created for input and output operands. E.g.,
@@ -470,12 +476,15 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
       getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
                                      genericOp, opOperand);
 
-  // If the dps init operand of the generic is a tensor.empty forward the pack
-  // op destination.
+  // Forward the new tensor.empty as a destination if it is one of the following
+  // situations:
+  // 1) The dps init operand is a tensor.empty.
+  // 2) The dps init is a write-only operand, i.e., it is not used in the
+  // genericOp
   Value dest = packedOutOperand;
-  if (auto initTensor = genericOp.getDpsInitOperand(0)
-                            ->get()
-                            .getDefiningOp<tensor::EmptyOp>()) {
+  auto initTensor =
+      genericOp.getDpsInitOperand(0)->get().getDefiningOp<tensor::EmptyOp>();
+  if (initTensor || isGenericOutsNotUsed(genericOp)) {
     dest = packOpDest;
   }
   // pack(unpack) isn't naively foldable because the unpack op can be from
@@ -1101,12 +1110,15 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
                                      genericOp, genericOp.getDpsInitOperand(0));
   auto destPack = packedOutOperand.getDefiningOp<linalg::PackOp>();
 
-  // If the dps init operand of the generic is a tensor.empty, do not pack it
-  // and forward the new tensor.empty as a destination.
+  // Forward the new tensor.empty as a destination if it is one of the following
+  // situations:
+  // 1) The dps init operand is a tensor.empty.
+  // 2) The dps init is a write-only operand, i.e., it is not used in the
+  // genericOp
   Value dest = packedOutOperand;
-  if (auto initTensor = genericOp.getDpsInitOperand(0)
-                            ->get()
-                            .getDefiningOp<tensor::EmptyOp>()) {
+  auto initTensor =
+      genericOp.getDpsInitOperand(0)->get().getDefiningOp<tensor::EmptyOp>();
+  if (initTensor || isGenericOutsNotUsed(genericOp)) {
     if (destPack)
       dest = destPack.getDest();
   }

diff  --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 31c9e9ed3c501..6fc8d9f152f4e 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -435,6 +435,40 @@ func.func @elem_pack_transpose_outer_dims(%arg0: tensor<128x256xi32>, %init: ten
 
 // -----
 
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+func.func @elem_pack_transpose_outer_dims_unused_init(%arg0: tensor<128x256xi32>, %init: tensor<128x256xi32>) -> tensor<16x4x32x16xi32>{
+  %elem = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]}
+      ins(%arg0 : tensor<128x256xi32>)
+      outs(%init : tensor<128x256xi32>) {
+    ^bb0(%arg3: i32, %arg4: i32):
+      %4 = arith.addi %arg3, %arg3 : i32
+      linalg.yield %4 : i32
+  } -> tensor<128x256xi32>
+  %empty = tensor.empty() : tensor<16x4x32x16xi32>
+  %pack = linalg.pack %elem
+    outer_dims_perm = [1, 0]
+    inner_dims_pos = [0, 1]
+    inner_tiles = [32, 16]
+    into %empty : tensor<128x256xi32> -> tensor<16x4x32x16xi32>
+  return %pack : tensor<16x4x32x16xi32>
+}
+
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @elem_pack_transpose_outer_dims
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK:         %[[ARG1_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
+// CHECK:         %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
+// CHECK:         %[[PACKED_ARG0:.+]] = linalg.pack %[[ARG0]]
+// CHECK-SAME:      outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:      into %[[ARG0_EMPTY]]
+// CHECK:         %[[RES:.+]] = linalg.generic
+// CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP]]]
+// CHECK-SAME:      ins(%[[PACKED_ARG0]]
+// CHECK-SAME:      outs(%[[ARG1_EMPTY]]
+
+// -----
+
 #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 
 func.func @unpack_on_output(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56x64xf32> {
@@ -497,7 +531,7 @@ func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56
 
 #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 
-func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56x56x64xf16>) -> tensor<12x56x56x64xf16> {
+func.func @unpack_element_type_change_no_use(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56x56x64xf16>) -> tensor<12x56x56x64xf16> {
   %0 = tensor.empty() : tensor<12x56x56x64xf32>
   %1 = linalg.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<12x2x56x56x32xf32> -> tensor<12x56x56x64xf32>
   %2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1: tensor<12x56x56x64xf32>) outs(%init : tensor<12x56x56x64xf16>) {
@@ -509,17 +543,14 @@ func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: t
 }
 
 // CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
-// CHECK-LABEL: func.func @unpack_element_type_change
+// CHECK-LABEL: func.func @unpack_element_type_change_no_use
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK:         %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
-// CHECK:         %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:      into %[[ARG1_PACK_EMPTY]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
 // CHECK:         %[[RES:.+]] = linalg.generic
 // CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP]]]
 // CHECK-SAME:      ins(%[[ARG0]]
-// CHECK-SAME:      outs(%[[ARG1_PACK]]
+// CHECK-SAME:      outs(%[[EMPTY]]
 // CHECK:         %[[UNPACK:.+]] = linalg.unpack %[[RES]]
 // CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[ARG1]]
@@ -1402,13 +1433,10 @@ func.func @push_unpack_in_padded_domain_foldable(%arg0: tensor<8x8x4x8xf32>, %de
 // CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]
 // CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]
-// CHECK:         %[[ARG2_PACK_EMPTY:.+]] = tensor.empty
-// CHECK:         %[[ARG2_PACK:.+]] = linalg.pack %[[ARG2]]
-// CHECK-SAME:      inner_dims_pos = [0, 1] inner_tiles = [4, 8]
-// CHECK-SAME:      into %[[ARG2_PACK_EMPTY]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty
 // CHECK:         %[[GENERIC:.+]] = linalg.generic
 // CHECK-SAME:    ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
-// CHECK-SAME:    outs(%[[ARG2_PACK]] : tensor<?x8x4x8xbf16>)
+// CHECK-SAME:    outs(%[[EMPTY]] : tensor<?x8x4x8xbf16>)
 // CHECK:         %[[UNPACK:.+]] = linalg.unpack %[[GENERIC]]
 // CHECK-SAME:    into %[[ARG2]]
 // CHECK:         return %[[UNPACK]] : tensor<?x64xbf16>


        


More information about the Mlir-commits mailing list