[Mlir-commits] [mlir] 7d040d4 - [mlir][linalg] Handle outer_dims_perm in linalg.pack consumer fusion. (#149426)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jul 18 09:42:43 PDT 2025


Author: Han-Chung Wang
Date: 2025-07-18T09:42:40-07:00
New Revision: 7d040d4675baf6881cf50c4dba78cc18af85f9ef

URL: https://github.com/llvm/llvm-project/commit/7d040d4675baf6881cf50c4dba78cc18af85f9ef
DIFF: https://github.com/llvm/llvm-project/commit/7d040d4675baf6881cf50c4dba78cc18af85f9ef.diff

LOG: [mlir][linalg] Handle outer_dims_perm in linalg.pack consumer fusion. (#149426)

Signed-off-by: hanhanW <hanhan0912 at gmail.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
    mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 5a10883a6043c..b059bcc025315 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -893,6 +893,13 @@ struct PackOpTiling
     SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
     DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
         packOp.getDimAndTileMapping();
+    SmallVector<int64_t> outerShapeWithoutTranspose(
+        packOp.getDestType().getShape().take_front(packOp.getSourceRank()));
+    if (!packOp.getOuterDimsPerm().empty()) {
+      applyPermutationToVector(
+          outerShapeWithoutTranspose,
+          invertPermutationVector(packOp.getOuterDimsPerm()));
+    }
     for (auto dim : llvm::seq<int64_t>(packOp.getSourceRank())) {
       if (dimAndTileMapping.count(dim)) {
         FailureOr<int64_t> cstTileSize =
@@ -908,7 +915,7 @@ struct PackOpTiling
         // TODO: It could be untiled if the `srcDimSize` is dynamic. It is a
         // hard check to determine if a dimension is tiled or not.
         int64_t srcDimSize = packOp.getSourceType().getDimSize(dim);
-        int64_t destDimSize = packOp.getDestType().getDimSize(dim);
+        int64_t destDimSize = outerShapeWithoutTranspose[dim];
         bool isTiled = failed(cstTileSize) ||
                        ShapedType::isDynamic(srcDimSize) ||
                        cstTileSize.value() != srcDimSize;

diff  --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index 7b0a8494a8acb..20164d5dfd91a 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -451,6 +451,51 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+
+func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>, %arg2: tensor<2x64x16x1xf32>) -> tensor<2x64x16x1xf32> {
+  %0 = scf.forall (%arg3) = (0) to (32) step (16) shared_outs(%arg4 = %arg1) -> (tensor<64x32xf32>) {
+    %src = tensor.extract_slice %arg0[0, %arg3] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
+    %dest = tensor.extract_slice %arg4[0, %arg3] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
+    %1 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %1 into %arg4[0, %arg3] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32>
+    }
+  }
+  %pack = linalg.pack %0 outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 1] into %arg2 : tensor<64x32xf32> -> tensor<2x64x16x1xf32>
+  return %pack : tensor<2x64x16x1xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
+//      CHECK: func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]
+//      CHECK:   %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (32) step (16)
+// CHECK-SAME:      shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[ARG2]])
+//      CHECK:      %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1]
+//      CHECK:      %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
+//      CHECK:      %[[ELEM:.*]] = linalg.exp
+// CHECK-SAME:        ins(%[[ELEM_SRC]]
+// CHECK-SAME:        outs(%[[ELEM_DEST]]
+//  CHECK-DAG:      %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]])
+//  CHECK-DAG:      %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], 0, 0, 0] [1, 64, 16, 1] [1, 1, 1, 1]
+//      CHECK:      %[[PACK:.*]] = linalg.pack %[[ELEM]]
+// CHECK-SAME:        outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 1]
+// CHECK-SAME:        into %[[TILED_PACK_DEST]]
+//      CHECK:      scf.forall.in_parallel {
+//      CHECK:          tensor.parallel_insert_slice %[[ELEM]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], 0, 0, 0] [1, 64, 16, 1] [1, 1, 1, 1]
+
+// -----
+
 // It is valid to fuse the pack op in perfect tiling scenario when the dimension
 // is dynamic and padding is not needed.
 


        


More information about the Mlir-commits mailing list