[Mlir-commits] [mlir] [mlir][linalg] Allow pack consumer fusion if the tile size is greater than dimension size. (PR #149438)

Han-Chung Wang llvmlistbot at llvm.org
Fri Jul 18 09:44:29 PDT 2025


https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/149438

>From 27ca0b5cefcdd9bdaba372c5ba298e6264a0257a Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Thu, 17 Jul 2025 19:03:25 -0700
Subject: [PATCH 1/2] [mlir][linalg] Allow pack consumer fusion if the tile
 size is greater than dimension size.

This only happens when you use tile size which is greater than or equal
to the dimension size. In this case, it is a full slice, so it is
fusiable.

The IR can be generated during the TileAndFuse process. It is hard to
fix in such driver, so we enable the naive fusion for the case.

Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
 .../Linalg/Transforms/TilingInterfaceImpl.cpp |  2 +-
 .../tile-and-fuse-consumer.mlir               | 51 +++++++++++++++++++
 2 files changed, 52 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 5a10883a6043c..eed431570452a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -911,7 +911,7 @@ struct PackOpTiling
         int64_t destDimSize = packOp.getDestType().getDimSize(dim);
         bool isTiled = failed(cstTileSize) ||
                        ShapedType::isDynamic(srcDimSize) ||
-                       cstTileSize.value() != srcDimSize;
+                       cstTileSize.value() < srcDimSize;
         if (!isTiled) {
           outerDimOffsets.push_back(offsets[dim]);
           if (ShapedType::isStatic(destDimSize)) {
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index 7b0a8494a8acb..e9465baac7509 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -451,6 +451,57 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+#map = affine_map<(d0) -> (-d0 + 4, 16)>
+func.func @fuse_pack_consumer_if_single_iteration(%arg0: tensor<4x4xf32>) -> tensor<1x4x16x1xf32> {
+  %0 = tensor.empty() : tensor<1x4x16x1xf32>
+  %1 = tensor.empty() : tensor<4x4xf32>
+  %2 = scf.forall (%arg1) = (0) to (4) step (16) shared_outs(%arg2 = %1) -> (tensor<4x4xf32>) {
+    %3 = affine.min #map(%arg1)
+    %extracted_slice = tensor.extract_slice %arg0[%arg1, 0] [%3, 4] [1, 1] : tensor<4x4xf32> to tensor<?x4xf32>
+    %extracted_slice_0 = tensor.extract_slice %arg2[%arg1, 0] [%3, 4] [1, 1] : tensor<4x4xf32> to tensor<?x4xf32>
+    %4 = linalg.exp ins(%extracted_slice : tensor<?x4xf32>) outs(%extracted_slice_0 : tensor<?x4xf32>) -> tensor<?x4xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %4 into %arg2[%arg1, 0] [%3, 4] [1, 1] : tensor<?x4xf32> into tensor<4x4xf32>
+    }
+  }
+  %cst = arith.constant 0.000000e+00 : f32
+  %pack = linalg.pack %2 padding_value(%cst : f32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 1] into %0 : tensor<4x4xf32> -> tensor<1x4x16x1xf32>
+  return %pack : tensor<1x4x16x1xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: #[[MAP:.*]] = affine_map<(d0) -> (-d0 + 4, 16)>
+//      CHECK: func.func @fuse_pack_consumer_if_single_iteration(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
+//  CHECK-DAG:   %[[PACK_INIT:.*]] = tensor.empty() : tensor<1x4x16x1xf32>
+//  CHECK-DAG:   %[[ELEM_INIT:.*]] = tensor.empty() : tensor<4x4xf32>
+//  CHECK-DAG:   %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32
+//      CHECK:   %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (4) step (16)
+// CHECK-SAME:      shared_outs(%[[ELEM_OUT_ARG:.*]] = %[[ELEM_INIT]], %[[PACK_OUT_ARG:.*]] = %[[PACK_INIT]])
+//  CHECK-DAG:      %[[SIZE:.+]] = affine.min #[[MAP]](%[[IV]])
+//  CHECK-DAG:      %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
+//  CHECK-DAG:      %[[ELEM_DEST:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
+//      CHECK:      %[[ELEM:.*]] = linalg.exp
+// CHECK-SAME:        ins(%[[ELEM_SRC]]
+// CHECK-SAME:        outs(%[[ELEM_DEST]]
+//  CHECK-DAG:      %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[IV]], 0, 0, 0] [1, 4, 16, 1] [1, 1, 1, 1]
+//      CHECK:      %[[PACK:.*]] = linalg.pack %[[ELEM]]
+// CHECK-SAME:        padding_value(%[[PAD_VAL]] : f32)
+// CHECK-SAME:        outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 1]
+// CHECK-SAME:        into %[[TILED_PACK_DEST]]
+//      CHECK:      scf.forall.in_parallel {
+//      CHECK:          tensor.parallel_insert_slice %[[ELEM]] into %[[ELEM_OUT_ARG]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][%[[IV]], 0, 0, 0] [1, 4, 16, 1] [1, 1, 1, 1]
+
+// -----
+
 // It is valid to fuse the pack op in perfect tiling scenario when the dimension
 // is dynamic and padding is not needed.
 

>From bf1dd3b9a44715f35b1c2f61295530f3521e3b4f Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Fri, 18 Jul 2025 09:44:14 -0700
Subject: [PATCH 2/2] add a comment

Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
 mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index eed431570452a..4aee94ee3e652 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -904,7 +904,9 @@ struct PackOpTiling
 
         // If a dimension is not tiled, it is always valid to fuse the pack op,
         // even if the op has padding semantics. Because it always generates a
-        // full slice along the dimension.
+        // full slice along the dimension. The tile sizes are for unpacked
+        // domain, i.e., `srcDimSize`, so `tileSize < srcDimSize` means no
+        // tiling at all.
         // TODO: It could be untiled if the `srcDimSize` is dynamic. It is a
         // hard check to determine if a dimension is tiled or not.
         int64_t srcDimSize = packOp.getSourceType().getDimSize(dim);



More information about the Mlir-commits mailing list