[Mlir-commits] [mlir] [mlir][linalg][transform][python] Clean up _ext.py test. (PR #66469)

Ingo Müller llvmlistbot at llvm.org
Fri Sep 15 00:53:34 PDT 2023


https://github.com/ingomueller-net created https://github.com/llvm/llvm-project/pull/66469

This PR cleans up the test of the mix-ins of this dialect. Most of the character diff is due to factoring out the creation of the the top-level sequence into a decorator. This decorator siginficantly shortens the definition of the individual tests and can be used in all but one test, where the top-level op is a PDL op. The only functional diff is due to the fact that the decator uses `transform.any_op` instead of `pdl.operation` for the type of the root handle. The only remaining usages of the PDL dialects is now in the test a PDL-related op.

>From f72d9942f6c52f27f9dda20ce175e72b38fef3c9 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Fri, 15 Sep 2023 07:39:08 +0000
Subject: [PATCH] [mlir][linalg][transform][python] Clean up _ext.py test.

This PR cleans up the test of the mix-ins of this dialect. Most of the
character diff is due to factoring out the creation of the the top-level
sequence into a decorator. This decorator siginficantly shortens the
definition of the individual tests and can be used in all but one test,
where the top-level op is a PDL op. The only functional diff is due to
the fact that the decator uses `transform.any_op` instead of
`pdl.operation` for the type of the root handle. The only remaining
usages of the PDL dialects is now in the test a PDL-related op.
---
 .../dialects/transform_structured_ext.py      | 557 +++++++-----------
 1 file changed, 202 insertions(+), 355 deletions(-)

diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index 465e01d8b658f58..19afbb895eb5b48 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -4,6 +4,7 @@
 from mlir.dialects import transform
 from mlir.dialects import pdl
 from mlir.dialects.transform import structured
+from typing import Callable
 from mlir.dialects.transform import pdl as transform_pdl
 
 
@@ -18,33 +19,40 @@ def run(f):
     return f
 
 
+def create_sequence(func: Callable) -> Callable:
+    def decorated() -> None:
+        sequence = transform.SequenceOp(
+            transform.FailurePropagationMode.Propagate,
+            [],
+            transform.AnyOpType.get(),
+        )
+        with InsertionPoint(sequence.body):
+            func(sequence.bodyTarget)
+            transform.YieldOp()
+
+    decorated.__name__ = func.__name__
+    return decorated
+
+
 @run
-def testBufferizeToAllocationOpCompact():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
-    )
-    with InsertionPoint(sequence.body):
-        structured.BufferizeToAllocationOp(sequence.bodyTarget)
-        transform.YieldOp()
+ at create_sequence
+def testBufferizeToAllocationOpCompact(target):
+    structured.BufferizeToAllocationOp(target)
     # CHECK-LABEL: TEST: testBufferizeToAllocationOpCompact
     # CHECK: transform.sequence
     # CHECK: transform.structured.bufferize_to_allocation
 
 
 @run
-def testBufferizeToAllocationOpArgs():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
+ at create_sequence
+def testBufferizeToAllocationOpArgs(target):
+    structured.BufferizeToAllocationOp(
+        target,
+        memory_space=3,
+        memcpy_op="memref.copy",
+        alloc_op="memref.alloca",
+        bufferize_destination_only=True,
     )
-    with InsertionPoint(sequence.body):
-        structured.BufferizeToAllocationOp(
-            sequence.bodyTarget,
-            memory_space=3,
-            memcpy_op="memref.copy",
-            alloc_op="memref.alloca",
-            bufferize_destination_only=True,
-        )
-        transform.YieldOp()
     # CHECK-LABEL: TEST: testBufferizeToAllocationOpArgs
     # CHECK: transform.sequence
     # CHECK: transform.structured.bufferize_to_allocation
@@ -55,78 +63,54 @@ def testBufferizeToAllocationOpArgs():
 
 
 @run
-def testDecompose():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
-    )
-    with InsertionPoint(sequence.body):
-        structured.DecomposeOp(sequence.bodyTarget)
-        transform.YieldOp()
+ at create_sequence
+def testDecompose(target):
+    structured.DecomposeOp(target)
     # CHECK-LABEL: TEST: testDecompose
     # CHECK: transform.sequence
     # CHECK: transform.structured.decompose
 
 
 @run
-def testFuseIntoContainingOpTypes():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+ at create_sequence
+def testFuseIntoContainingOpTypes(target):
+    fused = structured.MatchOp.match_op_names(target, ["test.dummy"])
+    containing = structured.MatchOp.match_op_names(target, ["test.dummy"])
+    structured.FuseIntoContainingOp(
+        transform.OperationType.get("test.dummy"),
+        transform.OperationType.get("test.dummy"),
+        fused,
+        containing,
     )
-    with InsertionPoint(sequence.body):
-        fused = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
-        containing = structured.MatchOp.match_op_names(
-            sequence.bodyTarget, ["test.dummy"]
-        )
-        structured.FuseIntoContainingOp(
-            transform.OperationType.get("test.dummy"),
-            transform.OperationType.get("test.dummy"),
-            fused,
-            containing,
-        )
-        transform.YieldOp()
     # CHECK-LABEL: TEST: testFuseIntoContainingOpTypes
     # CHECK: = transform.structured.fuse_into_containing_op
     # CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.op<"test.dummy">, !transform.op<"test.dummy">)
 
 
 @run
-def testFuseIntoContainingOpCompact():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
-    )
-    with InsertionPoint(sequence.body):
-        fused = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
-        containing = structured.MatchOp.match_op_names(
-            sequence.bodyTarget, ["test.dummy"]
-        )
-        structured.FuseIntoContainingOp(fused, containing)
-        transform.YieldOp()
+ at create_sequence
+def testFuseIntoContainingOpCompact(target):
+    fused = structured.MatchOp.match_op_names(target, ["test.dummy"])
+    containing = structured.MatchOp.match_op_names(target, ["test.dummy"])
+    structured.FuseIntoContainingOp(fused, containing)
     # CHECK-LABEL: TEST: testFuseIntoContainingOpCompact
     # CHECK: = transform.structured.fuse_into_containing_op
     # CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
 
 
 @run
-def testGeneralize():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
-    )
-    with InsertionPoint(sequence.body):
-        structured.GeneralizeOp(sequence.bodyTarget)
-        transform.YieldOp()
+ at create_sequence
+def testGeneralize(target):
+    structured.GeneralizeOp(target)
     # CHECK-LABEL: TEST: testGeneralize
     # CHECK: transform.sequence
     # CHECK: transform.structured.generalize
 
 
 @run
-def testInterchange():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
-    )
-    with InsertionPoint(sequence.body):
-        structured.InterchangeOp(sequence.bodyTarget, iterator_interchange=[1, 0])
-        transform.YieldOp()
+ at create_sequence
+def testInterchange(target):
+    structured.InterchangeOp(target, iterator_interchange=[1, 0])
     # CHECK-LABEL: TEST: testInterchange
     # CHECK: transform.sequence
     # CHECK: transform.structured.interchange
@@ -134,15 +118,11 @@ def testInterchange():
 
 
 @run
-def testMapCopyToThreadsOpCompact():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+ at create_sequence
+def testMapCopyToThreadsOpCompact(target):
+    structured.MapCopyToThreadsOp(
+        target, total_num_threads=32, desired_bit_alignment=128
     )
-    with InsertionPoint(sequence.body):
-        structured.MapCopyToThreadsOp(
-            sequence.bodyTarget, total_num_threads=32, desired_bit_alignment=128
-        )
-        transform.YieldOp()
     # CHECK-LABEL: TEST: testMapCopyToThreadsOpCompact
     # CHECK: = transform.structured.gpu.map_copy_to_threads
     # CHECK-SAME: total_num_threads = 32
@@ -151,19 +131,15 @@ def testMapCopyToThreadsOpCompact():
 
 
 @run
-def testMapCopyToThreadsOpTypes():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+ at create_sequence
+def testMapCopyToThreadsOpTypes(target):
+    structured.MapCopyToThreadsOp(
+        transform.OperationType.get("test.opA"),
+        transform.OperationType.get("test.opB"),
+        target,
+        total_num_threads=32,
+        desired_bit_alignment=128,
     )
-    with InsertionPoint(sequence.body):
-        structured.MapCopyToThreadsOp(
-            transform.OperationType.get("test.opA"),
-            transform.OperationType.get("test.opB"),
-            sequence.bodyTarget,
-            total_num_threads=32,
-            desired_bit_alignment=128,
-        )
-        transform.YieldOp()
     # CHECK-LABEL: TEST: testMapCopyToThreadsOpTypes
     # CHECK: = transform.structured.gpu.map_copy_to_threads
     # CHECK-SAME: total_num_threads = 32
@@ -172,13 +148,9 @@ def testMapCopyToThreadsOpTypes():
 
 
 @run
-def testMatchOpNamesString():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
-    )
-    with InsertionPoint(sequence.body):
-        structured.MatchOp.match_op_names(sequence.bodyTarget, "test.dummy")
-        transform.YieldOp()
+ at create_sequence
+def testMatchOpNamesString(target):
+    structured.MatchOp.match_op_names(target, "test.dummy")
     # CHECK-LABEL: TEST: testMatchOpNamesString
     # CHECK: transform.structured.match ops
     # CHECK-SAME: ["test.dummy"]
@@ -186,13 +158,9 @@ def testMatchOpNamesString():
 
 
 @run
-def testMatchOpNamesList():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
-    )
-    with InsertionPoint(sequence.body):
-        structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
-        transform.YieldOp()
+ at create_sequence
+def testMatchOpNamesList(target):
+    structured.MatchOp.match_op_names(target, ["test.dummy"])
     # CHECK-LABEL: TEST: testMatchOpNamesList
     # CHECK: transform.structured.match ops
     # CHECK-SAME: ["test.dummy"]
@@ -200,13 +168,9 @@ def testMatchOpNamesList():
 
 
 @run
-def testMaskedVectorizeStatic():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
-    )
-    with InsertionPoint(sequence.body):
-        structured.MaskedVectorizeOp(sequence.bodyTarget, [16, 4])
-        transform.YieldOp()
+ at create_sequence
+def testMaskedVectorizeStatic(target):
+    structured.MaskedVectorizeOp(target, [16, 4])
     # CHECK-LABEL: TEST: testMaskedVectorizeStatic
     # CHECK: transform.sequence
     # CHECK: transform.structured.masked_vectorize
@@ -214,14 +178,10 @@ def testMaskedVectorizeStatic():
 
 
 @run
-def testMaskedVectorizeArray():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
-    )
-    with InsertionPoint(sequence.body):
-        sizes = Attribute.parse("[16, 4]")
-        structured.MaskedVectorizeOp(sequence.bodyTarget, sizes)
-        transform.YieldOp()
+ at create_sequence
+def testMaskedVectorizeArray(target):
+    sizes = Attribute.parse("[16, 4]")
+    structured.MaskedVectorizeOp(target, sizes)
     # CHECK-LABEL: TEST: testMaskedVectorizeArray
     # CHECK: transform.sequence
     # CHECK: transform.structured.masked_vectorize
@@ -229,15 +189,11 @@ def testMaskedVectorizeArray():
 
 
 @run
-def testMaskedVectorizeMixed():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
-    )
-    with InsertionPoint(sequence.body):
-        sz1 = structured.MatchOp.match_op_names(sequence.bodyTarget, ["arith.constant"])
-        sz2 = Attribute.parse("4")
-        structured.MaskedVectorizeOp(sequence.bodyTarget, [sz1, sz2])
-        transform.YieldOp()
+ at create_sequence
+def testMaskedVectorizeMixed(target):
+    sz1 = structured.MatchOp.match_op_names(target, ["arith.constant"])
+    sz2 = Attribute.parse("4")
+    structured.MaskedVectorizeOp(target, [sz1, sz2])
     # CHECK-LABEL: TEST: testMaskedVectorizeMixed
     # CHECK: transform.sequence
     # CHECK: %[[V0:.*]] = transform.structured.match
@@ -246,15 +202,11 @@ def testMaskedVectorizeMixed():
 
 
 @run
-def testMaskedVectorizeScalable():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
-    )
-    with InsertionPoint(sequence.body):
-        sz1 = structured.MatchOp.match_op_names(sequence.bodyTarget, ["arith.constant"])
-        sz2 = Attribute.parse("4")
-        structured.MaskedVectorizeOp(sequence.bodyTarget, [16, [sz1], [sz2], [8]])
-        transform.YieldOp()
+ at create_sequence
+def testMaskedVectorizeScalable(target):
+    sz1 = structured.MatchOp.match_op_names(target, ["arith.constant"])
+    sz2 = Attribute.parse("4")
+    structured.MaskedVectorizeOp(target, [16, [sz1], [sz2], [8]])
     # CHECK-LABEL: TEST: testMaskedVectorizeScalable
     # CHECK: transform.sequence
     # CHECK-DAG: %[[V0:.*]] = transform.structured.match
@@ -263,15 +215,9 @@ def testMaskedVectorizeScalable():
 
 
 @run
-def testMaskedVectorizeArgs():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
-    )
-    with InsertionPoint(sequence.body):
-        structured.MaskedVectorizeOp(
-            sequence.bodyTarget, [16, 4], vectorize_nd_extract=True
-        )
-        transform.YieldOp()
+ at create_sequence
+def testMaskedVectorizeArgs(target):
+    structured.MaskedVectorizeOp(target, [16, 4], vectorize_nd_extract=True)
     # CHECK-LABEL: TEST: testMaskedVectorizeArgs
     # CHECK: transform.sequence
     # CHECK: transform.structured.masked_vectorize
@@ -279,17 +225,13 @@ def testMaskedVectorizeArgs():
 
 
 @run
-def testMatchOpNamesTyped():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+ at create_sequence
+def testMatchOpNamesTyped(target):
+    structured.MatchOp.match_op_names(
+        transform.OperationType.get("test.dummy"),
+        target,
+        ["test.dummy"],
     )
-    with InsertionPoint(sequence.body):
-        structured.MatchOp.match_op_names(
-            transform.OperationType.get("test.dummy"),
-            sequence.bodyTarget,
-            ["test.dummy"],
-        )
-        transform.YieldOp()
     # CHECK-LABEL: TEST: testMatchOpNamesTyped
     # CHECK: transform.structured.match ops
     # CHECK-SAME: ["test.dummy"]
@@ -297,15 +239,11 @@ def testMatchOpNamesTyped():
 
 
 @run
-def testMultitileSizesCompact():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
+ at create_sequence
+def testMultitileSizesCompact(target):
+    structured.MultiTileSizesOp(
+        transform.AnyOpType.get(), target, dimension=1, target_size=42
     )
-    with InsertionPoint(sequence.body):
-        structured.MultiTileSizesOp(
-            pdl.OperationType.get(), sequence.bodyTarget, dimension=1, target_size=42
-        )
-        transform.YieldOp()
     # CHECK-LABEL: TEST: testMultitileSizes
     # CHECK: transform.sequence
     # CHECK-NOT: divisor
@@ -318,19 +256,15 @@ def testMultitileSizesCompact():
 
 
 @run
-def testMultitileSizesAllArgs():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
+ at create_sequence
+def testMultitileSizesAllArgs(target):
+    structured.MultiTileSizesOp(
+        transform.AnyOpType.get(),
+        target,
+        dimension=1,
+        target_size=42,
+        divisor=2,
     )
-    with InsertionPoint(sequence.body):
-        structured.MultiTileSizesOp(
-            pdl.OperationType.get(),
-            sequence.bodyTarget,
-            dimension=1,
-            target_size=42,
-            divisor=2,
-        )
-        transform.YieldOp()
     # CHECK-LABEL: TEST: testMultitileSizes
     # CHECK: transform.sequence
     # CHECK: transform.structured.multitile_sizes
@@ -340,13 +274,9 @@ def testMultitileSizesAllArgs():
 
 
 @run
-def testPadOpNoArgs():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
-    )
-    with InsertionPoint(sequence.body):
-        structured.PadOp(sequence.bodyTarget)
-        transform.YieldOp()
+ at create_sequence
+def testPadOpNoArgs(target):
+    structured.PadOp(target)
     # CHECK-LABEL: TEST: testPadOpNoArgs
     # CHECK: transform.sequence
     # CHECK: transform.structured.pad
@@ -359,21 +289,17 @@ def testPadOpNoArgs():
 
 
 @run
-def testPadOpArgs():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
+ at create_sequence
+def testPadOpArgs(target):
+    structured.PadOp(
+        target,
+        padding_values=[FloatAttr.get_f32(42.0), StringAttr.get("0")],
+        padding_dimensions=Attribute.parse("[1]"),
+        pad_to_multiple_of=[128],
+        pack_paddings=[0],
+        transpose_paddings=[[1, Attribute.parse("0")], Attribute.parse("[0, 1]")],
+        copy_back_op="linalg.copy",
     )
-    with InsertionPoint(sequence.body):
-        structured.PadOp(
-            sequence.bodyTarget,
-            padding_values=[FloatAttr.get_f32(42.0), StringAttr.get("0")],
-            padding_dimensions=Attribute.parse("[1]"),
-            pad_to_multiple_of=[128],
-            pack_paddings=[0],
-            transpose_paddings=[[1, Attribute.parse("0")], Attribute.parse("[0, 1]")],
-            copy_back_op="linalg.copy",
-        )
-        transform.YieldOp()
     # CHECK-LABEL: TEST: testPadOpArgs
     # CHECK: transform.sequence
     # CHECK: transform.structured.pad
@@ -386,39 +312,27 @@ def testPadOpArgs():
 
 
 @run
-def testScalarize():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
-    )
-    with InsertionPoint(sequence.body):
-        structured.ScalarizeOp(sequence.bodyTarget)
-        transform.YieldOp()
+ at create_sequence
+def testScalarize(target):
+    structured.ScalarizeOp(target)
     # CHECK-LABEL: TEST: testScalarize
     # CHECK: transform.structured.scalarize
 
 
 @run
-def testSplit():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
-    )
-    with InsertionPoint(sequence.body):
-        split = structured.SplitOp(sequence.bodyTarget, dimension=1, split_point=42)
-        structured.SplitOp(split.results[0], dimension=3, split_point=split.results[1])
-        transform.YieldOp()
+ at create_sequence
+def testSplit(target):
+    split = structured.SplitOp(target, dimension=1, split_point=42)
+    structured.SplitOp(split.results[0], dimension=3, split_point=split.results[1])
     # CHECK-LABEL: TEST: testSplit
     # CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1
     # CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3
 
 
 @run
-def testTileCompact():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
-    )
-    with InsertionPoint(sequence.body):
-        structured.TileOp(sequence.bodyTarget, sizes=[4, 8], interchange=[0, 1])
-        transform.YieldOp()
+ at create_sequence
+def testTileCompact(target):
+    structured.TileOp(target, sizes=[4, 8], interchange=[0, 1])
     # CHECK-LABEL: TEST: testTileCompact
     # CHECK: transform.sequence
     # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8]
@@ -426,15 +340,11 @@ def testTileCompact():
 
 
 @run
-def testTileAttributes():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
-    )
+ at create_sequence
+def testTileAttributes(target):
     attr = DenseI64ArrayAttr.get([4, 8])
     ichange = DenseI64ArrayAttr.get([0, 1])
-    with InsertionPoint(sequence.body):
-        structured.TileOp(sequence.bodyTarget, sizes=attr, interchange=ichange)
-        transform.YieldOp()
+    structured.TileOp(target, sizes=attr, interchange=ichange)
     # CHECK-LABEL: TEST: testTileAttributes
     # CHECK: transform.sequence
     # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8]
@@ -442,15 +352,9 @@ def testTileAttributes():
 
 
 @run
-def testTileZero():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
-    )
-    with InsertionPoint(sequence.body):
-        structured.TileOp(
-            sequence.bodyTarget, sizes=[4, 0, 2, 0], interchange=[0, 1, 2, 3]
-        )
-        transform.YieldOp()
+ at create_sequence
+def testTileZero(target):
+    structured.TileOp(target, sizes=[4, 0, 2, 0], interchange=[0, 1, 2, 3])
     # CHECK-LABEL: TEST: testTileZero
     # CHECK: transform.sequence
     # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 0, 2, 0]
@@ -480,32 +384,22 @@ def testTileDynamic():
 
 
 @run
-def testTileExplicitLoopTypeSingle():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
-    )
-    with InsertionPoint(sequence.body):
-        structured.TileOp(
-            transform.OperationType.get("scf.for"), sequence.bodyTarget, sizes=[2, 3, 4]
-        )
-        transform.YieldOp()
+ at create_sequence
+def testTileExplicitLoopTypeSingle(target):
+    structured.TileOp(transform.OperationType.get("scf.for"), target, sizes=[2, 3, 4])
     # CHECK-LABEL: TEST: testTileExplicitLoopTypeSingle
     # CHECK: = transform.structured.tile %{{.*}} : (!{{.*}}) ->
     # CHECK-COUNT-3: !transform.op<"scf.for">
 
 
 @run
-def testTileExplicitLoopTypeAll():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
-    )
+ at create_sequence
+def testTileExplicitLoopTypeAll(target):
     types = [
         transform.OperationType.get(x)
         for x in ["scf.for", "scf.parallel", "scf.forall"]
     ]
-    with InsertionPoint(sequence.body):
-        structured.TileOp(types, sequence.bodyTarget, sizes=[2, 3, 4])
-        transform.YieldOp()
+    structured.TileOp(types, target, sizes=[2, 3, 4])
     # CHECK-LABEL: TEST: testTileExplicitLoopTypeAll
     # CHECK: = transform.structured.tile
     # CHECK-SAME : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">,
@@ -513,31 +407,22 @@ def testTileExplicitLoopTypeAll():
 
 
 @run
-def testTileScalable():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+ at create_sequence
+def testTileScalable(target):
+    structured.TileOp(
+        target,
+        sizes=[4, [2]],
     )
-    with InsertionPoint(sequence.body):
-        structured.TileOp(
-            sequence.bodyTarget,
-            sizes=[4, [2]],
-        )
-        transform.YieldOp()
     # CHECK-LABEL: TEST: testTileScalable
     # CHECK: transform.sequence
     # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, [2]]
 
 
 @run
-def testTileToForallCompact():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate,
-        [],
-        transform.OperationType.get("linalg.matmul"),
-    )
-    with InsertionPoint(sequence.body):
-        structured.TileToForallOp(sequence.bodyTarget, num_threads=[2, 3, 4])
-        transform.YieldOp()
+ at create_sequence
+def testTileToForallCompact(target):
+    matmul = transform.CastOp(transform.OperationType.get("linalg.matmul"), target)
+    structured.TileToForallOp(matmul, num_threads=[2, 3, 4])
     # CHECK-LABEL: TEST: testTileToForallCompact
     # CHECK: = transform.structured.tile_to_forall_op
     # CHECK-SAME: num_threads [2, 3, 4] tile_sizes []
@@ -545,18 +430,14 @@ def testTileToForallCompact():
 
 
 @run
-def testTileToForallLoopsAndTileOpTypes():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+ at create_sequence
+def testTileToForallLoopsAndTileOpTypes(target):
+    structured.TileToForallOp(
+        transform.OperationType.get("scf.forall"),  # loops_type
+        transform.OperationType.get("linalg.matmul"),  # tiled_op_type
+        target,
+        num_threads=[2, 3, 4],
     )
-    with InsertionPoint(sequence.body):
-        structured.TileToForallOp(
-            transform.OperationType.get("scf.forall"),  # loops_type
-            transform.OperationType.get("linalg.matmul"),  # tiled_op_type
-            sequence.bodyTarget,
-            num_threads=[2, 3, 4],
-        )
-        transform.YieldOp()
     # CHECK-LABEL: TEST: testTileToForallLoopsAndTileOpTypes
     # CHECK: = transform.structured.tile_to_forall_op
     # CHECK-SAME: num_threads [2, 3, 4] tile_sizes []
@@ -564,76 +445,54 @@ def testTileToForallLoopsAndTileOpTypes():
 
 
 @run
-def testTileToForallTileSizes():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
-    )
-    with InsertionPoint(sequence.body):
-        structured.TileToForallOp(sequence.bodyTarget, tile_sizes=[2, 3, 4])
-        transform.YieldOp()
+ at create_sequence
+def testTileToForallTileSizes(target):
+    structured.TileToForallOp(target, tile_sizes=[2, 3, 4])
     # CHECK-LABEL: TEST: testTileToForallTileSizes
     # CHECK: = transform.structured.tile_to_forall_op
     # CHECK-SAME: num_threads [] tile_sizes [2, 3, 4]
 
 
 @run
-def testTileToForallMixedDynamic():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
-    )
-    with InsertionPoint(sequence.body):
-        n = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
-        structured.TileToForallOp(sequence.bodyTarget, num_threads=[n, 3, 4])
-        transform.YieldOp()
+ at create_sequence
+def testTileToForallMixedDynamic(target):
+    n = structured.MatchOp.match_op_names(target, ["test.dummy"])
+    structured.TileToForallOp(target, num_threads=[n, 3, 4])
     # CHECK-LABEL: TEST: testTileToForallMixedDynamic
     # CHECK: = transform.structured.tile_to_forall_op
     # CHECK-SAME: num_threads [%{{.*}} : !transform.any_op, 3, 4]
 
 
 @run
-def testTileToForallPackedDynamic():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
-    )
-    with InsertionPoint(sequence.body):
-        n = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
-        structured.TileToForallOp(sequence.bodyTarget, num_threads=n)
-        transform.YieldOp()
+ at create_sequence
+def testTileToForallPackedDynamic(target):
+    n = structured.MatchOp.match_op_names(target, ["test.dummy"])
+    structured.TileToForallOp(target, num_threads=n)
     # CHECK-LABEL: TEST: testTileToForallPackedDynamic
     # CHECK: = transform.structured.tile_to_forall_op
     # CHECK-SAME: num_threads *(%0 : !transform.any_op)
 
 
 @run
-def testTileToForallMapping():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
-    )
-    with InsertionPoint(sequence.body):
-        mapping = Attribute.parse("[ #gpu.thread<y>, #gpu.thread<x> ]")
-        structured.TileToForallOp(
-            sequence.bodyTarget, num_threads=[2, 3], mapping=mapping
-        )
-        transform.YieldOp()
+ at create_sequence
+def testTileToForallMapping(target):
+    mapping = Attribute.parse("[ #gpu.thread<y>, #gpu.thread<x> ]")
+    structured.TileToForallOp(target, num_threads=[2, 3], mapping=mapping)
     # CHECK-LABEL: TEST: testTileToForallMapping
     # CHECK: = transform.structured.tile_to_forall_op
     # CHECK-SAME: mapping = [#gpu.thread<y>, #gpu.thread<x>]
 
 
 @run
-def testVectorizeAllAttrs():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
+ at create_sequence
+def testVectorizeAllAttrs(target):
+    structured.VectorizeOp(
+        target,
+        disable_multi_reduction_to_contract_patterns=True,
+        disable_transfer_permutation_map_lowering_patterns=True,
+        vectorize_nd_extract=True,
+        vectorize_padding=True,
     )
-    with InsertionPoint(sequence.body):
-        structured.VectorizeOp(
-            sequence.bodyTarget,
-            disable_multi_reduction_to_contract_patterns=True,
-            disable_transfer_permutation_map_lowering_patterns=True,
-            vectorize_nd_extract=True,
-            vectorize_padding=True,
-        )
-        transform.YieldOp()
     # CHECK-LABEL: TEST: testVectorizeAllAttrs
     # CHECK: transform.sequence
     # CHECK: = transform.structured.vectorize
@@ -644,19 +503,15 @@ def testVectorizeAllAttrs():
 
 
 @run
-def testVectorizeNoAttrs():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
+ at create_sequence
+def testVectorizeNoAttrs(target):
+    structured.VectorizeOp(
+        target,
+        disable_multi_reduction_to_contract_patterns=False,
+        disable_transfer_permutation_map_lowering_patterns=False,
+        vectorize_nd_extract=False,
+        vectorize_padding=False,
     )
-    with InsertionPoint(sequence.body):
-        structured.VectorizeOp(
-            sequence.bodyTarget,
-            disable_multi_reduction_to_contract_patterns=False,
-            disable_transfer_permutation_map_lowering_patterns=False,
-            vectorize_nd_extract=False,
-            vectorize_padding=False,
-        )
-        transform.YieldOp()
     # CHECK-LABEL: TEST: testVectorizeNoAttrs
     # CHECK: transform.sequence
     # CHECK: = transform.structured.vectorize
@@ -667,20 +522,16 @@ def testVectorizeNoAttrs():
 
 
 @run
-def testMatchInterfaceEnum():
+ at create_sequence
+def testMatchInterfaceEnum(target):
     names = ArrayAttr.get([StringAttr.get("test.dummy")])
     result_type = transform.AnyOpType.get()
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+    fused = structured.MatchOp.__base__(
+        result_type,
+        target,
+        ops=names,
+        interface=structured.MatchInterfaceEnum.LinalgOp,
     )
-    with InsertionPoint(sequence.body):
-        fused = structured.MatchOp.__base__(
-            result_type,
-            sequence.bodyTarget,
-            ops=names,
-            interface=structured.MatchInterfaceEnum.LinalgOp,
-        )
-        transform.YieldOp()
     # CHECK-LABEL: TEST: testMatchInterfaceEnum
     # CHECK: transform.sequence
     # CHECK: = transform.structured.match
@@ -688,7 +539,8 @@ def testMatchInterfaceEnum():
 
 
 @run
-def testMatchInterfaceEnumReplaceAttributeBuilder():
+ at create_sequence
+def testMatchInterfaceEnumReplaceAttributeBuilder(target):
     @register_attribute_builder("MatchInterfaceEnum", replace=True)
     def match_interface_enum(x, context):
         if x == "LinalgOp":
@@ -699,17 +551,12 @@ def match_interface_enum(x, context):
 
     names = ArrayAttr.get([StringAttr.get("test.dummy")])
     result_type = transform.AnyOpType.get()
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+    fused = structured.MatchOp.__base__(
+        result_type,
+        target,
+        ops=names,
+        interface="TilingInterface",
     )
-    with InsertionPoint(sequence.body):
-        fused = structured.MatchOp.__base__(
-            result_type,
-            sequence.bodyTarget,
-            ops=names,
-            interface="TilingInterface",
-        )
-        transform.YieldOp()
     # CHECK-LABEL: TEST: testMatchInterfaceEnumReplaceAttributeBuilder
     # CHECK: transform.sequence
     # CHECK: = transform.structured.match



More information about the Mlir-commits mailing list