[Mlir-commits] [mlir] c10f33e - [mlir][linalg] Fuse transform op - variadic tile sizes (#194657)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 30 01:06:14 PDT 2026


Author: Adam Siemieniuk
Date: 2026-04-30T10:06:09+02:00
New Revision: c10f33e8935de001a10f46bc370dc12aa67e5674

URL: https://github.com/llvm/llvm-project/commit/c10f33e8935de001a10f46bc370dc12aa67e5674
DIFF: https://github.com/llvm/llvm-project/commit/c10f33e8935de001a10f46bc370dc12aa67e5674.diff

LOG: [mlir][linalg] Fuse transform op - variadic tile sizes (#194657)

Extends the 'structured.fuse' op to accept packed handle containing
variable number of tile sizes.

Use of packed handles allows for runtime tiling decisions for improved
transform schedule flexibility and reusability.
The extension's design follows the existing approach of transform
'structured.tile_using_forall' op to more closely align their usage.

In case of tiling using nested loops, all created loops are packed into
a single return handle. For each target op, corresponding loops are
appended to the result handle.

Assisted-by: Claude

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/python/mlir/dialects/_ods_common.py
    mlir/python/mlir/dialects/transform/structured.py
    mlir/test/Dialect/Linalg/transform-op-fuse.mlir
    mlir/test/python/dialects/transform_structured_ext.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index cb61177bc7533..1a59f4c7d1acb 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -423,6 +423,12 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse",
     and loop interchange permutation can be provided as either static
     attributes or dynamic values (transform parameters or payload handles).
 
+    Additionally, tile sizes can also be provided as a single handle containing
+    variadic number of values. In that case, the number of loops generated is
+    determined at runtime from the number of values in the packed handle.
+    For each target, created loops are appended to the single return handle in
+    the same order as the target operations.
+
     If `apply_cleanup` is true then slice canonicalization is applied between
     fusion steps. If `use_forall` is true then tiling method generates a
     `scf.forall` loop instead of `scf.for` loops.
@@ -432,6 +438,7 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse",
     (ins TransformHandleTypeInterface:$target,
         Variadic<TransformAnyParamTypeOrAnyHandle> : $tile_sizes,
         Variadic<TransformAnyParamTypeOrAnyHandle> : $tile_interchange,
+        Optional<TransformAnyParamTypeOrAnyHandle> : $packed_tile_sizes,
         DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_tile_sizes,
         DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_tile_interchange,
         UnitAttr:$apply_cleanup,
@@ -465,7 +472,9 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse",
 
   let assemblyFormat = [{
     $target oilist(
-      `tile_sizes` custom<DynamicIndexList>($tile_sizes, $static_tile_sizes) |
+      `tile_sizes` custom<PackedOrDynamicIndexList>($packed_tile_sizes,
+                                                    $tile_sizes,
+                                                    $static_tile_sizes) |
       `interchange` custom<DynamicIndexList>($tile_interchange, $static_tile_interchange)
     )
     attr-dict `:` functional-type(operands, results)

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index baa57f8920094..f44693096b26b 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -654,6 +654,7 @@ void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
         /*target=*/target,
         /*tile_sizes=*/dynamicTileSizes,
         /*tile_interchange=*/dynamicTileInterchange,
+        /*packed_tile_sizes=*/Value(),
         /*static_tile_sizes=*/staticTileSizesAttr,
         /*static_tile_interchange=*/staticTileInterchangeAttr,
         /*apply_cleanup=*/applyCleanup,
@@ -666,10 +667,12 @@ template <typename Range>
 static LogicalResult applyTilingToAll(
     RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
     unsigned numLoops, transform::TransformResults &transformResults,
+    bool packedResults,
     function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
         applyFn) {
   SmallVector<Operation *> tiledLinalgOps;
   SmallVector<SmallVector<Operation *>> loopOps(numLoops);
+  size_t numTargets = llvm::range_size(payloadOps);
 
   for (Operation *target : payloadOps) {
     auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
@@ -704,8 +707,22 @@ static LogicalResult applyTilingToAll(
   }
 
   transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
-  for (unsigned int i = 0; i < numLoops; ++i)
-    transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
+  if (packedResults) {
+    // In case of packed results, all created loops are assigned to a single
+    // handle. Loops are returned in order of targets such as:
+    //   %loops_handle = {
+    //     target0:loop0, ..., target0:loopN,
+    //     target1:loop0, ..., target1:loopN,
+    //     ... }
+    SmallVector<Operation *> flattenedLoopOps;
+    for (unsigned int idx = 0; idx < numTargets; ++idx)
+      for (unsigned int i = 0; i < numLoops; ++i)
+        flattenedLoopOps.push_back(loopOps[i][idx]);
+    transformResults.set(transformOp->getOpResult(1), flattenedLoopOps);
+  } else {
+    for (unsigned int i = 0; i < numLoops; ++i)
+      transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
+  }
 
   return success();
 }
@@ -716,9 +733,13 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
                          mlir::transform::TransformState &state) {
   auto transformOp = cast<TransformOpInterface>(getOperation());
 
-  SmallVector<int64_t> tileSizes;
-  DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults(
-      state, transformOp, getMixedTileSizes(), tileSizes);
+  SmallVector<OpFoldResult> mixedTileSizes;
+  DiagnosedSilenceableFailure status =
+      getPackedTileSizes()
+          ? unpackSingleIndexResultPayloadOperations(
+                state, transformOp, mixedTileSizes, getPackedTileSizes())
+          : unpackSingleIndexResultPayloadOperations(
+                state, transformOp, mixedTileSizes, getMixedTileSizes());
   if (!status.succeeded())
     return status;
   SmallVector<int64_t> tileInterchange;
@@ -733,9 +754,7 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
   tilingOptions.setLoopType(useForall
                                 ? scf::SCFTilingOptions::LoopType::ForallOp
                                 : scf::SCFTilingOptions::LoopType::ForOp);
-  SmallVector<OpFoldResult> tileSizesOfr =
-      getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
-  tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
+  tilingOptions = tilingOptions.setTileSizes(mixedTileSizes);
   scf::SCFTileAndFuseOptions tileAndFuseOptions;
   tileAndFuseOptions.tilingOptions = tilingOptions;
 
@@ -748,11 +767,20 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
     tileAndFuseOptions.cleanupPatterns = std::move(patterns);
   }
 
-  size_t numLoops =
-      useForall ? 1 : tileSizes.size() - llvm::count(tileSizes, 0);
+  size_t numLoops;
+  if (useForall) {
+    numLoops = 1;
+  } else {
+    numLoops = llvm::count_if(mixedTileSizes, [](OpFoldResult ofr) {
+      auto attr = dyn_cast<Attribute>(ofr);
+      if (!attr)
+        return true;
+      return cast<IntegerAttr>(attr).getInt() != 0;
+    });
+  }
   LogicalResult result = applyTilingToAll(
       rewriter, getOperation(), state.getPayloadOps(getTarget()), numLoops,
-      transformResults,
+      transformResults, /*packedResults=*/getPackedTileSizes() != nullptr,
       [&](TilingInterface tilingInterfaceOp)
           -> FailureOr<scf::SCFTileAndFuseResult> {
         return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
@@ -763,6 +791,11 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
 }
 
 LogicalResult transform::FuseOp::verify() {
+  bool hasPackedTiles = getPackedTileSizes() != nullptr;
+  if (!getMixedTileSizes().empty() && hasPackedTiles)
+    return emitOpError(
+        "tile_sizes and packed_tile_sizes are mutually exclusive");
+
   auto iterspace_rank = getStaticTileSizes().size();
   ArrayRef<int64_t> permutation = getStaticTileInterchange();
   if (permutation.size() > iterspace_rank)
@@ -782,8 +815,9 @@ LogicalResult transform::FuseOp::verify() {
   }
 
   ArrayRef<int64_t> sizes = getStaticTileSizes();
-  size_t numExpectedLoops =
-      getUseForall() ? 1 : sizes.size() - llvm::count(sizes, 0);
+  size_t numExpectedLoops = getUseForall() || hasPackedTiles
+                                ? 1
+                                : sizes.size() - llvm::count(sizes, 0);
   if (numExpectedLoops != getNumResults() - 1)
     return emitOpError() << "expects " << numExpectedLoops << " loop results";
 
@@ -803,6 +837,7 @@ void transform::FuseOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
   consumesHandle(getTargetMutable(), effects);
   onlyReadsHandle(getTileSizesMutable(), effects);
+  onlyReadsHandle(getPackedTileSizesMutable(), effects);
   onlyReadsHandle(getTileInterchangeMutable(), effects);
   producesHandle(getOperation()->getOpResults(), effects);
   modifiesPayload(effects);

diff  --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 10abd06ff266e..7f1bd2183a0c5 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -240,6 +240,8 @@ def _dispatch_mixed_values(
         for size in values or []:
             if isinstance(size, int):
                 static_values.append(size)
+            elif isinstance(size, IntegerAttr):
+                static_values.append(size.value)
             else:
                 static_values.append(ShapedType.get_dynamic_size())
                 dynamic_values.append(size)

diff  --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py
index d9ab504f0de54..a3c3057ddb834 100644
--- a/mlir/python/mlir/dialects/transform/structured.py
+++ b/mlir/python/mlir/dialects/transform/structured.py
@@ -183,15 +183,19 @@ def __init__(
         tile_interchange = tile_interchange if tile_interchange else []
         (
             dynamic_tile_sizes,
+            packed_tile_sizes,
             static_tile_sizes,
-            _,
-        ) = _dispatch_dynamic_index_list(tile_sizes)
+        ) = _dispatch_mixed_values(tile_sizes)
         (
             dynamic_tile_interchange,
             static_tile_interchange,
             _,
         ) = _dispatch_dynamic_index_list(tile_interchange)
-        num_loops = 1 if use_forall else sum(1 for v in static_tile_sizes if v != 0)
+        num_loops = (
+            1
+            if use_forall or packed_tile_sizes is not None
+            else sum(1 for v in static_tile_sizes if v != 0)
+        )
 
         if isinstance(loop_types_or_target, (Operation, Value, OpView)):
             loop_types = [transform.AnyOpType.get()] * num_loops
@@ -210,6 +214,7 @@ def __init__(
             target,
             tile_sizes=dynamic_tile_sizes,
             tile_interchange=dynamic_tile_interchange,
+            packed_tile_sizes=packed_tile_sizes,
             static_tile_sizes=static_tile_sizes,
             static_tile_interchange=static_tile_interchange,
             apply_cleanup=apply_cleanup,

diff  --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
index b05dc1f295a49..dab8491708104 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
@@ -112,6 +112,131 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: func.func @fuse_unary_packed_tile_sizes
+func.func @fuse_unary_packed_tile_sizes(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+
+  //     CHECK: %[[RES:.*]] = scf.for
+  //     CHECK:    scf.for
+  //     CHECK:       linalg.exp
+  //     CHECK:       linalg.add
+  //     CHECK: return %[[RES]]
+  %0 = linalg.exp ins(%arg0 : tensor<?x?xf32>)
+                             outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
+  %1 = linalg.add ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>)
+                             outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %c32 = transform.param.constant 32 : i64 -> !transform.any_param
+    %c64 = transform.param.constant 64 : i64 -> !transform.any_param
+    %tiles = transform.merge_handles %c32, %c64 : !transform.any_param
+    %1, %loops = transform.structured.fuse %0 tile_sizes *(%tiles)
+      : (!transform.any_op, !transform.any_param) -> (!transform.any_op, !transform.any_op)
+    // Verify that correct number of loops is present in packed result.
+    %loop:2 = transform.split_handle %loops : (!transform.any_op)
+      -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @fuse_unary_packed_tile_sizes_forall
+func.func @fuse_unary_packed_tile_sizes_forall(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+
+  //     CHECK: %[[RES:.*]] = scf.forall
+  //     CHECK:       linalg.exp
+  //     CHECK:       linalg.add
+  //     CHECK: return %[[RES]]
+  %0 = linalg.exp ins(%arg0 : tensor<?x?xf32>)
+                             outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
+  %1 = linalg.add ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>)
+                             outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %c32 = transform.param.constant 32 : i64 -> !transform.any_param
+    %c64 = transform.param.constant 64 : i64 -> !transform.any_param
+    %tiles = transform.merge_handles %c32, %c64 : !transform.any_param
+    %1, %loops = transform.structured.fuse %0 tile_sizes *(%tiles) {use_forall}
+      : (!transform.any_op, !transform.any_param) -> (!transform.any_op, !transform.any_op)
+    // Verify that correct number of loops is present in packed result.
+    %loop:1 = transform.split_handle %loops : (!transform.any_op)
+      -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @fuse_unary_packed_tile_sizes_multiple_targets
+func.func @fuse_unary_packed_tile_sizes_multiple_targets(
+    %arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+
+  //     CHECK: scf.for
+  //     CHECK:    scf.for
+  //     CHECK:       linalg.add
+  //     CHECK: %[[RES:.*]] = scf.for
+  //     CHECK:    scf.for
+  //     CHECK:       linalg.add
+  //     CHECK: return %[[RES]]
+  %0 = linalg.add ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+                             outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
+  %1 = linalg.add ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>)
+                             outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %c32 = transform.param.constant 32 : i64 -> !transform.any_param
+    %c64 = transform.param.constant 64 : i64 -> !transform.any_param
+    %tiles = transform.merge_handles %c32, %c64 : !transform.any_param
+    %1, %loops = transform.structured.fuse %0 tile_sizes *(%tiles)
+      : (!transform.any_op, !transform.any_param) -> (!transform.any_op, !transform.any_op)
+    // Verify that correct number of loops is present in packed result.
+    %loop:4 = transform.split_handle %loops : (!transform.any_op)
+      -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @fuse_no_tiling_packed_tile_sizes
+func.func @fuse_no_tiling_packed_tile_sizes(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+
+  //     CHECK-NOT: scf.for
+  //     CHECK: linalg.exp
+  //     CHECK: %[[RES:.*]] = linalg.add
+  //     CHECK: return %[[RES]]
+  %0 = linalg.exp ins(%arg0 : tensor<?x?xf32>)
+                             outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
+  %1 = linalg.add ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>)
+                             outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %c0 = transform.param.constant 0 : i64 -> !transform.any_param
+    %tiles = transform.merge_handles %c0, %c0 : !transform.any_param
+    %1, %loops = transform.structured.fuse %0 tile_sizes *(%tiles)
+      : (!transform.any_op, !transform.any_param) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
 // CHECK-LABEL: func.func @interchange_reduction
 //  CHECK-SAME: (%[[INPUT:.+]]: tensor<12x7x25xf32>)
 func.func @interchange_reduction(%input: tensor<12x7x25xf32>) -> tensor<12x25xf32> {

diff  --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index e58b7646316fc..fcede61100e00 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -191,6 +191,33 @@ def testFuseOpAttributes(target):
     # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
 
 
+ at run
+ at create_sequence
+def testFuseOpPackedTileSizes(target):
+    tiles = structured.MatchOp.match_op_names(target, ["arith.constant"])
+    structured.FuseOp(target, tile_sizes=tiles)
+    # CHECK-LABEL: TEST: testFuseOpPackedTileSizes
+    # CHECK: transform.sequence
+    # CHECK: %[[T:.*]] = transform.structured.match
+    # CHECK: %{{.+}}, %{{.+}} = transform.structured.fuse
+    # CHECK-SAME: tile_sizes *(%[[T]])
+    # CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+
+ at run
+ at create_sequence
+def testFuseOpPackedTileSizesForall(target):
+    tiles = structured.MatchOp.match_op_names(target, ["arith.constant"])
+    structured.FuseOp(target, tile_sizes=tiles, use_forall=True)
+    # CHECK-LABEL: TEST: testFuseOpPackedTileSizesForall
+    # CHECK: transform.sequence
+    # CHECK: %[[T:.*]] = transform.structured.match
+    # CHECK: %{{.+}}, %{{.+}} = transform.structured.fuse
+    # CHECK-SAME: tile_sizes *(%[[T]])
+    # CHECK-SAME: {use_forall}
+    # CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+
 @run
 @create_sequence
 def testGeneralize(target):


        


More information about the Mlir-commits mailing list