[Mlir-commits] [mlir] [mlir][linalg] Improve linalg.pack consumer fusion. (PR #148993)

Han-Chung Wang llvmlistbot at llvm.org
Thu Jul 17 12:28:30 PDT 2025


https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/148993

>From 3642d259c3ece69cbc41ab74af863d6b4b221839 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Tue, 15 Jul 2025 16:50:31 -0700
Subject: [PATCH 1/4] [mlir][linalg] Improve linalg.pack consumer fusion.

If a dimension is not tiled, it is always valid to to fuse the pack op
even if it has padding semantics. Because it always generates a full
slice along the dimension.

Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
 .../Linalg/Transforms/TilingInterfaceImpl.cpp |  41 +--
 .../tile-and-fuse-consumer.mlir               | 278 ++++++++++--------
 2 files changed, 184 insertions(+), 135 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 513cecef29b61..fb9ba4ccb14af 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -887,26 +887,13 @@ struct PackOpTiling
 
     ArrayRef<OpFoldResult> offsets(allOffsets[0]);
     ArrayRef<OpFoldResult> sizes(allSizes[0]);
-
     auto packOp = cast<PackOp>(op);
-    // It is not trivial to infer dest tile from source tile if `packOp` has
-    // padding semantic.
-    if (packOp.getPaddingValue())
-      return failure();
-
     Location loc = packOp.getLoc();
-
     SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
     DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
         packOp.getDimAndTileMapping();
     for (auto dim : llvm::seq<int64_t>(packOp.getSourceRank())) {
       if (dimAndTileMapping.count(dim)) {
-        FailureOr<int64_t> cstSize =
-            ValueBoundsConstraintSet::computeConstantBound(
-                presburger::BoundType::UB, sizes[dim],
-                /*stopCondition=*/nullptr, /*closedUB=*/true);
-        std::optional<int64_t> cstInnerSize =
-            getConstantIntValue(dimAndTileMapping[dim]);
         // Currently fusing `packOp` as consumer only expects perfect tiling
         // scenario because even if without padding semantic, the `packOp` may
         // also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>,
@@ -916,12 +903,25 @@ struct PackOpTiling
         // (0,0)~(0,4) at first row.
         // 2. the second slice is extracted from (5) to (9) and SHOULD BE
         // respectively inserted into two rows with different length, including
-        // first row: (0,5) and second row (1,0)~(1,3). It is hard to coordinate
-        // them, thus adding below constraint to bypass them temporarily. In
-        // another word, we can only support tiling with consumer if the tile
-        // size for the producer is a multiple of the inner tile size for the
-        // packed dimensions at this moment.
-        if (failed(cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0) {
+        // first row: (0,5) and second row (1,0)~(1,3).
+        // It is hard to coordinate them, thus adding below constraint to bypass
+        // them temporarily. In another word, we can only support tiling with
+        // consumer if the tile size for the producer is either a multiple of
+        // the inner tile size for the packed dimensions or the dimension is not
+        // tiled at this moment.
+        FailureOr<int64_t> cstTileSize =
+            ValueBoundsConstraintSet::computeConstantBound(
+                presburger::BoundType::UB, sizes[dim],
+                /*stopCondition=*/nullptr, /*closedUB=*/true);
+        std::optional<int64_t> cstInnerSize =
+            getConstantIntValue(dimAndTileMapping[dim]);
+        int64_t dimSize = packOp.getSourceType().getDimSize(dim);
+        // TODO: It could be untiled if the `dimSize` is dynamic. It is a hard
+        // check to determine if a dimension is tiled or not.
+        bool isTiled = failed(cstTileSize) || ShapedType::isDynamic(dimSize) ||
+                       cstTileSize.value() != dimSize;
+        if (isTiled && (failed(cstTileSize) || !cstInnerSize ||
+                        *cstTileSize % *cstInnerSize != 0)) {
           return failure();
         }
 
@@ -988,7 +988,8 @@ struct PackOpTiling
         loc, packOp.getDest(), outputOffsets, outputSizes, strides);
     tiledOperands.push_back(outSlice);
 
-    assert(!packOp.getPaddingValue() && "Expect no padding semantic");
+    if (auto val = packOp.getPaddingValue())
+      tiledOperands.push_back(val);
     for (auto tile : packOp.getInnerTiles())
       tiledOperands.push_back(tile);
 
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index d09373bdb3f14..da3592547e125 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -193,33 +193,33 @@ module attributes {transform.with_named_sequence} {
 
 #map = affine_map<(d0, d1) -> (d0, d1)>
 module {
-    func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>, %arg3: tensor<64x32xf32>) -> (tensor<64x64xf32>, tensor<2048xf32>) {
-      %c4 = arith.constant 4 : index
-      %c64 = arith.constant 64 : index
-      %c0 = arith.constant 0 : index
-      %0:2 = scf.forall (%arg4, %arg5) in (2, 2) shared_outs(%arg6 = %arg3, %arg7 = %arg2) -> (tensor<64x32xf32>, tensor<64x64xf32>) {
-        %extracted_slice = tensor.extract_slice %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
-        %extracted_slice_0 = tensor.extract_slice %arg7[%arg4, %arg5] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
-        %6 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32>
-        scf.forall.in_parallel {
-          tensor.parallel_insert_slice %6 into %arg7[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
-          tensor.parallel_insert_slice %extracted_slice_0 into %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
-        }
+  func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>, %arg3: tensor<64x32xf32>) -> (tensor<64x64xf32>, tensor<2048xf32>) {
+    %c4 = arith.constant 4 : index
+    %c64 = arith.constant 64 : index
+    %c0 = arith.constant 0 : index
+    %0:2 = scf.forall (%arg4, %arg5) in (2, 2) shared_outs(%arg6 = %arg3, %arg7 = %arg2) -> (tensor<64x32xf32>, tensor<64x64xf32>) {
+      %extracted_slice = tensor.extract_slice %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
+      %extracted_slice_0 = tensor.extract_slice %arg7[%arg4, %arg5] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
+      %6 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32>
+      scf.forall.in_parallel {
+        tensor.parallel_insert_slice %6 into %arg7[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+        tensor.parallel_insert_slice %extracted_slice_0 into %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
       }
-      %1 = tensor.empty() : tensor<64x64xf32>
-      %2 = tensor.empty() : tensor<64x64xf32>
-      %3 = tensor.empty() : tensor<64x64xf32>
-      %4:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%0#1, %1 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%2, %3 : tensor<64x64xf32>, tensor<64x64xf32>) {
-      ^bb0(%in: f32, %in_0: f32, %out: f32, %out_1: f32):
-        %6 = arith.mulf %in, %in_0 : f32
-        %7 = arith.subf %out, %6 : f32
-        %8 = arith.addf %out_1, %in : f32
-        linalg.yield %7, %8 : f32, f32
-      } -> (tensor<64x64xf32>, tensor<64x64xf32>)
-      %5 = tensor.empty() : tensor<2048xf32>
-      %unpack = linalg.unpack %0#0 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %5 : tensor<64x32xf32> -> tensor<2048xf32>
-      return %4#1, %unpack : tensor<64x64xf32>, tensor<2048xf32>
     }
+    %1 = tensor.empty() : tensor<64x64xf32>
+    %2 = tensor.empty() : tensor<64x64xf32>
+    %3 = tensor.empty() : tensor<64x64xf32>
+    %4:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%0#1, %1 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%2, %3 : tensor<64x64xf32>, tensor<64x64xf32>) {
+    ^bb0(%in: f32, %in_0: f32, %out: f32, %out_1: f32):
+      %6 = arith.mulf %in, %in_0 : f32
+      %7 = arith.subf %out, %6 : f32
+      %8 = arith.addf %out_1, %in : f32
+      linalg.yield %7, %8 : f32, f32
+    } -> (tensor<64x64xf32>, tensor<64x64xf32>)
+    %5 = tensor.empty() : tensor<2048xf32>
+    %unpack = linalg.unpack %0#0 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %5 : tensor<64x32xf32> -> tensor<2048xf32>
+    return %4#1, %unpack : tensor<64x64xf32>, tensor<2048xf32>
+  }
 }
 
 module attributes {transform.with_named_sequence} {
@@ -269,38 +269,38 @@ module attributes {transform.with_named_sequence} {
 
 #map = affine_map<(d0, d1) -> (d0, d1)>
 module {
-    func.func @fuse_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2048xf32> {
-        %c4 = arith.constant 4 : index
-        %c64 = arith.constant 64 : index
-        %c0 = arith.constant 0 : index
-        %1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
-            %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
-            %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
-                ^bb0(%in: f32, %in_16: f32, %out: f32):
-                %13 = arith.mulf %in, %in_16 : f32
-                %14 = arith.addf %out, %13 : f32
-                linalg.yield %14 : f32
-            } -> tensor<32x32xf32>
-            scf.forall.in_parallel {
-                tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
-            }
-        }
-        %output = tensor.empty() : tensor<2048xf32>
-        %unpack = linalg.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2048xf32>
-        return %unpack : tensor<2048xf32>
+  func.func @fuse_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2048xf32> {
+    %c4 = arith.constant 4 : index
+    %c64 = arith.constant 64 : index
+    %c0 = arith.constant 0 : index
+    %1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
+      %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
+      %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
+        ^bb0(%in: f32, %in_16: f32, %out: f32):
+        %13 = arith.mulf %in, %in_16 : f32
+        %14 = arith.addf %out, %13 : f32
+        linalg.yield %14 : f32
+      } -> tensor<32x32xf32>
+      scf.forall.in_parallel {
+        tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
+      }
     }
+    %output = tensor.empty() : tensor<2048xf32>
+    %unpack = linalg.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2048xf32>
+    return %unpack : tensor<2048xf32>
+  }
 }
-  
+
 module attributes {transform.with_named_sequence} {
-    transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
-        %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
-        : (!transform.any_op) -> !transform.any_op
-        %loop = transform.structured.match ops{["scf.forall"]} in %arg1
-        : (!transform.any_op) -> !transform.any_op
-        %a, %b = transform.test.fuse_consumer %slice_op in (%loop)
-        : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-        transform.yield
-    }
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+    : (!transform.any_op) -> !transform.any_op
+    %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+    : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_consumer %slice_op in (%loop)
+    : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
 }
 //  CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)>
 //  CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2048)>
@@ -332,38 +332,38 @@ module attributes {transform.with_named_sequence} {
 
 #map = affine_map<(d0, d1) -> (d0, d1)>
 module {
-    func.func @fuse_unaligned_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2047xf32> {
-        %c4 = arith.constant 4 : index
-        %c64 = arith.constant 64 : index
-        %c0 = arith.constant 0 : index
-        %1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
-            %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
-            %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
-                ^bb0(%in: f32, %in_16: f32, %out: f32):
-                %13 = arith.mulf %in, %in_16 : f32
-                %14 = arith.addf %out, %13 : f32
-                linalg.yield %14 : f32
-            } -> tensor<32x32xf32>
-            scf.forall.in_parallel {
-                tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
-            }
-        }
-        %output = tensor.empty() : tensor<2047xf32>
-        %unpack = linalg.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2047xf32>
-        return %unpack : tensor<2047xf32>
+  func.func @fuse_unaligned_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2047xf32> {
+    %c4 = arith.constant 4 : index
+    %c64 = arith.constant 64 : index
+    %c0 = arith.constant 0 : index
+    %1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
+      %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
+      %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
+        ^bb0(%in: f32, %in_16: f32, %out: f32):
+        %13 = arith.mulf %in, %in_16 : f32
+        %14 = arith.addf %out, %13 : f32
+        linalg.yield %14 : f32
+      } -> tensor<32x32xf32>
+      scf.forall.in_parallel {
+        tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
+      }
     }
+    %output = tensor.empty() : tensor<2047xf32>
+    %unpack = linalg.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2047xf32>
+    return %unpack : tensor<2047xf32>
+  }
 }
-  
+
 module attributes {transform.with_named_sequence} {
-    transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
-        %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
-        : (!transform.any_op) -> !transform.any_op
-        %loop = transform.structured.match ops{["scf.forall"]} in %arg1
-        : (!transform.any_op) -> !transform.any_op
-        %a, %b = transform.test.fuse_consumer %slice_op in (%loop)
-        : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-        transform.yield
-    }
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+    : (!transform.any_op) -> !transform.any_op
+    %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+    : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_consumer %slice_op in (%loop)
+    : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
 }
 //  CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)>
 //  CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2047)>
@@ -395,38 +395,38 @@ module attributes {transform.with_named_sequence} {
 
 #map = affine_map<(d0, d1) -> (d0, d1)>
 module {
-    func.func @fuse_pack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<4x32x16xf32> {
-        %c4 = arith.constant 4 : index
-        %c64 = arith.constant 64 : index
-        %c0 = arith.constant 0 : index
-        %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
-            %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
-            %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
-                ^bb0(%in: f32, %in_16: f32, %out: f32):
-                %13 = arith.mulf %in, %in_16 : f32
-                %14 = arith.addf %out, %13 : f32
-                linalg.yield %14 : f32
-            } -> tensor<32x32xf32>
-            scf.forall.in_parallel {
-                tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
-            }
-        }
-        %output = tensor.empty() : tensor<4x32x16xf32>
-        %pack = linalg.pack %1 inner_dims_pos = [0] inner_tiles = [16] into %output : tensor<64x32xf32> -> tensor<4x32x16xf32>
-        return %pack : tensor<4x32x16xf32>
+  func.func @fuse_pack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<4x32x16xf32> {
+    %c4 = arith.constant 4 : index
+    %c64 = arith.constant 64 : index
+    %c0 = arith.constant 0 : index
+    %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
+      %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
+      %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
+        ^bb0(%in: f32, %in_16: f32, %out: f32):
+        %13 = arith.mulf %in, %in_16 : f32
+        %14 = arith.addf %out, %13 : f32
+        linalg.yield %14 : f32
+      } -> tensor<32x32xf32>
+      scf.forall.in_parallel {
+        tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
+      }
     }
+    %output = tensor.empty() : tensor<4x32x16xf32>
+    %pack = linalg.pack %1 inner_dims_pos = [0] inner_tiles = [16] into %output : tensor<64x32xf32> -> tensor<4x32x16xf32>
+    return %pack : tensor<4x32x16xf32>
+  }
 }
-  
+
 module attributes {transform.with_named_sequence} {
-    transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
-        %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
-        : (!transform.any_op) -> !transform.any_op
-        %loop = transform.structured.match ops{["scf.forall"]} in %arg1
-        : (!transform.any_op) -> !transform.any_op
-        %a, %b = transform.test.fuse_consumer %slice_op in (%loop)
-        : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-        transform.yield
-    }
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+    : (!transform.any_op) -> !transform.any_op
+    %loop = transform.structured.match ops{["scf.forall"]} in %arg1
+    : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_consumer %slice_op in (%loop)
+    : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
 }
 //      CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
 //      CHECK: func.func @fuse_pack_consumer_into_scf_forall(
@@ -451,6 +451,54 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @fuse_pack_consumer_with_padding_semantic_into_scf_forall(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x32x3x16xf32> {
+  %0 = scf.forall (%arg2) in (2) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
+    %extracted_slice = tensor.extract_slice %arg0[0, %arg2] [64, 32] [1, 1] : tensor<64x32xf32> to tensor<64x32xf32>
+    %extracted_slice_0 = tensor.extract_slice %arg3[0, %arg2] [64, 32] [1, 1] : tensor<64x32xf32> to tensor<64x32xf32>
+    %2 = linalg.exp ins(%extracted_slice : tensor<64x32xf32>) outs(%extracted_slice_0 : tensor<64x32xf32>) -> tensor<64x32xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 32] [1, 1] : tensor<64x32xf32> into tensor<64x32xf32>
+    }
+  }
+  %1 = tensor.empty() : tensor<23x32x3x16xf32>
+  %cst = arith.constant 0.000000e+00 : f32
+  %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<23x32x3x16xf32>
+  return %pack : tensor<23x32x3x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
+//      CHECK: func.func @fuse_pack_consumer_with_padding_semantic_into_scf_forall(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
+//  CHECK-DAG:   %[[OUT_INIT:.*]] = tensor.empty() : tensor<23x32x3x16xf32>
+//  CHECK-DAG:   %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32
+//      CHECK:   %{{.*}}:2 = scf.forall (%[[IV:.*]]) in (2)
+// CHECK-SAME:      shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]])
+//      CHECK:      %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 32] [1, 1]
+//      CHECK:      %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 32] [1, 1]
+//      CHECK:      %[[ELEM:.*]] = linalg.exp
+// CHECK-SAME:        ins(%[[ELEM_SRC]]
+// CHECK-SAME:        outs(%[[ELEM_DEST]]
+//  CHECK-DAG:      %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]])
+//  CHECK-DAG:      %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 2, 3, 16] [1, 1, 1, 1]
+//      CHECK:      %[[TILED_PACK_OUT:.*]] = linalg.pack %[[ELEM]]
+// CHECK-SAME:        padding_value(%[[PAD_VAL]] : f32)
+// CHECK-SAME:        inner_dims_pos = [0, 1] inner_tiles = [3, 16]
+// CHECK-SAME:        into %[[TILED_PACK_DEST]]
+//      CHECK:      scf.forall.in_parallel {
+//      CHECK:          tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 32] [1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 2, 3, 16] [1, 1, 1, 1]
+
+// -----
+
 module {
   func.func @fuse_add_multiple_tilable_consumers(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256x256xf32>) {
     %c0 = arith.constant 0 : index
@@ -489,7 +537,7 @@ module attributes {transform.with_named_sequence} {
 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
 //      CHECK:   %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
 //      CHECK:   %[[LOOP_RESULT:.*]]:3 = scf.for %[[IV1:.*]] = %[[C0]]
-// CHECK-SAME:       iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]]) 
+// CHECK-SAME:       iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]])
 // CHECK-SAME:   {
 //      CHECK:          %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
 //      CHECK:          %[[ADD_INS0_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 256] [1, 1]
@@ -645,7 +693,7 @@ func.func @multi_slice_fusion1(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %
     scf.forall.in_parallel {
       tensor.parallel_insert_slice %generic#0 into %init0[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
       tensor.parallel_insert_slice %generic#1 into %init1[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
-    }	
+    }
   }
   %empty = tensor.empty(%dim0) : tensor<?xf32>
   %result = linalg.generic {
@@ -719,7 +767,7 @@ func.func @multi_slice_fusion2(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %
     scf.forall.in_parallel {
       tensor.parallel_insert_slice %generic0 into %init0[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
       tensor.parallel_insert_slice %generic1 into %init1[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32>
-    }	
+    }
   }
   %empty = tensor.empty(%dim0) : tensor<?xf32>
   %result = linalg.generic {

>From 061d4a2336d958d3bb83fd0ea64d7ce20b9cbc61 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Thu, 17 Jul 2025 11:15:31 -0700
Subject: [PATCH 2/4] Restrict the fusion condition.

Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
 .../Linalg/Transforms/TilingInterfaceImpl.cpp | 27 ++++++---
 .../tile-and-fuse-consumer.mlir               | 59 ++++++++++++++-----
 2 files changed, 63 insertions(+), 23 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index fb9ba4ccb14af..f609c65818e43 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "llvm/Support/Debug.h"
@@ -915,13 +916,25 @@ struct PackOpTiling
                 /*stopCondition=*/nullptr, /*closedUB=*/true);
         std::optional<int64_t> cstInnerSize =
             getConstantIntValue(dimAndTileMapping[dim]);
-        int64_t dimSize = packOp.getSourceType().getDimSize(dim);
-        // TODO: It could be untiled if the `dimSize` is dynamic. It is a hard
-        // check to determine if a dimension is tiled or not.
-        bool isTiled = failed(cstTileSize) || ShapedType::isDynamic(dimSize) ||
-                       cstTileSize.value() != dimSize;
-        if (isTiled && (failed(cstTileSize) || !cstInnerSize ||
-                        *cstTileSize % *cstInnerSize != 0)) {
+        // If a dimension is not tiled, it is always valid to fuse the pack op,
+        // even if the op has padding semantics. Because it always generates a
+        // full slice along the dimension.
+        // TODO: It could be untiled if the `srcDimSize` is dynamic. It is a
+        // hard check to determine if a dimension is tiled or not.
+        int64_t srcDimSize = packOp.getSourceType().getDimSize(dim);
+        bool isTiled = failed(cstTileSize) ||
+                       ShapedType::isDynamic(srcDimSize) ||
+                       cstTileSize.value() != srcDimSize;
+        int64_t destDimSize = packOp.getDestType().getDimSize(dim);
+        bool needPadding = ShapedType::isDynamic(destDimSize) ||
+                           !cstInnerSize ||
+                           destDimSize * cstInnerSize.value() != srcDimSize;
+        // Prioritize the case that the op already says that it does not need
+        // padding.
+        if (!packOp.getPaddingValue()) {
+          needPadding = false;
+        }
+        if (isTiled && needPadding) {
           return failure();
         }
 
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index da3592547e125..3d32ddd9bed84 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -451,19 +451,19 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-func.func @fuse_pack_consumer_with_padding_semantic_into_scf_forall(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x32x3x16xf32> {
-  %0 = scf.forall (%arg2) in (2) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
-    %extracted_slice = tensor.extract_slice %arg0[0, %arg2] [64, 32] [1, 1] : tensor<64x32xf32> to tensor<64x32xf32>
-    %extracted_slice_0 = tensor.extract_slice %arg3[0, %arg2] [64, 32] [1, 1] : tensor<64x32xf32> to tensor<64x32xf32>
-    %2 = linalg.exp ins(%extracted_slice : tensor<64x32xf32>) outs(%extracted_slice_0 : tensor<64x32xf32>) -> tensor<64x32xf32>
+func.func @fuse_pack_consumer_with_padding_semantic_into_scf_forall(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x2x3x16xf32> {
+  %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
+    %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
+    %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
+    %2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32>
     scf.forall.in_parallel {
-      tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 32] [1, 1] : tensor<64x32xf32> into tensor<64x32xf32>
+      tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32>
     }
   }
-  %1 = tensor.empty() : tensor<23x32x3x16xf32>
+  %1 = tensor.empty() : tensor<23x2x3x16xf32>
   %cst = arith.constant 0.000000e+00 : f32
-  %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<23x32x3x16xf32>
-  return %pack : tensor<23x32x3x16xf32>
+  %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<23x2x3x16xf32>
+  return %pack : tensor<23x2x3x16xf32>
 }
 
 module attributes {transform.with_named_sequence} {
@@ -478,24 +478,51 @@ module attributes {transform.with_named_sequence} {
 //      CHECK: func.func @fuse_pack_consumer_with_padding_semantic_into_scf_forall(
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
-//  CHECK-DAG:   %[[OUT_INIT:.*]] = tensor.empty() : tensor<23x32x3x16xf32>
+//  CHECK-DAG:   %[[OUT_INIT:.*]] = tensor.empty() : tensor<23x2x3x16xf32>
 //  CHECK-DAG:   %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32
-//      CHECK:   %{{.*}}:2 = scf.forall (%[[IV:.*]]) in (2)
+//      CHECK:   %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (32) step (16)
 // CHECK-SAME:      shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]])
-//      CHECK:      %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 32] [1, 1]
-//      CHECK:      %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 32] [1, 1]
+//      CHECK:      %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1]
+//      CHECK:      %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
 //      CHECK:      %[[ELEM:.*]] = linalg.exp
 // CHECK-SAME:        ins(%[[ELEM_SRC]]
 // CHECK-SAME:        outs(%[[ELEM_DEST]]
 //  CHECK-DAG:      %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]])
-//  CHECK-DAG:      %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 2, 3, 16] [1, 1, 1, 1]
+//  CHECK-DAG:      %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 1, 3, 16] [1, 1, 1, 1]
 //      CHECK:      %[[TILED_PACK_OUT:.*]] = linalg.pack %[[ELEM]]
 // CHECK-SAME:        padding_value(%[[PAD_VAL]] : f32)
 // CHECK-SAME:        inner_dims_pos = [0, 1] inner_tiles = [3, 16]
 // CHECK-SAME:        into %[[TILED_PACK_DEST]]
 //      CHECK:      scf.forall.in_parallel {
-//      CHECK:          tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 32] [1, 1]
-//      CHECK:          tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 2, 3, 16] [1, 1, 1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 1, 3, 16] [1, 1, 1, 1]
+
+// -----
+
+func.func @nofuse_pack_consumer_with_padding_semantic_into_scf_forall(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x32x3x16xf32> {
+  %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
+    %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
+    %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
+    %2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32>
+    scf.forall.in_parallel {
+      // expected-error @below {{failed to fuse consumer of slice}}
+      tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32>
+    }
+  }
+  %1 = tensor.empty() : tensor<23x32x3x16xf32>
+  %cst = arith.constant 0.000000e+00 : f32
+  %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<23x32x3x16xf32>
+  return %pack : tensor<23x32x3x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
 
 // -----
 

>From 864a9a55c6fafea9ff41f10e9ed757bfce0409cc Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Thu, 17 Jul 2025 11:16:35 -0700
Subject: [PATCH 3/4] Fix IR bug

Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
 .../Interfaces/TilingInterface/tile-and-fuse-consumer.mlir    | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index 3d32ddd9bed84..fc64733d7a887 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -399,7 +399,7 @@ module {
     %c4 = arith.constant 4 : index
     %c64 = arith.constant 64 : index
     %c0 = arith.constant 0 : index
-    %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
+    %1 = scf.forall (%arg3, %arg4) in (2, 1) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
       %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
       %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
         ^bb0(%in: f32, %in_16: f32, %out: f32):
@@ -434,7 +434,7 @@ module attributes {transform.with_named_sequence} {
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
 //      CHECK:   %[[OUT_INIT:.*]] = tensor.empty() : tensor<4x32x16xf32>
-//      CHECK:   %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
+//      CHECK:   %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 1)
 // CHECK-SAME:      shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]])
 // CHECK-SAME:   {
 //      CHECK:      %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]

>From 583639db7bacdc197623df0e20185b865ffa095d Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Thu, 17 Jul 2025 12:24:57 -0700
Subject: [PATCH 4/4] Fix the fusion logic and add more lit tests

Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
 .../Linalg/Transforms/TilingInterfaceImpl.cpp |  59 ++++++----
 .../tile-and-fuse-consumer.mlir               | 107 ++++++++++++++++--
 2 files changed, 137 insertions(+), 29 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index f609c65818e43..0513fbfe28148 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -895,48 +895,63 @@ struct PackOpTiling
         packOp.getDimAndTileMapping();
     for (auto dim : llvm::seq<int64_t>(packOp.getSourceRank())) {
       if (dimAndTileMapping.count(dim)) {
-        // Currently fusing `packOp` as consumer only expects perfect tiling
-        // scenario because even if without padding semantic, the `packOp` may
-        // also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>,
-        // where the `tileSize` from operand of `packOp` is 5, which is not
-        // exactly divided by `innerTile`(=6) of `packOp`. As the result:
-        // 1. the first slice is extracted from (0) to (4) and inserted into
-        // (0,0)~(0,4) at first row.
-        // 2. the second slice is extracted from (5) to (9) and SHOULD BE
-        // respectively inserted into two rows with different length, including
-        // first row: (0,5) and second row (1,0)~(1,3).
-        // It is hard to coordinate them, thus adding below constraint to bypass
-        // them temporarily. In another word, we can only support tiling with
-        // consumer if the tile size for the producer is either a multiple of
-        // the inner tile size for the packed dimensions or the dimension is not
-        // tiled at this moment.
         FailureOr<int64_t> cstTileSize =
             ValueBoundsConstraintSet::computeConstantBound(
                 presburger::BoundType::UB, sizes[dim],
                 /*stopCondition=*/nullptr, /*closedUB=*/true);
         std::optional<int64_t> cstInnerSize =
             getConstantIntValue(dimAndTileMapping[dim]);
+
         // If a dimension is not tiled, it is always valid to fuse the pack op,
         // even if the op has padding semantics. Because it always generates a
         // full slice along the dimension.
         // TODO: It could be untiled if the `srcDimSize` is dynamic. It is a
         // hard check to determine if a dimension is tiled or not.
         int64_t srcDimSize = packOp.getSourceType().getDimSize(dim);
+        int64_t destDimSize = packOp.getDestType().getDimSize(dim);
         bool isTiled = failed(cstTileSize) ||
                        ShapedType::isDynamic(srcDimSize) ||
                        cstTileSize.value() != srcDimSize;
-        int64_t destDimSize = packOp.getDestType().getDimSize(dim);
-        bool needPadding = ShapedType::isDynamic(destDimSize) ||
+        if (!isTiled) {
+          outerDimOffsets.push_back(offsets[dim]);
+          if (ShapedType::isStatic(destDimSize)) {
+            outerDimSizes.push_back(b.getIndexAttr(destDimSize));
+          } else {
+            outerDimSizes.push_back(
+                b.createOrFold<tensor::DimOp>(loc, packOp.getDest(), dim));
+          }
+          continue;
+        }
+
+        // If the dimension needs padding, it is not supported because there are
+        // iterations that only write padding values to the whole tile. The
+        // consumer fusion is driven by the source, so it is not possible to map
+        // an empty slice to the tile.
+        bool needExtraPadding = ShapedType::isDynamic(destDimSize) ||
                            !cstInnerSize ||
                            destDimSize * cstInnerSize.value() != srcDimSize;
         // Prioritize the case that the op already says that it does not need
         // padding.
-        if (!packOp.getPaddingValue()) {
-          needPadding = false;
-        }
-        if (isTiled && needPadding) {
+        if (!packOp.getPaddingValue())
+          needExtraPadding = false;
+        if (needExtraPadding)
+          return failure();
+
+        // Currently fusing `packOp` as consumer only expects perfect tiling
+        // scenario because even if without padding semantic, the `packOp` may
+        // also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>,
+        // where the `tileSize` from operand of `packOp` is 5, which is not
+        // exactly divided by `innerTile`(=6) of `packOp`. As the result:
+        // 1. the first slice is extracted from (0) to (4) and inserted into
+        // (0,0)~(0,4) at first row.
+        // 2. the second slice is extracted from (5) to (9) and SHOULD BE
+        // respectively inserted into two rows with different length, including
+        // first row: (0,5) and second row (1,0)~(1,3).
+        // It is hard to coordinate them, thus adding below constraint to bypass
+        // them temporarily.
+        if ((failed(cstTileSize) || !cstInnerSize ||
+             *cstTileSize % *cstInnerSize != 0))
           return failure();
-        }
 
         using AV = affine::AffineValueExpr;
         affine::AffineBuilder ab(b, loc);
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index fc64733d7a887..daa8341ca5a28 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -395,7 +395,7 @@ module attributes {transform.with_named_sequence} {
 
 #map = affine_map<(d0, d1) -> (d0, d1)>
 module {
-  func.func @fuse_pack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<4x32x16xf32> {
+  func.func @fuse_perfect_tiling_pack_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<4x32x16xf32> {
     %c4 = arith.constant 4 : index
     %c64 = arith.constant 64 : index
     %c0 = arith.constant 0 : index
@@ -429,7 +429,7 @@ module attributes {transform.with_named_sequence} {
   }
 }
 //      CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
-//      CHECK: func.func @fuse_pack_consumer_into_scf_forall(
+//      CHECK: func.func @fuse_perfect_tiling_pack_consumer(
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
@@ -451,7 +451,10 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-func.func @fuse_pack_consumer_with_padding_semantic_into_scf_forall(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x2x3x16xf32> {
+// It is valid to fuse the pack op with padding semantics if the dimension does
+// not need padding.
+
+func.func @fuse_pack_consumer_with_padding_semantics(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x2x3x16xf32> {
   %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
     %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
     %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
@@ -475,7 +478,7 @@ module attributes {transform.with_named_sequence} {
   }
 }
 //      CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
-//      CHECK: func.func @fuse_pack_consumer_with_padding_semantic_into_scf_forall(
+//      CHECK: func.func @fuse_pack_consumer_with_padding_semantics(
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
 //  CHECK-DAG:   %[[OUT_INIT:.*]] = tensor.empty() : tensor<23x2x3x16xf32>
@@ -488,18 +491,72 @@ module attributes {transform.with_named_sequence} {
 // CHECK-SAME:        ins(%[[ELEM_SRC]]
 // CHECK-SAME:        outs(%[[ELEM_DEST]]
 //  CHECK-DAG:      %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]])
-//  CHECK-DAG:      %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 1, 3, 16] [1, 1, 1, 1]
+//  CHECK-DAG:      %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [23, 1, 3, 16] [1, 1, 1, 1]
 //      CHECK:      %[[TILED_PACK_OUT:.*]] = linalg.pack %[[ELEM]]
 // CHECK-SAME:        padding_value(%[[PAD_VAL]] : f32)
 // CHECK-SAME:        inner_dims_pos = [0, 1] inner_tiles = [3, 16]
 // CHECK-SAME:        into %[[TILED_PACK_DEST]]
 //      CHECK:      scf.forall.in_parallel {
 //      CHECK:          tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
-//      CHECK:          tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 1, 3, 16] [1, 1, 1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [23, 1, 3, 16] [1, 1, 1, 1]
 
 // -----
 
-func.func @nofuse_pack_consumer_with_padding_semantic_into_scf_forall(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x32x3x16xf32> {
+// It is valid to fuse the pack if the dimension is not tiled even when it needs
+// extra padding.
+
+func.func @fuse_pack_consumer_with_untiled_extra_padding(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<33x2x3x16xf32> {
+  %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
+    %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
+    %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
+    %2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32>
+    }
+  }
+  %1 = tensor.empty() : tensor<33x2x3x16xf32>
+  %cst = arith.constant 0.000000e+00 : f32
+  %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<33x2x3x16xf32>
+  return %pack : tensor<33x2x3x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
+//      CHECK: func.func @fuse_pack_consumer_with_untiled_extra_padding(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
+//  CHECK-DAG:   %[[OUT_INIT:.*]] = tensor.empty() : tensor<33x2x3x16xf32>
+//  CHECK-DAG:   %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32
+//      CHECK:   %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (32) step (16)
+// CHECK-SAME:      shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]])
+//      CHECK:      %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1]
+//      CHECK:      %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
+//      CHECK:      %[[ELEM:.*]] = linalg.exp
+// CHECK-SAME:        ins(%[[ELEM_SRC]]
+// CHECK-SAME:        outs(%[[ELEM_DEST]]
+//  CHECK-DAG:      %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]])
+//  CHECK-DAG:      %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [33, 1, 3, 16] [1, 1, 1, 1]
+//      CHECK:      %[[TILED_PACK_OUT:.*]] = linalg.pack %[[ELEM]]
+// CHECK-SAME:        padding_value(%[[PAD_VAL]] : f32)
+// CHECK-SAME:        inner_dims_pos = [0, 1] inner_tiles = [3, 16]
+// CHECK-SAME:        into %[[TILED_PACK_DEST]]
+//      CHECK:      scf.forall.in_parallel {
+//      CHECK:          tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [33, 1, 3, 16] [1, 1, 1, 1]
+
+// -----
+
+// If the dimension is tiled and it needs extra padding, do not fuse the pack
+// op.
+
+func.func @nofuse_pack_consumer_with_extra_padding(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x32x3x16xf32> {
   %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
     %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
     %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
@@ -526,6 +583,42 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// Imperfect tiling is not supported in pack op consumer fusion.
+
+#map = affine_map<(d0) -> (d0 * 5)>
+#map1 = affine_map<(d0) -> (d0)>
+func.func @nofuse_pack_with_imperfect_tiling(%arg0: tensor<30xf32>) -> tensor<5x6xf32> {
+  %0 = tensor.empty() : tensor<30xf32>
+  %1 = scf.forall (%arg1) in (6) shared_outs(%arg2 = %0) -> (tensor<30xf32>) {
+    %3 = affine.apply #map(%arg1)
+    %extracted_slice = tensor.extract_slice %arg0[%3] [5] [1] : tensor<30xf32> to tensor<5xf32>
+    %extracted_slice_0 = tensor.extract_slice %arg2[%3] [5] [1] : tensor<30xf32> to tensor<5xf32>
+    %4 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel"]} ins(%extracted_slice : tensor<5xf32>) outs(%extracted_slice_0 : tensor<5xf32>) {
+    ^bb0(%in: f32, %out: f32):
+      %5 = arith.addf %in, %in : f32
+      linalg.yield %5 : f32
+    } -> tensor<5xf32>
+    scf.forall.in_parallel {
+      // expected-error @below {{failed to fuse consumer of slice}}
+      tensor.parallel_insert_slice %4 into %arg2[%3] [5] [1] : tensor<5xf32> into tensor<30xf32>
+    }
+  }
+  %2 = tensor.empty() : tensor<5x6xf32>
+  %pack = linalg.pack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [6] into %2 : tensor<30xf32> -> tensor<5x6xf32>
+  return %pack : tensor<5x6xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
 module {
   func.func @fuse_add_multiple_tilable_consumers(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256x256xf32>) {
     %c0 = arith.constant 0 : index



More information about the Mlir-commits mailing list