[Mlir-commits] [mlir] [mlir][linalg][transform][python] Clean up _ext.py test. (PR #66469)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Sep 15 00:54:36 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
<details>
<summary>Changes</summary>
This PR cleans up the test of the mix-ins of this dialect. Most of the character diff is due to factoring out the creation of the the top-level sequence into a decorator. This decorator siginficantly shortens the definition of the individual tests and can be used in all but one test, where the top-level op is a PDL op. The only functional diff is due to the fact that the decator uses `transform.any_op` instead of `pdl.operation` for the type of the root handle. The only remaining usages of the PDL dialects is now in the test a PDL-related op.
--
Patch is 31.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/66469.diff
1 Files Affected:
- (modified) mlir/test/python/dialects/transform_structured_ext.py (+202-355)
<pre>
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index 465e01d8b658f58..19afbb895eb5b48 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -4,6 +4,7 @@
from mlir.dialects import transform
from mlir.dialects import pdl
from mlir.dialects.transform import structured
+from typing import Callable
from mlir.dialects.transform import pdl as transform_pdl
@@ -18,33 +19,40 @@ def run(f):
return f
+def create_sequence(func: Callable) -> Callable:
+ def decorated() -> None:
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.AnyOpType.get(),
+ )
+ with InsertionPoint(sequence.body):
+ func(sequence.bodyTarget)
+ transform.YieldOp()
+
+ decorated.__name__ = func.__name__
+ return decorated
+
+
@run
-def testBufferizeToAllocationOpCompact():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
- )
- with InsertionPoint(sequence.body):
- structured.BufferizeToAllocationOp(sequence.bodyTarget)
- transform.YieldOp()
+ at create_sequence
+def testBufferizeToAllocationOpCompact(target):
+ structured.BufferizeToAllocationOp(target)
# CHECK-LABEL: TEST: testBufferizeToAllocationOpCompact
# CHECK: transform.sequence
# CHECK: transform.structured.bufferize_to_allocation
@run
-def testBufferizeToAllocationOpArgs():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
+ at create_sequence
+def testBufferizeToAllocationOpArgs(target):
+ structured.BufferizeToAllocationOp(
+ target,
+ memory_space=3,
+ memcpy_op="memref.copy",
+ alloc_op="memref.alloca",
+ bufferize_destination_only=True,
)
- with InsertionPoint(sequence.body):
- structured.BufferizeToAllocationOp(
- sequence.bodyTarget,
- memory_space=3,
- memcpy_op="memref.copy",
- alloc_op="memref.alloca",
- bufferize_destination_only=True,
- )
- transform.YieldOp()
# CHECK-LABEL: TEST: testBufferizeToAllocationOpArgs
# CHECK: transform.sequence
# CHECK: transform.structured.bufferize_to_allocation
@@ -55,78 +63,54 @@ def testBufferizeToAllocationOpArgs():
@run
-def testDecompose():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
- )
- with InsertionPoint(sequence.body):
- structured.DecomposeOp(sequence.bodyTarget)
- transform.YieldOp()
+ at create_sequence
+def testDecompose(target):
+ structured.DecomposeOp(target)
# CHECK-LABEL: TEST: testDecompose
# CHECK: transform.sequence
# CHECK: transform.structured.decompose
@run
-def testFuseIntoContainingOpTypes():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+ at create_sequence
+def testFuseIntoContainingOpTypes(target):
+ fused = structured.MatchOp.match_op_names(target, ["test.dummy"])
+ containing = structured.MatchOp.match_op_names(target, ["test.dummy"])
+ structured.FuseIntoContainingOp(
+ transform.OperationType.get("test.dummy"),
+ transform.OperationType.get("test.dummy"),
+ fused,
+ containing,
)
- with InsertionPoint(sequence.body):
- fused = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
- containing = structured.MatchOp.match_op_names(
- sequence.bodyTarget, ["test.dummy"]
- )
- structured.FuseIntoContainingOp(
- transform.OperationType.get("test.dummy"),
- transform.OperationType.get("test.dummy"),
- fused,
- containing,
- )
- transform.YieldOp()
# CHECK-LABEL: TEST: testFuseIntoContainingOpTypes
# CHECK: = transform.structured.fuse_into_containing_op
# CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.op<"test.dummy">, !transform.op<"test.dummy">)
@run
-def testFuseIntoContainingOpCompact():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
- )
- with InsertionPoint(sequence.body):
- fused = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
- containing = structured.MatchOp.match_op_names(
- sequence.bodyTarget, ["test.dummy"]
- )
- structured.FuseIntoContainingOp(fused, containing)
- transform.YieldOp()
+ at create_sequence
+def testFuseIntoContainingOpCompact(target):
+ fused = structured.MatchOp.match_op_names(target, ["test.dummy"])
+ containing = structured.MatchOp.match_op_names(target, ["test.dummy"])
+ structured.FuseIntoContainingOp(fused, containing)
# CHECK-LABEL: TEST: testFuseIntoContainingOpCompact
# CHECK: = transform.structured.fuse_into_containing_op
# CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
@run
-def testGeneralize():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
- )
- with InsertionPoint(sequence.body):
- structured.GeneralizeOp(sequence.bodyTarget)
- transform.YieldOp()
+ at create_sequence
+def testGeneralize(target):
+ structured.GeneralizeOp(target)
# CHECK-LABEL: TEST: testGeneralize
# CHECK: transform.sequence
# CHECK: transform.structured.generalize
@run
-def testInterchange():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
- )
- with InsertionPoint(sequence.body):
- structured.InterchangeOp(sequence.bodyTarget, iterator_interchange=[1, 0])
- transform.YieldOp()
+ at create_sequence
+def testInterchange(target):
+ structured.InterchangeOp(target, iterator_interchange=[1, 0])
# CHECK-LABEL: TEST: testInterchange
# CHECK: transform.sequence
# CHECK: transform.structured.interchange
@@ -134,15 +118,11 @@ def testInterchange():
@run
-def testMapCopyToThreadsOpCompact():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+ at create_sequence
+def testMapCopyToThreadsOpCompact(target):
+ structured.MapCopyToThreadsOp(
+ target, total_num_threads=32, desired_bit_alignment=128
)
- with InsertionPoint(sequence.body):
- structured.MapCopyToThreadsOp(
- sequence.bodyTarget, total_num_threads=32, desired_bit_alignment=128
- )
- transform.YieldOp()
# CHECK-LABEL: TEST: testMapCopyToThreadsOpCompact
# CHECK: = transform.structured.gpu.map_copy_to_threads
# CHECK-SAME: total_num_threads = 32
@@ -151,19 +131,15 @@ def testMapCopyToThreadsOpCompact():
@run
-def testMapCopyToThreadsOpTypes():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+ at create_sequence
+def testMapCopyToThreadsOpTypes(target):
+ structured.MapCopyToThreadsOp(
+ transform.OperationType.get("test.opA"),
+ transform.OperationType.get("test.opB"),
+ target,
+ total_num_threads=32,
+ desired_bit_alignment=128,
)
- with InsertionPoint(sequence.body):
- structured.MapCopyToThreadsOp(
- transform.OperationType.get("test.opA"),
- transform.OperationType.get("test.opB"),
- sequence.bodyTarget,
- total_num_threads=32,
- desired_bit_alignment=128,
- )
- transform.YieldOp()
# CHECK-LABEL: TEST: testMapCopyToThreadsOpTypes
# CHECK: = transform.structured.gpu.map_copy_to_threads
# CHECK-SAME: total_num_threads = 32
@@ -172,13 +148,9 @@ def testMapCopyToThreadsOpTypes():
@run
-def testMatchOpNamesString():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
- )
- with InsertionPoint(sequence.body):
- structured.MatchOp.match_op_names(sequence.bodyTarget, "test.dummy")
- transform.YieldOp()
+ at create_sequence
+def testMatchOpNamesString(target):
+ structured.MatchOp.match_op_names(target, "test.dummy")
# CHECK-LABEL: TEST: testMatchOpNamesString
# CHECK: transform.structured.match ops
# CHECK-SAME: ["test.dummy"]
@@ -186,13 +158,9 @@ def testMatchOpNamesString():
@run
-def testMatchOpNamesList():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
- )
- with InsertionPoint(sequence.body):
- structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
- transform.YieldOp()
+ at create_sequence
+def testMatchOpNamesList(target):
+ structured.MatchOp.match_op_names(target, ["test.dummy"])
# CHECK-LABEL: TEST: testMatchOpNamesList
# CHECK: transform.structured.match ops
# CHECK-SAME: ["test.dummy"]
@@ -200,13 +168,9 @@ def testMatchOpNamesList():
@run
-def testMaskedVectorizeStatic():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
- )
- with InsertionPoint(sequence.body):
- structured.MaskedVectorizeOp(sequence.bodyTarget, [16, 4])
- transform.YieldOp()
+ at create_sequence
+def testMaskedVectorizeStatic(target):
+ structured.MaskedVectorizeOp(target, [16, 4])
# CHECK-LABEL: TEST: testMaskedVectorizeStatic
# CHECK: transform.sequence
# CHECK: transform.structured.masked_vectorize
@@ -214,14 +178,10 @@ def testMaskedVectorizeStatic():
@run
-def testMaskedVectorizeArray():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
- )
- with InsertionPoint(sequence.body):
- sizes = Attribute.parse("[16, 4]")
- structured.MaskedVectorizeOp(sequence.bodyTarget, sizes)
- transform.YieldOp()
+ at create_sequence
+def testMaskedVectorizeArray(target):
+ sizes = Attribute.parse("[16, 4]")
+ structured.MaskedVectorizeOp(target, sizes)
# CHECK-LABEL: TEST: testMaskedVectorizeArray
# CHECK: transform.sequence
# CHECK: transform.structured.masked_vectorize
@@ -229,15 +189,11 @@ def testMaskedVectorizeArray():
@run
-def testMaskedVectorizeMixed():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
- )
- with InsertionPoint(sequence.body):
- sz1 = structured.MatchOp.match_op_names(sequence.bodyTarget, ["arith.constant"])
- sz2 = Attribute.parse("4")
- structured.MaskedVectorizeOp(sequence.bodyTarget, [sz1, sz2])
- transform.YieldOp()
+ at create_sequence
+def testMaskedVectorizeMixed(target):
+ sz1 = structured.MatchOp.match_op_names(target, ["arith.constant"])
+ sz2 = Attribute.parse("4")
+ structured.MaskedVectorizeOp(target, [sz1, sz2])
# CHECK-LABEL: TEST: testMaskedVectorizeMixed
# CHECK: transform.sequence
# CHECK: %[[V0:.*]] = transform.structured.match
@@ -246,15 +202,11 @@ def testMaskedVectorizeMixed():
@run
-def testMaskedVectorizeScalable():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
- )
- with InsertionPoint(sequence.body):
- sz1 = structured.MatchOp.match_op_names(sequence.bodyTarget, ["arith.constant"])
- sz2 = Attribute.parse("4")
- structured.MaskedVectorizeOp(sequence.bodyTarget, [16, [sz1], [sz2], [8]])
- transform.YieldOp()
+ at create_sequence
+def testMaskedVectorizeScalable(target):
+ sz1 = structured.MatchOp.match_op_names(target, ["arith.constant"])
+ sz2 = Attribute.parse("4")
+ structured.MaskedVectorizeOp(target, [16, [sz1], [sz2], [8]])
# CHECK-LABEL: TEST: testMaskedVectorizeScalable
# CHECK: transform.sequence
# CHECK-DAG: %[[V0:.*]] = transform.structured.match
@@ -263,15 +215,9 @@ def testMaskedVectorizeScalable():
@run
-def testMaskedVectorizeArgs():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
- )
- with InsertionPoint(sequence.body):
- structured.MaskedVectorizeOp(
- sequence.bodyTarget, [16, 4], vectorize_nd_extract=True
- )
- transform.YieldOp()
+ at create_sequence
+def testMaskedVectorizeArgs(target):
+ structured.MaskedVectorizeOp(target, [16, 4], vectorize_nd_extract=True)
# CHECK-LABEL: TEST: testMaskedVectorizeArgs
# CHECK: transform.sequence
# CHECK: transform.structured.masked_vectorize
@@ -279,17 +225,13 @@ def testMaskedVectorizeArgs():
@run
-def testMatchOpNamesTyped():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+ at create_sequence
+def testMatchOpNamesTyped(target):
+ structured.MatchOp.match_op_names(
+ transform.OperationType.get("test.dummy"),
+ target,
+ ["test.dummy"],
)
- with InsertionPoint(sequence.body):
- structured.MatchOp.match_op_names(
- transform.OperationType.get("test.dummy"),
- sequence.bodyTarget,
- ["test.dummy"],
- )
- transform.YieldOp()
# CHECK-LABEL: TEST: testMatchOpNamesTyped
# CHECK: transform.structured.match ops
# CHECK-SAME: ["test.dummy"]
@@ -297,15 +239,11 @@ def testMatchOpNamesTyped():
@run
-def testMultitileSizesCompact():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
+ at create_sequence
+def testMultitileSizesCompact(target):
+ structured.MultiTileSizesOp(
+ transform.AnyOpType.get(), target, dimension=1, target_size=42
)
- with InsertionPoint(sequence.body):
- structured.MultiTileSizesOp(
- pdl.OperationType.get(), sequence.bodyTarget, dimension=1, target_size=42
- )
- transform.YieldOp()
# CHECK-LABEL: TEST: testMultitileSizes
# CHECK: transform.sequence
# CHECK-NOT: divisor
@@ -318,19 +256,15 @@ def testMultitileSizesCompact():
@run
-def testMultitileSizesAllArgs():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
+ at create_sequence
+def testMultitileSizesAllArgs(target):
+ structured.MultiTileSizesOp(
+ transform.AnyOpType.get(),
+ target,
+ dimension=1,
+ target_size=42,
+ divisor=2,
)
- with InsertionPoint(sequence.body):
- structured.MultiTileSizesOp(
- pdl.OperationType.get(),
- sequence.bodyTarget,
- dimension=1,
- target_size=42,
- divisor=2,
- )
- transform.YieldOp()
# CHECK-LABEL: TEST: testMultitileSizes
# CHECK: transform.sequence
# CHECK: transform.structured.multitile_sizes
@@ -340,13 +274,9 @@ def testMultitileSizesAllArgs():
@run
-def testPadOpNoArgs():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
- )
- with InsertionPoint(sequence.body):
- structured.PadOp(sequence.bodyTarget)
- transform.YieldOp()
+ at create_sequence
+def testPadOpNoArgs(target):
+ structured.PadOp(target)
# CHECK-LABEL: TEST: testPadOpNoArgs
# CHECK: transform.sequence
# CHECK: transform.structured.pad
@@ -359,21 +289,17 @@ def testPadOpNoArgs():
@run
-def testPadOpArgs():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
+ at create_sequence
+def testPadOpArgs(target):
+ structured.PadOp(
+ target,
+ padding_values=[FloatAttr.get_f32(42.0), StringAttr.get("0")],
+ padding_dimensions=Attribute.parse("[1]"),
+ pad_to_multiple_of=[128],
+ pack_paddings=[0],
+ transpose_paddings=[[1, Attribute.parse("0")], Attribute.parse("[0, 1]")],
+ copy_back_op="linalg.copy",
)
- with InsertionPoint(sequence.body):
- structured.PadOp(
- sequence.bodyTarget,
- padding_values=[FloatAttr.get_f32(42.0), StringAttr.get("0")],
- padding_dimensions=Attribute.parse("[1]"),
- pad_to_multiple_of=[128],
- pack_paddings=[0],
- transpose_paddings=[[1, Attribute.parse("0")], Attribute.parse("[0, 1]")],
- copy_back_op="linalg.copy",
- )
- transform.YieldOp()
# CHECK-LABEL: TEST: testPadOpArgs
# CHECK: transform.sequence
# CHECK: transform.structured.pad
@@ -386,39 +312,27 @@ def testPadOpArgs():
@run
-def testScalarize():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
- )
- with InsertionPoint(sequence.body):
- structured.ScalarizeOp(sequence.bodyTarget)
- transform.YieldOp()
+ at create_sequence
+def testScalarize(target):
+ structured.ScalarizeOp(target)
# CHECK-LABEL: TEST: testScalarize
# CHECK: transform.structured.scalarize
@run
-def testSplit():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
- )
- with InsertionPoint(sequence.body):
- split = structured.SplitOp(sequence.bodyTarget, dimension=1, split_point=42)
- structured.SplitOp(split.results[0], dimension=3, split_point=split.results[1])
- transform.YieldOp()
+ at create_sequence
+def testSplit(target):
+ split = structured.SplitOp(target, dimension=1, split_point=42)
+ structured.SplitOp(split.results[0], dimension=3, split_point=split.results[1])
# CHECK-LABEL: TEST: testSplit
# CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1
# CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3
@run
-def testTileCompact():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
- )
- with InsertionPoint(sequence.body):
- structured.TileOp(sequence.bodyTarget, sizes=[4, 8], interchange=[0, 1])
- transform.YieldOp()
+ at create_sequence
+def testTileCompact(target):
+ structured.TileOp(target, sizes=[4, 8], interchange=[0, 1])
# CHECK-LABEL: TEST: testTileCompact
# CHECK: transform.sequence
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8]
@@ -426,15 +340,11 @@ def testTileCompact():
@run
-def testTileAttributes():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
- )
+ at create_sequence
+def testTileAttributes(target):
attr = DenseI64ArrayAttr.get([4, 8])
ichange = DenseI64ArrayAttr.get([0, 1])
- with InsertionPoint(sequence.body):
- structured.TileOp(sequence.bodyTarget, sizes=attr, interchange=ichange)
...
<truncated>
</pre>
</details>
https://github.com/llvm/llvm-project/pull/66469
More information about the Mlir-commits
mailing list