[Mlir-commits] [mlir] 618f231 - [MLIR][Transform] Consolidate result of structured.split into one list (#111171)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 14 18:53:37 PST 2024


Author: Jinyun (Joey) Ye
Date: 2024-11-15T10:53:34+08:00
New Revision: 618f231a6d3ef41d231e2a4d1e2eca4c0d709802

URL: https://github.com/llvm/llvm-project/commit/618f231a6d3ef41d231e2a4d1e2eca4c0d709802
DIFF: https://github.com/llvm/llvm-project/commit/618f231a6d3ef41d231e2a4d1e2eca4c0d709802.diff

LOG: [MLIR][Transform] Consolidate result of structured.split into one list (#111171)

Follow-up a review comment from
https://github.com/llvm/llvm-project/pull/82792#discussion_r1604925239
as a separate PR:

	E.g.:
	```
	%0:2 = transform.structured.split
	```
	is changed to
	```
	%t = transform.structured.split
	%0:2 = transform.split_handle %t
	```

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/python/mlir/dialects/transform/structured.py
    mlir/test/Dialect/Linalg/continuous-tiling-full.mlir
    mlir/test/Dialect/Linalg/continuous-tiling-multiway-split.mlir
    mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
    mlir/test/Dialect/Linalg/transform-op-split.mlir
    mlir/test/Dialect/Linalg/transform-ops.mlir
    mlir/test/python/dialects/transform_structured_ext.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 25a98a16960f37..f256af2f6b12b8 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -703,8 +703,10 @@ def MultiTileSizesOp : Op<Transform_Dialect, "structured.multitile_sizes",
                          { target_size = 10, dimension = 1 }
                        : !transform.any_op, !transform.param<i64>,
                          !transform.param<i64>, !transform.param<i64>
-    %low, %high = structured.split %target after %split { dimension = 1 }
+    %handles = structured.split %target after %split { dimension = 1 }
                 : !transform.any_op, !transform.param<i64>
+    %low, %high = transform.split_handle %handles : (!transform.any_op)
+                      -> (!transform.any_op, !transform.any_op)
     %tiled_low, %loop1 = structured.tile_using_for %low [0, %sz1]
                        : (!transform.any_op, !transform.param<i64>)
                       -> (!transform.any_op, !transform.any_op)
@@ -1452,21 +1454,24 @@ def SplitOp : Op<Transform_Dialect, "structured.split",
     operations pointed to by the target handle.
 
     The operation consumes the target handle, but preserves the chunk size
-    handle if provided. Without the `multiway` attribute, it produces two
-    new handles pointing to the two parts of the structured op after splitting,
-    in the same order as the target operand, with the first handle
-    corresponding to the part with lower iteration space indices.
+    handle if provided. Without the `multiway` attribute, it produces a
+    new handle that is a list of the two parts of the structured op after
+    splitting, whose lower index part corresponding to the part with lower
+    iteration space indices.
 
     Multiway split mode is enabled by specifying the `multiway` attribute.
     In this mode a single `target` op is split into multiple parts covering
     the iteration space of the specified dimension. `static_chunk_sizes` and
     `dynamic_chunk_sizes` in this case is a list of chunk sizes that the given
-    dimension should be split into. With `multiway` it produces two handles;
-    the first handle is a list of the multiple parts of the structured op
+    dimension should be split into. With `multiway` it also produces a handle;
+    The result handle is a list of the multiple parts of the structured op
     after splitting, where the target dimensions for each linalg op in the
     list corresponds to the chunk sizes specfied in the input split list.
     If the chunk sizes do not cover the entire iteration space, the leftover
-    chunk is the last payload in the first handle. The second handle is empty.
+    chunk is the last payload in the result handle.
+
+    As the result handle is most of time a list, an `transform.split_handle`
+    is needed to access individual handle.
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
@@ -1474,8 +1479,7 @@ def SplitOp : Op<Transform_Dialect, "structured.split",
                        Optional<TransformAnyParamTypeOrAnyHandle>:$dynamic_chunk_sizes,
                        I64Attr:$static_chunk_sizes,
                        UnitAttr:$multiway);
-  let results = (outs TransformHandleTypeInterface:$first,
-                      TransformHandleTypeInterface:$second);
+  let results = (outs TransformHandleTypeInterface:$split_list);
   let hasCustomAssemblyFormat = 1;
   let hasVerifier = 1;
 }

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 88e116bce7f595..1956fc634ef395 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2363,10 +2363,10 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
     return DiagnosedSilenceableFailure::success();
   };
 
+  SmallVector<Operation *> opList;
   if (isMultiwaySplit) {
 
     // Split a single target operation at multiple points.
-    SmallVector<Operation *> opList;
     TilingInterface head, tail;
     Operation *target = payload.front();
 
@@ -2406,8 +2406,6 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
     // Append any leftover parts to the end of the result list.
     if (tail)
       opList.push_back(tail.getOperation());
-    results.set(cast<OpResult>(getFirst()), opList);
-    results.set(cast<OpResult>(getSecond()), {});
 
   } else {
     // Split each target operation.
@@ -2453,9 +2451,11 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
       return diag;
     }
 
-    results.set(cast<OpResult>(getFirst()), first);
-    results.set(cast<OpResult>(getSecond()), second);
+    opList.append(first);
+    if (second.size())
+      opList.append(second);
   }
+  results.set(cast<OpResult>(getSplitList()), opList);
   return DiagnosedSilenceableFailure::success();
 }
 
@@ -2507,7 +2507,7 @@ ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
   result.addAttribute(
       SplitOp::getStaticChunkSizesAttrName(result.name).getValue(),
       staticChunkSizes);
-  result.addTypes({targetType, targetType});
+  result.addTypes(targetType);
   return success();
 }
 

diff  --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py
index f6111f516f8c3d..9121aa8e40237b 100644
--- a/mlir/python/mlir/dialects/transform/structured.py
+++ b/mlir/python/mlir/dialects/transform/structured.py
@@ -445,7 +445,6 @@ def __init__(
             dynamic_chunk_sizes = chunk_sizes
 
         super().__init__(
-            target.type,
             target.type,
             target,
             dimension=dimension,

diff  --git a/mlir/test/Dialect/Linalg/continuous-tiling-full.mlir b/mlir/test/Dialect/Linalg/continuous-tiling-full.mlir
index 7410ff593d01a2..e02aa0c4db44af 100644
--- a/mlir/test/Dialect/Linalg/continuous-tiling-full.mlir
+++ b/mlir/test/Dialect/Linalg/continuous-tiling-full.mlir
@@ -4,7 +4,7 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %tile_sizes, %chunk_sizes = transform.structured.continuous_tile_sizes %0 { dimension = 0, target_size = 9 } : (!transform.any_op) -> !transform.any_op
-    %linalg_splits, %empty = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.any_op
+    %linalg_splits = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.any_op
     transform.foreach %linalg_splits, %tile_sizes : !transform.any_op, !transform.any_op {
     ^bb1(%linalg_split: !transform.any_op, %tile_size: !transform.any_op):
       %tiled_linalg_split, %dim0_loop = transform.structured.tile_using_for %linalg_split tile_sizes [%tile_size] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
@@ -65,7 +65,7 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %tile_sizes, %chunk_sizes = transform.structured.continuous_tile_sizes %0 { dimension = 0, target_size = 9 } : (!transform.any_op) -> !transform.param<i64>
-    %linalg_splits, %empty = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.param<i64>
+    %linalg_splits = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.param<i64>
     transform.foreach %linalg_splits, %tile_sizes : !transform.any_op, !transform.param<i64> {
     ^bb1(%linalg_split: !transform.any_op, %tile_size: !transform.param<i64>):
       %tiled_linalg_split, %dim0_loop = transform.structured.tile_using_for %linalg_split tile_sizes [%tile_size] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
@@ -126,7 +126,7 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %tile_sizes, %chunk_sizes = transform.structured.continuous_tile_sizes %0 { dimension = 0, target_size = 9 } : (!transform.any_op) -> !transform.any_op
-    %linalg_splits, %empty = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.any_op
+    %linalg_splits = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.any_op
     transform.foreach %linalg_splits, %tile_sizes with_zip_shortest : !transform.any_op, !transform.any_op {
     ^bb1(%linalg_split: !transform.any_op, %tile_size: !transform.any_op):
       %tiled_linalg_split, %dim0_loop = transform.structured.tile_using_for %linalg_split tile_sizes [%tile_size] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
@@ -177,4 +177,4 @@ func.func @continuous_tile_dynamic_linalg_matmul(
 // CHECK:     %[[AM16:.*]] = affine.min #[[$MAP12]]()[%{{.*}}, %{{.*}}, %[[AM0]], %[[AM4]], %[[AM8]], %[[AM12]]]
 // CHECK:     %{{.*}} = scf.for %[[IDX:.+]] = %[[C0]] to %[[AM16]] step %[[C1]] iter_args(%[[OUT:.+]] = %{{.*}}) -> (tensor<?x?xf32>) {
 // CHECK:       %[[MM:.+]] = linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<1x?xf32>, tensor<?x?xf32>) outs(%{{.*}} : tensor<1x?xf32>) -> tensor<1x?xf32>
-// CHECK:       %{{.*}} = tensor.insert_slice %[[MM]] into %[[OUT]][%[[IDX]], 0] [1, %{{.*}}] [1, 1] : tensor<1x?xf32> into tensor<?x?xf32>
\ No newline at end of file
+// CHECK:       %{{.*}} = tensor.insert_slice %[[MM]] into %[[OUT]][%[[IDX]], 0] [1, %{{.*}}] [1, 1] : tensor<1x?xf32> into tensor<?x?xf32>

diff  --git a/mlir/test/Dialect/Linalg/continuous-tiling-multiway-split.mlir b/mlir/test/Dialect/Linalg/continuous-tiling-multiway-split.mlir
index 609766fbdc91f2..12fe8a2a2b6b5c 100644
--- a/mlir/test/Dialect/Linalg/continuous-tiling-multiway-split.mlir
+++ b/mlir/test/Dialect/Linalg/continuous-tiling-multiway-split.mlir
@@ -8,7 +8,7 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %tiles, %splits = transform.structured.continuous_tile_sizes %0 { dimension = 1, target_size = 9} : (!transform.any_op) -> !transform.any_op
-    %low, %high = transform.structured.split %0 after %splits { dimension = 1, multiway } : !transform.any_op, !transform.any_op
+    %splits2 = transform.structured.split %0 after %splits { dimension = 1, multiway } : !transform.any_op, !transform.any_op
     transform.yield
   }
 }
@@ -58,7 +58,7 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %tiles, %splits = transform.structured.continuous_tile_sizes %0 { dimension = 1, target_size = 9} : (!transform.any_op) -> !transform.param<i64>
-    %low, %high = transform.structured.split %0 after %splits { dimension = 1, multiway } : !transform.any_op, !transform.param<i64>
+    %splits2 = transform.structured.split %0 after %splits { dimension = 1, multiway } : !transform.any_op, !transform.param<i64>
     transform.yield
   }
 }

diff  --git a/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
index 51332ffce03d1d..af041db9eeffbf 100644
--- a/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
+++ b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
@@ -6,14 +6,16 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3} : (!transform.any_op) -> !transform.any_op
-    %2:2 = transform.structured.split %0 after %1#2 { dimension = 0 } : !transform.any_op, !transform.any_op
+    %split = transform.structured.split %0 after %1#2 { dimension = 0 } : !transform.any_op, !transform.any_op
+    %2:2 = transform.split_handle %split : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
     %3:2 = transform.structured.tile_using_for %2#0 tile_sizes [%1#0] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
     %4:2 = transform.structured.tile_using_for %2#1 tile_sizes [%1#1] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
     %5 = transform.merge_handles %3#0, %4#0 : !transform.any_op
     transform.foreach %5 : !transform.any_op {
     ^bb0(%inner_linalg: !transform.any_op):
       %low, %high, %split_point = transform.structured.multitile_sizes %inner_linalg { dimension = 1, target_size = 10} : (!transform.any_op) -> !transform.any_op
-      %inner_linalg_low, %inner_linalg_high = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.any_op
+      %split2 = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.any_op
+      %inner_linalg_low, %inner_linalg_high = transform.split_handle %split2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
       transform.structured.tile_using_for %inner_linalg_low tile_sizes [0, %low] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
       transform.structured.tile_using_for %inner_linalg_high tile_sizes [0, %high] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
     }
@@ -111,14 +113,16 @@ module attributes {transform.with_named_sequence} {
     %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3} : (!transform.any_op) -> !transform.param<i64>
     %t:3 = transform.structured.multitile_sizes %0 { dimension = 1, target_size = 10} : (!transform.any_op) -> !transform.param<i64>
-    %2:2 = transform.structured.split %0 after %1#2 { dimension = 0 } : !transform.any_op, !transform.param<i64>
+    %split = transform.structured.split %0 after %1#2 { dimension = 0 } : !transform.any_op, !transform.param<i64>
+    %2:2 = transform.split_handle %split : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
     %3:2 = transform.structured.tile_using_for %2#0 tile_sizes [%1#0] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
     %4:2 = transform.structured.tile_using_for %2#1 tile_sizes [%1#1] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
     %5 = transform.merge_handles %3#0, %4#0 : !transform.any_op
     %tt:3 = transform.replicate num(%5) %t#0, %t#1, %t#2 : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>
     transform.foreach %5, %tt#0, %tt#1, %tt#2 : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64> {
     ^bb0(%inner_linalg: !transform.any_op, %low: !transform.param<i64>, %high: !transform.param<i64>, %split_point: !transform.param<i64>):
-      %inner_linalg_low, %inner_linalg_high = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.param<i64>
+      %split2 = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.param<i64>
+      %inner_linalg_low, %inner_linalg_high = transform.split_handle %split2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
       transform.structured.tile_using_for %inner_linalg_low tile_sizes [0, %low] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
       transform.structured.tile_using_for %inner_linalg_high tile_sizes [0, %high] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
     }

diff  --git a/mlir/test/Dialect/Linalg/transform-op-split.mlir b/mlir/test/Dialect/Linalg/transform-op-split.mlir
index e072fff4c5d771..68c849385ba6b5 100644
--- a/mlir/test/Dialect/Linalg/transform-op-split.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-split.mlir
@@ -3,7 +3,7 @@
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1:2 = transform.structured.split %0 after 42 { dimension = 0 } : !transform.any_op
+    %1 = transform.structured.split %0 after 42 { dimension = 0 } : !transform.any_op
     transform.yield
   }
 }
@@ -53,7 +53,7 @@ func.func @one_d_static(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tenso
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1:2 = transform.structured.split %0 after 42 { dimension = 0 } : !transform.any_op
+    %1 = transform.structured.split %0 after 42 { dimension = 0 } : !transform.any_op
     transform.yield
   }
 }
@@ -138,8 +138,9 @@ func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1:2 = transform.structured.split %0 after 4 { dimension = 0 } : !transform.any_op
-    %2:2 = transform.structured.split %1#1 after 16 { dimension = 1 } : !transform.any_op
+    %t = transform.structured.split %0 after 4 { dimension = 0 } : !transform.any_op
+    %1:2 = transform.split_handle %t : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %2 = transform.structured.split %1#1 after 16 { dimension = 1 } : !transform.any_op
     transform.yield
   }
 }
@@ -197,7 +198,7 @@ func.func @two_d(%arg0: tensor<10x34xf32>,
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.consumed}) {
     // expected-error @below {{expects either a dynamic or a static split point to be provided}}
-    %0:2 = "transform.structured.split"(%arg1) { dimension = 1, static_chunk_sizes = -9223372036854775808 } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %0 = "transform.structured.split"(%arg1) { dimension = 1, static_chunk_sizes = -9223372036854775808 } : (!transform.any_op) -> (!transform.any_op)
     transform.yield
   }
 }
@@ -303,7 +304,7 @@ module attributes {transform.with_named_sequence} {
     %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     // expected-error @below {{splitting does not produce the second part for a subset of targets}}
     // expected-note @below {{expected splitting to produce the second part of all or none of the targets}}
-    %1:2 = transform.structured.split %0 after 142 { dimension = 0 } : !transform.any_op
+    %1 = transform.structured.split %0 after 142 { dimension = 0 } : !transform.any_op
     transform.yield
   }
 }

diff  --git a/mlir/test/Dialect/Linalg/transform-ops.mlir b/mlir/test/Dialect/Linalg/transform-ops.mlir
index c152fc887a3a39..06a89fccd5c383 100644
--- a/mlir/test/Dialect/Linalg/transform-ops.mlir
+++ b/mlir/test/Dialect/Linalg/transform-ops.mlir
@@ -18,7 +18,8 @@ transform.sequence failures(propagate) {
 
 transform.sequence failures(propagate) {
 ^bb1(%arg0: !transform.any_op):
-  %0:2 = transform.structured.split %arg0 after 42 { dimension = 0 } : !transform.any_op
+  %t = transform.structured.split %arg0 after 42 { dimension = 0 } : !transform.any_op
+  %0:2 = transform.split_handle %t : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
   transform.structured.split %0#0 after %0#1 { dimension = 1 } : !transform.any_op, !transform.any_op
 }
 

diff  --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index d029b3bfa6b118..fb4c75b5337928 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -361,11 +361,15 @@ def testScalarize(target):
 @run
 @create_sequence
 def testSplit(target):
-    split = structured.SplitOp(target, dimension=1, chunk_sizes=42)
+    handle = structured.SplitOp(target, dimension=1, chunk_sizes=42)
+    split = transform.SplitHandleOp(
+        [transform.AnyOpType.get(), transform.AnyOpType.get()], handle
+    )
     structured.SplitOp(split.results[0], dimension=3, chunk_sizes=split.results[1])
     # CHECK-LABEL: TEST: testSplit
-    # CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1
-    # CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3
+    # CHECK: %[[G:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1
+    # CHECK: %[[F:.+]]:2 = split_handle %[[G]]
+    # CHECK: transform.structured.split %[[F]]#0 after %[[F]]#1 {dimension = 3
 
 
 @run


        


More information about the Mlir-commits mailing list