[Mlir-commits] [mlir] [mlir][linalg][transform][python] Clean up _ext.py test. (PR #66469)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 15 00:54:36 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir
            
<details>
<summary>Changes</summary>
This PR cleans up the test of the mix-ins of this dialect. Most of the character diff is due to factoring out the creation of the the top-level sequence into a decorator. This decorator siginficantly shortens the definition of the individual tests and can be used in all but one test, where the top-level op is a PDL op. The only functional diff is due to the fact that the decator uses `transform.any_op` instead of `pdl.operation` for the type of the root handle. The only remaining usages of the PDL dialects is now in the test a PDL-related op.
--

Patch is 31.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/66469.diff

1 Files Affected:

- (modified) mlir/test/python/dialects/transform_structured_ext.py (+202-355) 


<pre>
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index 465e01d8b658f58..19afbb895eb5b48 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -4,6 +4,7 @@
 from mlir.dialects import transform
 from mlir.dialects import pdl
 from mlir.dialects.transform import structured
+from typing import Callable
 from mlir.dialects.transform import pdl as transform_pdl
 
 
@@ -18,33 +19,40 @@ def run(f):
     return f
 
 
+def create_sequence(func: Callable) -&gt; Callable:
+    def decorated() -&gt; None:
+        sequence = transform.SequenceOp(
+            transform.FailurePropagationMode.Propagate,
+            [],
+            transform.AnyOpType.get(),
+        )
+        with InsertionPoint(sequence.body):
+            func(sequence.bodyTarget)
+            transform.YieldOp()
+
+    decorated.__name__ = func.__name__
+    return decorated
+
+
 @run
-def testBufferizeToAllocationOpCompact():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
-    )
-    with InsertionPoint(sequence.body):
-        structured.BufferizeToAllocationOp(sequence.bodyTarget)
-        transform.YieldOp()
+ at create_sequence
+def testBufferizeToAllocationOpCompact(target):
+    structured.BufferizeToAllocationOp(target)
     # CHECK-LABEL: TEST: testBufferizeToAllocationOpCompact
     # CHECK: transform.sequence
     # CHECK: transform.structured.bufferize_to_allocation
 
 
 @run
-def testBufferizeToAllocationOpArgs():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
+ at create_sequence
+def testBufferizeToAllocationOpArgs(target):
+    structured.BufferizeToAllocationOp(
+        target,
+        memory_space=3,
+        memcpy_op=&quot;memref.copy&quot;,
+        alloc_op=&quot;memref.alloca&quot;,
+        bufferize_destination_only=True,
     )
-    with InsertionPoint(sequence.body):
-        structured.BufferizeToAllocationOp(
-            sequence.bodyTarget,
-            memory_space=3,
-            memcpy_op=&quot;memref.copy&quot;,
-            alloc_op=&quot;memref.alloca&quot;,
-            bufferize_destination_only=True,
-        )
-        transform.YieldOp()
     # CHECK-LABEL: TEST: testBufferizeToAllocationOpArgs
     # CHECK: transform.sequence
     # CHECK: transform.structured.bufferize_to_allocation
@@ -55,78 +63,54 @@ def testBufferizeToAllocationOpArgs():
 
 
 @run
-def testDecompose():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
-    )
-    with InsertionPoint(sequence.body):
-        structured.DecomposeOp(sequence.bodyTarget)
-        transform.YieldOp()
+ at create_sequence
+def testDecompose(target):
+    structured.DecomposeOp(target)
     # CHECK-LABEL: TEST: testDecompose
     # CHECK: transform.sequence
     # CHECK: transform.structured.decompose
 
 
 @run
-def testFuseIntoContainingOpTypes():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+ at create_sequence
+def testFuseIntoContainingOpTypes(target):
+    fused = structured.MatchOp.match_op_names(target, [&quot;test.dummy&quot;])
+    containing = structured.MatchOp.match_op_names(target, [&quot;test.dummy&quot;])
+    structured.FuseIntoContainingOp(
+        transform.OperationType.get(&quot;test.dummy&quot;),
+        transform.OperationType.get(&quot;test.dummy&quot;),
+        fused,
+        containing,
     )
-    with InsertionPoint(sequence.body):
-        fused = structured.MatchOp.match_op_names(sequence.bodyTarget, [&quot;test.dummy&quot;])
-        containing = structured.MatchOp.match_op_names(
-            sequence.bodyTarget, [&quot;test.dummy&quot;]
-        )
-        structured.FuseIntoContainingOp(
-            transform.OperationType.get(&quot;test.dummy&quot;),
-            transform.OperationType.get(&quot;test.dummy&quot;),
-            fused,
-            containing,
-        )
-        transform.YieldOp()
     # CHECK-LABEL: TEST: testFuseIntoContainingOpTypes
     # CHECK: = transform.structured.fuse_into_containing_op
     # CHECK-SAME: (!transform.any_op, !transform.any_op) -&gt; (!transform.op&lt;&quot;test.dummy&quot;&gt;, !transform.op&lt;&quot;test.dummy&quot;&gt;)
 
 
 @run
-def testFuseIntoContainingOpCompact():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
-    )
-    with InsertionPoint(sequence.body):
-        fused = structured.MatchOp.match_op_names(sequence.bodyTarget, [&quot;test.dummy&quot;])
-        containing = structured.MatchOp.match_op_names(
-            sequence.bodyTarget, [&quot;test.dummy&quot;]
-        )
-        structured.FuseIntoContainingOp(fused, containing)
-        transform.YieldOp()
+ at create_sequence
+def testFuseIntoContainingOpCompact(target):
+    fused = structured.MatchOp.match_op_names(target, [&quot;test.dummy&quot;])
+    containing = structured.MatchOp.match_op_names(target, [&quot;test.dummy&quot;])
+    structured.FuseIntoContainingOp(fused, containing)
     # CHECK-LABEL: TEST: testFuseIntoContainingOpCompact
     # CHECK: = transform.structured.fuse_into_containing_op
     # CHECK-SAME: (!transform.any_op, !transform.any_op) -&gt; (!transform.any_op, !transform.any_op)
 
 
 @run
-def testGeneralize():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
-    )
-    with InsertionPoint(sequence.body):
-        structured.GeneralizeOp(sequence.bodyTarget)
-        transform.YieldOp()
+ at create_sequence
+def testGeneralize(target):
+    structured.GeneralizeOp(target)
     # CHECK-LABEL: TEST: testGeneralize
     # CHECK: transform.sequence
     # CHECK: transform.structured.generalize
 
 
 @run
-def testInterchange():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
-    )
-    with InsertionPoint(sequence.body):
-        structured.InterchangeOp(sequence.bodyTarget, iterator_interchange=[1, 0])
-        transform.YieldOp()
+ at create_sequence
+def testInterchange(target):
+    structured.InterchangeOp(target, iterator_interchange=[1, 0])
     # CHECK-LABEL: TEST: testInterchange
     # CHECK: transform.sequence
     # CHECK: transform.structured.interchange
@@ -134,15 +118,11 @@ def testInterchange():
 
 
 @run
-def testMapCopyToThreadsOpCompact():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+ at create_sequence
+def testMapCopyToThreadsOpCompact(target):
+    structured.MapCopyToThreadsOp(
+        target, total_num_threads=32, desired_bit_alignment=128
     )
-    with InsertionPoint(sequence.body):
-        structured.MapCopyToThreadsOp(
-            sequence.bodyTarget, total_num_threads=32, desired_bit_alignment=128
-        )
-        transform.YieldOp()
     # CHECK-LABEL: TEST: testMapCopyToThreadsOpCompact
     # CHECK: = transform.structured.gpu.map_copy_to_threads
     # CHECK-SAME: total_num_threads = 32
@@ -151,19 +131,15 @@ def testMapCopyToThreadsOpCompact():
 
 
 @run
-def testMapCopyToThreadsOpTypes():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+ at create_sequence
+def testMapCopyToThreadsOpTypes(target):
+    structured.MapCopyToThreadsOp(
+        transform.OperationType.get(&quot;test.opA&quot;),
+        transform.OperationType.get(&quot;test.opB&quot;),
+        target,
+        total_num_threads=32,
+        desired_bit_alignment=128,
     )
-    with InsertionPoint(sequence.body):
-        structured.MapCopyToThreadsOp(
-            transform.OperationType.get(&quot;test.opA&quot;),
-            transform.OperationType.get(&quot;test.opB&quot;),
-            sequence.bodyTarget,
-            total_num_threads=32,
-            desired_bit_alignment=128,
-        )
-        transform.YieldOp()
     # CHECK-LABEL: TEST: testMapCopyToThreadsOpTypes
     # CHECK: = transform.structured.gpu.map_copy_to_threads
     # CHECK-SAME: total_num_threads = 32
@@ -172,13 +148,9 @@ def testMapCopyToThreadsOpTypes():
 
 
 @run
-def testMatchOpNamesString():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
-    )
-    with InsertionPoint(sequence.body):
-        structured.MatchOp.match_op_names(sequence.bodyTarget, &quot;test.dummy&quot;)
-        transform.YieldOp()
+ at create_sequence
+def testMatchOpNamesString(target):
+    structured.MatchOp.match_op_names(target, &quot;test.dummy&quot;)
     # CHECK-LABEL: TEST: testMatchOpNamesString
     # CHECK: transform.structured.match ops
     # CHECK-SAME: [&quot;test.dummy&quot;]
@@ -186,13 +158,9 @@ def testMatchOpNamesString():
 
 
 @run
-def testMatchOpNamesList():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
-    )
-    with InsertionPoint(sequence.body):
-        structured.MatchOp.match_op_names(sequence.bodyTarget, [&quot;test.dummy&quot;])
-        transform.YieldOp()
+ at create_sequence
+def testMatchOpNamesList(target):
+    structured.MatchOp.match_op_names(target, [&quot;test.dummy&quot;])
     # CHECK-LABEL: TEST: testMatchOpNamesList
     # CHECK: transform.structured.match ops
     # CHECK-SAME: [&quot;test.dummy&quot;]
@@ -200,13 +168,9 @@ def testMatchOpNamesList():
 
 
 @run
-def testMaskedVectorizeStatic():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
-    )
-    with InsertionPoint(sequence.body):
-        structured.MaskedVectorizeOp(sequence.bodyTarget, [16, 4])
-        transform.YieldOp()
+ at create_sequence
+def testMaskedVectorizeStatic(target):
+    structured.MaskedVectorizeOp(target, [16, 4])
     # CHECK-LABEL: TEST: testMaskedVectorizeStatic
     # CHECK: transform.sequence
     # CHECK: transform.structured.masked_vectorize
@@ -214,14 +178,10 @@ def testMaskedVectorizeStatic():
 
 
 @run
-def testMaskedVectorizeArray():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
-    )
-    with InsertionPoint(sequence.body):
-        sizes = Attribute.parse(&quot;[16, 4]&quot;)
-        structured.MaskedVectorizeOp(sequence.bodyTarget, sizes)
-        transform.YieldOp()
+ at create_sequence
+def testMaskedVectorizeArray(target):
+    sizes = Attribute.parse(&quot;[16, 4]&quot;)
+    structured.MaskedVectorizeOp(target, sizes)
     # CHECK-LABEL: TEST: testMaskedVectorizeArray
     # CHECK: transform.sequence
     # CHECK: transform.structured.masked_vectorize
@@ -229,15 +189,11 @@ def testMaskedVectorizeArray():
 
 
 @run
-def testMaskedVectorizeMixed():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
-    )
-    with InsertionPoint(sequence.body):
-        sz1 = structured.MatchOp.match_op_names(sequence.bodyTarget, [&quot;arith.constant&quot;])
-        sz2 = Attribute.parse(&quot;4&quot;)
-        structured.MaskedVectorizeOp(sequence.bodyTarget, [sz1, sz2])
-        transform.YieldOp()
+ at create_sequence
+def testMaskedVectorizeMixed(target):
+    sz1 = structured.MatchOp.match_op_names(target, [&quot;arith.constant&quot;])
+    sz2 = Attribute.parse(&quot;4&quot;)
+    structured.MaskedVectorizeOp(target, [sz1, sz2])
     # CHECK-LABEL: TEST: testMaskedVectorizeMixed
     # CHECK: transform.sequence
     # CHECK: %[[V0:.*]] = transform.structured.match
@@ -246,15 +202,11 @@ def testMaskedVectorizeMixed():
 
 
 @run
-def testMaskedVectorizeScalable():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
-    )
-    with InsertionPoint(sequence.body):
-        sz1 = structured.MatchOp.match_op_names(sequence.bodyTarget, [&quot;arith.constant&quot;])
-        sz2 = Attribute.parse(&quot;4&quot;)
-        structured.MaskedVectorizeOp(sequence.bodyTarget, [16, [sz1], [sz2], [8]])
-        transform.YieldOp()
+ at create_sequence
+def testMaskedVectorizeScalable(target):
+    sz1 = structured.MatchOp.match_op_names(target, [&quot;arith.constant&quot;])
+    sz2 = Attribute.parse(&quot;4&quot;)
+    structured.MaskedVectorizeOp(target, [16, [sz1], [sz2], [8]])
     # CHECK-LABEL: TEST: testMaskedVectorizeScalable
     # CHECK: transform.sequence
     # CHECK-DAG: %[[V0:.*]] = transform.structured.match
@@ -263,15 +215,9 @@ def testMaskedVectorizeScalable():
 
 
 @run
-def testMaskedVectorizeArgs():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
-    )
-    with InsertionPoint(sequence.body):
-        structured.MaskedVectorizeOp(
-            sequence.bodyTarget, [16, 4], vectorize_nd_extract=True
-        )
-        transform.YieldOp()
+ at create_sequence
+def testMaskedVectorizeArgs(target):
+    structured.MaskedVectorizeOp(target, [16, 4], vectorize_nd_extract=True)
     # CHECK-LABEL: TEST: testMaskedVectorizeArgs
     # CHECK: transform.sequence
     # CHECK: transform.structured.masked_vectorize
@@ -279,17 +225,13 @@ def testMaskedVectorizeArgs():
 
 
 @run
-def testMatchOpNamesTyped():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+ at create_sequence
+def testMatchOpNamesTyped(target):
+    structured.MatchOp.match_op_names(
+        transform.OperationType.get(&quot;test.dummy&quot;),
+        target,
+        [&quot;test.dummy&quot;],
     )
-    with InsertionPoint(sequence.body):
-        structured.MatchOp.match_op_names(
-            transform.OperationType.get(&quot;test.dummy&quot;),
-            sequence.bodyTarget,
-            [&quot;test.dummy&quot;],
-        )
-        transform.YieldOp()
     # CHECK-LABEL: TEST: testMatchOpNamesTyped
     # CHECK: transform.structured.match ops
     # CHECK-SAME: [&quot;test.dummy&quot;]
@@ -297,15 +239,11 @@ def testMatchOpNamesTyped():
 
 
 @run
-def testMultitileSizesCompact():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
+ at create_sequence
+def testMultitileSizesCompact(target):
+    structured.MultiTileSizesOp(
+        transform.AnyOpType.get(), target, dimension=1, target_size=42
     )
-    with InsertionPoint(sequence.body):
-        structured.MultiTileSizesOp(
-            pdl.OperationType.get(), sequence.bodyTarget, dimension=1, target_size=42
-        )
-        transform.YieldOp()
     # CHECK-LABEL: TEST: testMultitileSizes
     # CHECK: transform.sequence
     # CHECK-NOT: divisor
@@ -318,19 +256,15 @@ def testMultitileSizesCompact():
 
 
 @run
-def testMultitileSizesAllArgs():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
+ at create_sequence
+def testMultitileSizesAllArgs(target):
+    structured.MultiTileSizesOp(
+        transform.AnyOpType.get(),
+        target,
+        dimension=1,
+        target_size=42,
+        divisor=2,
     )
-    with InsertionPoint(sequence.body):
-        structured.MultiTileSizesOp(
-            pdl.OperationType.get(),
-            sequence.bodyTarget,
-            dimension=1,
-            target_size=42,
-            divisor=2,
-        )
-        transform.YieldOp()
     # CHECK-LABEL: TEST: testMultitileSizes
     # CHECK: transform.sequence
     # CHECK: transform.structured.multitile_sizes
@@ -340,13 +274,9 @@ def testMultitileSizesAllArgs():
 
 
 @run
-def testPadOpNoArgs():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
-    )
-    with InsertionPoint(sequence.body):
-        structured.PadOp(sequence.bodyTarget)
-        transform.YieldOp()
+ at create_sequence
+def testPadOpNoArgs(target):
+    structured.PadOp(target)
     # CHECK-LABEL: TEST: testPadOpNoArgs
     # CHECK: transform.sequence
     # CHECK: transform.structured.pad
@@ -359,21 +289,17 @@ def testPadOpNoArgs():
 
 
 @run
-def testPadOpArgs():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
+ at create_sequence
+def testPadOpArgs(target):
+    structured.PadOp(
+        target,
+        padding_values=[FloatAttr.get_f32(42.0), StringAttr.get(&quot;0&quot;)],
+        padding_dimensions=Attribute.parse(&quot;[1]&quot;),
+        pad_to_multiple_of=[128],
+        pack_paddings=[0],
+        transpose_paddings=[[1, Attribute.parse(&quot;0&quot;)], Attribute.parse(&quot;[0, 1]&quot;)],
+        copy_back_op=&quot;linalg.copy&quot;,
     )
-    with InsertionPoint(sequence.body):
-        structured.PadOp(
-            sequence.bodyTarget,
-            padding_values=[FloatAttr.get_f32(42.0), StringAttr.get(&quot;0&quot;)],
-            padding_dimensions=Attribute.parse(&quot;[1]&quot;),
-            pad_to_multiple_of=[128],
-            pack_paddings=[0],
-            transpose_paddings=[[1, Attribute.parse(&quot;0&quot;)], Attribute.parse(&quot;[0, 1]&quot;)],
-            copy_back_op=&quot;linalg.copy&quot;,
-        )
-        transform.YieldOp()
     # CHECK-LABEL: TEST: testPadOpArgs
     # CHECK: transform.sequence
     # CHECK: transform.structured.pad
@@ -386,39 +312,27 @@ def testPadOpArgs():
 
 
 @run
-def testScalarize():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
-    )
-    with InsertionPoint(sequence.body):
-        structured.ScalarizeOp(sequence.bodyTarget)
-        transform.YieldOp()
+ at create_sequence
+def testScalarize(target):
+    structured.ScalarizeOp(target)
     # CHECK-LABEL: TEST: testScalarize
     # CHECK: transform.structured.scalarize
 
 
 @run
-def testSplit():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
-    )
-    with InsertionPoint(sequence.body):
-        split = structured.SplitOp(sequence.bodyTarget, dimension=1, split_point=42)
-        structured.SplitOp(split.results[0], dimension=3, split_point=split.results[1])
-        transform.YieldOp()
+ at create_sequence
+def testSplit(target):
+    split = structured.SplitOp(target, dimension=1, split_point=42)
+    structured.SplitOp(split.results[0], dimension=3, split_point=split.results[1])
     # CHECK-LABEL: TEST: testSplit
     # CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1
     # CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3
 
 
 @run
-def testTileCompact():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
-    )
-    with InsertionPoint(sequence.body):
-        structured.TileOp(sequence.bodyTarget, sizes=[4, 8], interchange=[0, 1])
-        transform.YieldOp()
+ at create_sequence
+def testTileCompact(target):
+    structured.TileOp(target, sizes=[4, 8], interchange=[0, 1])
     # CHECK-LABEL: TEST: testTileCompact
     # CHECK: transform.sequence
     # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8]
@@ -426,15 +340,11 @@ def testTileCompact():
 
 
 @run
-def testTileAttributes():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
-    )
+ at create_sequence
+def testTileAttributes(target):
     attr = DenseI64ArrayAttr.get([4, 8])
     ichange = DenseI64ArrayAttr.get([0, 1])
-    with InsertionPoint(sequence.body):
-        structured.TileOp(sequence.bodyTarget, sizes=attr, interchange=ichange)
...
<truncated>
</pre>
</details>


https://github.com/llvm/llvm-project/pull/66469


More information about the Mlir-commits mailing list