[Mlir-commits] [mlir] [mlir][linalg] Allow pack consumer fusion if the tile size is greater than dimension size. (PR #149438)

Han-Chung Wang llvmlistbot at llvm.org
Fri Jul 18 10:00:39 PDT 2025


https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/149438

>From a4ed7d89ea1bfcfdfcf08e43304bd8ae8662424f Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Thu, 17 Jul 2025 19:03:25 -0700
Subject: [PATCH 1/3] [mlir][linalg] Allow pack consumer fusion if the tile
 size is greater than dimension size.

This only happens when you use tile size which is greater than or equal
to the dimension size. In this case, it is a full slice, so it is
fusiable.

The IR can be generated during the TileAndFuse process. It is hard to
fix in such driver, so we enable the naive fusion for the case.

Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
 .../Linalg/Transforms/TilingInterfaceImpl.cpp |  2 +-
 .../tile-and-fuse-consumer.mlir               | 50 +++++++++++++++++++
 2 files changed, 51 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index b059bcc025315..7581ccd22d8ec 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -918,7 +918,7 @@ struct PackOpTiling
         int64_t destDimSize = outerShapeWithoutTranspose[dim];
         bool isTiled = failed(cstTileSize) ||
                        ShapedType::isDynamic(srcDimSize) ||
-                       cstTileSize.value() != srcDimSize;
+                       cstTileSize.value() < srcDimSize;
         if (!isTiled) {
           outerDimOffsets.push_back(offsets[dim]);
           if (ShapedType::isStatic(destDimSize)) {
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index 20164d5dfd91a..cdbca7228ded3 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -451,6 +451,56 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+#map = affine_map<(d0) -> (-d0 + 4, 16)>
+func.func @fuse_pack_consumer_if_single_iteration(%arg0: tensor<4x4xf32>) -> tensor<1x4x16x1xf32> {
+  %0 = tensor.empty() : tensor<1x4x16x1xf32>
+  %1 = tensor.empty() : tensor<4x4xf32>
+  %2 = scf.forall (%arg1) = (0) to (4) step (16) shared_outs(%arg2 = %1) -> (tensor<4x4xf32>) {
+    %3 = affine.min #map(%arg1)
+    %extracted_slice = tensor.extract_slice %arg0[%arg1, 0] [%3, 4] [1, 1] : tensor<4x4xf32> to tensor<?x4xf32>
+    %extracted_slice_0 = tensor.extract_slice %arg2[%arg1, 0] [%3, 4] [1, 1] : tensor<4x4xf32> to tensor<?x4xf32>
+    %4 = linalg.exp ins(%extracted_slice : tensor<?x4xf32>) outs(%extracted_slice_0 : tensor<?x4xf32>) -> tensor<?x4xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %4 into %arg2[%arg1, 0] [%3, 4] [1, 1] : tensor<?x4xf32> into tensor<4x4xf32>
+    }
+  }
+  %cst = arith.constant 0.000000e+00 : f32
+  %pack = linalg.pack %2 padding_value(%cst : f32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 1] into %0 : tensor<4x4xf32> -> tensor<1x4x16x1xf32>
+  return %pack : tensor<1x4x16x1xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: #[[MAP:.*]] = affine_map<(d0) -> (-d0 + 4, 16)>
+//      CHECK: func.func @fuse_pack_consumer_if_single_iteration(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
+//  CHECK-DAG:   %[[PACK_INIT:.*]] = tensor.empty() : tensor<1x4x16x1xf32>
+//  CHECK-DAG:   %[[ELEM_INIT:.*]] = tensor.empty() : tensor<4x4xf32>
+//  CHECK-DAG:   %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32
+//      CHECK:   %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (4) step (16)
+// CHECK-SAME:      shared_outs(%[[ELEM_OUT_ARG:.*]] = %[[ELEM_INIT]], %[[PACK_OUT_ARG:.*]] = %[[PACK_INIT]])
+//  CHECK-DAG:      %[[SIZE:.+]] = affine.min #[[MAP]](%[[IV]])
+//  CHECK-DAG:      %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
+//  CHECK-DAG:      %[[ELEM_DEST:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
+//      CHECK:      %[[ELEM:.*]] = linalg.exp
+// CHECK-SAME:        ins(%[[ELEM_SRC]]
+// CHECK-SAME:        outs(%[[ELEM_DEST]]
+//  CHECK-DAG:      %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[IV]], 0, 0, 0] [1, 4, 16, 1] [1, 1, 1, 1]
+//      CHECK:      %[[PACK:.*]] = linalg.pack %[[ELEM]]
+// CHECK-SAME:        padding_value(%[[PAD_VAL]] : f32)
+// CHECK-SAME:        outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 1]
+// CHECK-SAME:        into %[[TILED_PACK_DEST]]
+//      CHECK:      scf.forall.in_parallel {
+//      CHECK:          tensor.parallel_insert_slice %[[ELEM]] into %[[ELEM_OUT_ARG]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][%[[IV]], 0, 0, 0] [1, 4, 16, 1] [1, 1, 1, 1]
+
+// -----
 
 func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>, %arg2: tensor<2x64x16x1xf32>) -> tensor<2x64x16x1xf32> {
   %0 = scf.forall (%arg3) = (0) to (32) step (16) shared_outs(%arg4 = %arg1) -> (tensor<64x32xf32>) {

>From 4c992de40674753f64c6a6909b6616a3bb587e49 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Fri, 18 Jul 2025 09:44:14 -0700
Subject: [PATCH 2/3] add a comment

Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
 mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 7581ccd22d8ec..0816bcd331c9c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -911,7 +911,9 @@ struct PackOpTiling
 
         // If a dimension is not tiled, it is always valid to fuse the pack op,
         // even if the op has padding semantics. Because it always generates a
-        // full slice along the dimension.
+        // full slice along the dimension. The tile sizes are for unpacked
+        // domain, i.e., `srcDimSize`, so `tileSize < srcDimSize` means no
+        // tiling at all.
         // TODO: It could be untiled if the `srcDimSize` is dynamic. It is a
         // hard check to determine if a dimension is tiled or not.
         int64_t srcDimSize = packOp.getSourceType().getDimSize(dim);

>From 6c51a22d24339e289891299664ff31065d48e65f Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Fri, 18 Jul 2025 09:57:41 -0700
Subject: [PATCH 3/3] correct the comment

Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
 mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 0816bcd331c9c..28d99b130963a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -912,8 +912,8 @@ struct PackOpTiling
         // If a dimension is not tiled, it is always valid to fuse the pack op,
         // even if the op has padding semantics. Because it always generates a
         // full slice along the dimension. The tile sizes are for unpacked
-        // domain, i.e., `srcDimSize`, so `tileSize < srcDimSize` means no
-        // tiling at all.
+        // domain, i.e., `srcDimSize`, so `tileSize < srcDimSize` means that the
+        // dimension is tiled.
         // TODO: It could be untiled if the `srcDimSize` is dynamic. It is a
         // hard check to determine if a dimension is tiled or not.
         int64_t srcDimSize = packOp.getSourceType().getDimSize(dim);



More information about the Mlir-commits mailing list