[Mlir-commits] [mlir] [mlir][linalg][transform][python] Clean up _ext.py test. (PR #66469)
Ingo Müller
llvmlistbot at llvm.org
Fri Sep 15 00:53:34 PDT 2023
https://github.com/ingomueller-net created https://github.com/llvm/llvm-project/pull/66469
This PR cleans up the test of the mix-ins of this dialect. Most of the character diff is due to factoring out the creation of the the top-level sequence into a decorator. This decorator siginficantly shortens the definition of the individual tests and can be used in all but one test, where the top-level op is a PDL op. The only functional diff is due to the fact that the decator uses `transform.any_op` instead of `pdl.operation` for the type of the root handle. The only remaining usages of the PDL dialects is now in the test a PDL-related op.
>From f72d9942f6c52f27f9dda20ce175e72b38fef3c9 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Fri, 15 Sep 2023 07:39:08 +0000
Subject: [PATCH] [mlir][linalg][transform][python] Clean up _ext.py test.
This PR cleans up the test of the mix-ins of this dialect. Most of the
character diff is due to factoring out the creation of the the top-level
sequence into a decorator. This decorator siginficantly shortens the
definition of the individual tests and can be used in all but one test,
where the top-level op is a PDL op. The only functional diff is due to
the fact that the decator uses `transform.any_op` instead of
`pdl.operation` for the type of the root handle. The only remaining
usages of the PDL dialects is now in the test a PDL-related op.
---
.../dialects/transform_structured_ext.py | 557 +++++++-----------
1 file changed, 202 insertions(+), 355 deletions(-)
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index 465e01d8b658f58..19afbb895eb5b48 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -4,6 +4,7 @@
from mlir.dialects import transform
from mlir.dialects import pdl
from mlir.dialects.transform import structured
+from typing import Callable
from mlir.dialects.transform import pdl as transform_pdl
@@ -18,33 +19,40 @@ def run(f):
return f
+def create_sequence(func: Callable) -> Callable:
+ def decorated() -> None:
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.AnyOpType.get(),
+ )
+ with InsertionPoint(sequence.body):
+ func(sequence.bodyTarget)
+ transform.YieldOp()
+
+ decorated.__name__ = func.__name__
+ return decorated
+
+
@run
-def testBufferizeToAllocationOpCompact():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
- )
- with InsertionPoint(sequence.body):
- structured.BufferizeToAllocationOp(sequence.bodyTarget)
- transform.YieldOp()
+ at create_sequence
+def testBufferizeToAllocationOpCompact(target):
+ structured.BufferizeToAllocationOp(target)
# CHECK-LABEL: TEST: testBufferizeToAllocationOpCompact
# CHECK: transform.sequence
# CHECK: transform.structured.bufferize_to_allocation
@run
-def testBufferizeToAllocationOpArgs():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
+ at create_sequence
+def testBufferizeToAllocationOpArgs(target):
+ structured.BufferizeToAllocationOp(
+ target,
+ memory_space=3,
+ memcpy_op="memref.copy",
+ alloc_op="memref.alloca",
+ bufferize_destination_only=True,
)
- with InsertionPoint(sequence.body):
- structured.BufferizeToAllocationOp(
- sequence.bodyTarget,
- memory_space=3,
- memcpy_op="memref.copy",
- alloc_op="memref.alloca",
- bufferize_destination_only=True,
- )
- transform.YieldOp()
# CHECK-LABEL: TEST: testBufferizeToAllocationOpArgs
# CHECK: transform.sequence
# CHECK: transform.structured.bufferize_to_allocation
@@ -55,78 +63,54 @@ def testBufferizeToAllocationOpArgs():
@run
-def testDecompose():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
- )
- with InsertionPoint(sequence.body):
- structured.DecomposeOp(sequence.bodyTarget)
- transform.YieldOp()
+ at create_sequence
+def testDecompose(target):
+ structured.DecomposeOp(target)
# CHECK-LABEL: TEST: testDecompose
# CHECK: transform.sequence
# CHECK: transform.structured.decompose
@run
-def testFuseIntoContainingOpTypes():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+ at create_sequence
+def testFuseIntoContainingOpTypes(target):
+ fused = structured.MatchOp.match_op_names(target, ["test.dummy"])
+ containing = structured.MatchOp.match_op_names(target, ["test.dummy"])
+ structured.FuseIntoContainingOp(
+ transform.OperationType.get("test.dummy"),
+ transform.OperationType.get("test.dummy"),
+ fused,
+ containing,
)
- with InsertionPoint(sequence.body):
- fused = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
- containing = structured.MatchOp.match_op_names(
- sequence.bodyTarget, ["test.dummy"]
- )
- structured.FuseIntoContainingOp(
- transform.OperationType.get("test.dummy"),
- transform.OperationType.get("test.dummy"),
- fused,
- containing,
- )
- transform.YieldOp()
# CHECK-LABEL: TEST: testFuseIntoContainingOpTypes
# CHECK: = transform.structured.fuse_into_containing_op
# CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.op<"test.dummy">, !transform.op<"test.dummy">)
@run
-def testFuseIntoContainingOpCompact():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
- )
- with InsertionPoint(sequence.body):
- fused = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
- containing = structured.MatchOp.match_op_names(
- sequence.bodyTarget, ["test.dummy"]
- )
- structured.FuseIntoContainingOp(fused, containing)
- transform.YieldOp()
+ at create_sequence
+def testFuseIntoContainingOpCompact(target):
+ fused = structured.MatchOp.match_op_names(target, ["test.dummy"])
+ containing = structured.MatchOp.match_op_names(target, ["test.dummy"])
+ structured.FuseIntoContainingOp(fused, containing)
# CHECK-LABEL: TEST: testFuseIntoContainingOpCompact
# CHECK: = transform.structured.fuse_into_containing_op
# CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
@run
-def testGeneralize():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
- )
- with InsertionPoint(sequence.body):
- structured.GeneralizeOp(sequence.bodyTarget)
- transform.YieldOp()
+ at create_sequence
+def testGeneralize(target):
+ structured.GeneralizeOp(target)
# CHECK-LABEL: TEST: testGeneralize
# CHECK: transform.sequence
# CHECK: transform.structured.generalize
@run
-def testInterchange():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
- )
- with InsertionPoint(sequence.body):
- structured.InterchangeOp(sequence.bodyTarget, iterator_interchange=[1, 0])
- transform.YieldOp()
+ at create_sequence
+def testInterchange(target):
+ structured.InterchangeOp(target, iterator_interchange=[1, 0])
# CHECK-LABEL: TEST: testInterchange
# CHECK: transform.sequence
# CHECK: transform.structured.interchange
@@ -134,15 +118,11 @@ def testInterchange():
@run
-def testMapCopyToThreadsOpCompact():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+ at create_sequence
+def testMapCopyToThreadsOpCompact(target):
+ structured.MapCopyToThreadsOp(
+ target, total_num_threads=32, desired_bit_alignment=128
)
- with InsertionPoint(sequence.body):
- structured.MapCopyToThreadsOp(
- sequence.bodyTarget, total_num_threads=32, desired_bit_alignment=128
- )
- transform.YieldOp()
# CHECK-LABEL: TEST: testMapCopyToThreadsOpCompact
# CHECK: = transform.structured.gpu.map_copy_to_threads
# CHECK-SAME: total_num_threads = 32
@@ -151,19 +131,15 @@ def testMapCopyToThreadsOpCompact():
@run
-def testMapCopyToThreadsOpTypes():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+ at create_sequence
+def testMapCopyToThreadsOpTypes(target):
+ structured.MapCopyToThreadsOp(
+ transform.OperationType.get("test.opA"),
+ transform.OperationType.get("test.opB"),
+ target,
+ total_num_threads=32,
+ desired_bit_alignment=128,
)
- with InsertionPoint(sequence.body):
- structured.MapCopyToThreadsOp(
- transform.OperationType.get("test.opA"),
- transform.OperationType.get("test.opB"),
- sequence.bodyTarget,
- total_num_threads=32,
- desired_bit_alignment=128,
- )
- transform.YieldOp()
# CHECK-LABEL: TEST: testMapCopyToThreadsOpTypes
# CHECK: = transform.structured.gpu.map_copy_to_threads
# CHECK-SAME: total_num_threads = 32
@@ -172,13 +148,9 @@ def testMapCopyToThreadsOpTypes():
@run
-def testMatchOpNamesString():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
- )
- with InsertionPoint(sequence.body):
- structured.MatchOp.match_op_names(sequence.bodyTarget, "test.dummy")
- transform.YieldOp()
+ at create_sequence
+def testMatchOpNamesString(target):
+ structured.MatchOp.match_op_names(target, "test.dummy")
# CHECK-LABEL: TEST: testMatchOpNamesString
# CHECK: transform.structured.match ops
# CHECK-SAME: ["test.dummy"]
@@ -186,13 +158,9 @@ def testMatchOpNamesString():
@run
-def testMatchOpNamesList():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
- )
- with InsertionPoint(sequence.body):
- structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
- transform.YieldOp()
+ at create_sequence
+def testMatchOpNamesList(target):
+ structured.MatchOp.match_op_names(target, ["test.dummy"])
# CHECK-LABEL: TEST: testMatchOpNamesList
# CHECK: transform.structured.match ops
# CHECK-SAME: ["test.dummy"]
@@ -200,13 +168,9 @@ def testMatchOpNamesList():
@run
-def testMaskedVectorizeStatic():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
- )
- with InsertionPoint(sequence.body):
- structured.MaskedVectorizeOp(sequence.bodyTarget, [16, 4])
- transform.YieldOp()
+ at create_sequence
+def testMaskedVectorizeStatic(target):
+ structured.MaskedVectorizeOp(target, [16, 4])
# CHECK-LABEL: TEST: testMaskedVectorizeStatic
# CHECK: transform.sequence
# CHECK: transform.structured.masked_vectorize
@@ -214,14 +178,10 @@ def testMaskedVectorizeStatic():
@run
-def testMaskedVectorizeArray():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
- )
- with InsertionPoint(sequence.body):
- sizes = Attribute.parse("[16, 4]")
- structured.MaskedVectorizeOp(sequence.bodyTarget, sizes)
- transform.YieldOp()
+ at create_sequence
+def testMaskedVectorizeArray(target):
+ sizes = Attribute.parse("[16, 4]")
+ structured.MaskedVectorizeOp(target, sizes)
# CHECK-LABEL: TEST: testMaskedVectorizeArray
# CHECK: transform.sequence
# CHECK: transform.structured.masked_vectorize
@@ -229,15 +189,11 @@ def testMaskedVectorizeArray():
@run
-def testMaskedVectorizeMixed():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
- )
- with InsertionPoint(sequence.body):
- sz1 = structured.MatchOp.match_op_names(sequence.bodyTarget, ["arith.constant"])
- sz2 = Attribute.parse("4")
- structured.MaskedVectorizeOp(sequence.bodyTarget, [sz1, sz2])
- transform.YieldOp()
+ at create_sequence
+def testMaskedVectorizeMixed(target):
+ sz1 = structured.MatchOp.match_op_names(target, ["arith.constant"])
+ sz2 = Attribute.parse("4")
+ structured.MaskedVectorizeOp(target, [sz1, sz2])
# CHECK-LABEL: TEST: testMaskedVectorizeMixed
# CHECK: transform.sequence
# CHECK: %[[V0:.*]] = transform.structured.match
@@ -246,15 +202,11 @@ def testMaskedVectorizeMixed():
@run
-def testMaskedVectorizeScalable():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
- )
- with InsertionPoint(sequence.body):
- sz1 = structured.MatchOp.match_op_names(sequence.bodyTarget, ["arith.constant"])
- sz2 = Attribute.parse("4")
- structured.MaskedVectorizeOp(sequence.bodyTarget, [16, [sz1], [sz2], [8]])
- transform.YieldOp()
+ at create_sequence
+def testMaskedVectorizeScalable(target):
+ sz1 = structured.MatchOp.match_op_names(target, ["arith.constant"])
+ sz2 = Attribute.parse("4")
+ structured.MaskedVectorizeOp(target, [16, [sz1], [sz2], [8]])
# CHECK-LABEL: TEST: testMaskedVectorizeScalable
# CHECK: transform.sequence
# CHECK-DAG: %[[V0:.*]] = transform.structured.match
@@ -263,15 +215,9 @@ def testMaskedVectorizeScalable():
@run
-def testMaskedVectorizeArgs():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
- )
- with InsertionPoint(sequence.body):
- structured.MaskedVectorizeOp(
- sequence.bodyTarget, [16, 4], vectorize_nd_extract=True
- )
- transform.YieldOp()
+ at create_sequence
+def testMaskedVectorizeArgs(target):
+ structured.MaskedVectorizeOp(target, [16, 4], vectorize_nd_extract=True)
# CHECK-LABEL: TEST: testMaskedVectorizeArgs
# CHECK: transform.sequence
# CHECK: transform.structured.masked_vectorize
@@ -279,17 +225,13 @@ def testMaskedVectorizeArgs():
@run
-def testMatchOpNamesTyped():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+ at create_sequence
+def testMatchOpNamesTyped(target):
+ structured.MatchOp.match_op_names(
+ transform.OperationType.get("test.dummy"),
+ target,
+ ["test.dummy"],
)
- with InsertionPoint(sequence.body):
- structured.MatchOp.match_op_names(
- transform.OperationType.get("test.dummy"),
- sequence.bodyTarget,
- ["test.dummy"],
- )
- transform.YieldOp()
# CHECK-LABEL: TEST: testMatchOpNamesTyped
# CHECK: transform.structured.match ops
# CHECK-SAME: ["test.dummy"]
@@ -297,15 +239,11 @@ def testMatchOpNamesTyped():
@run
-def testMultitileSizesCompact():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
+ at create_sequence
+def testMultitileSizesCompact(target):
+ structured.MultiTileSizesOp(
+ transform.AnyOpType.get(), target, dimension=1, target_size=42
)
- with InsertionPoint(sequence.body):
- structured.MultiTileSizesOp(
- pdl.OperationType.get(), sequence.bodyTarget, dimension=1, target_size=42
- )
- transform.YieldOp()
# CHECK-LABEL: TEST: testMultitileSizes
# CHECK: transform.sequence
# CHECK-NOT: divisor
@@ -318,19 +256,15 @@ def testMultitileSizesCompact():
@run
-def testMultitileSizesAllArgs():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
+ at create_sequence
+def testMultitileSizesAllArgs(target):
+ structured.MultiTileSizesOp(
+ transform.AnyOpType.get(),
+ target,
+ dimension=1,
+ target_size=42,
+ divisor=2,
)
- with InsertionPoint(sequence.body):
- structured.MultiTileSizesOp(
- pdl.OperationType.get(),
- sequence.bodyTarget,
- dimension=1,
- target_size=42,
- divisor=2,
- )
- transform.YieldOp()
# CHECK-LABEL: TEST: testMultitileSizes
# CHECK: transform.sequence
# CHECK: transform.structured.multitile_sizes
@@ -340,13 +274,9 @@ def testMultitileSizesAllArgs():
@run
-def testPadOpNoArgs():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
- )
- with InsertionPoint(sequence.body):
- structured.PadOp(sequence.bodyTarget)
- transform.YieldOp()
+ at create_sequence
+def testPadOpNoArgs(target):
+ structured.PadOp(target)
# CHECK-LABEL: TEST: testPadOpNoArgs
# CHECK: transform.sequence
# CHECK: transform.structured.pad
@@ -359,21 +289,17 @@ def testPadOpNoArgs():
@run
-def testPadOpArgs():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
+ at create_sequence
+def testPadOpArgs(target):
+ structured.PadOp(
+ target,
+ padding_values=[FloatAttr.get_f32(42.0), StringAttr.get("0")],
+ padding_dimensions=Attribute.parse("[1]"),
+ pad_to_multiple_of=[128],
+ pack_paddings=[0],
+ transpose_paddings=[[1, Attribute.parse("0")], Attribute.parse("[0, 1]")],
+ copy_back_op="linalg.copy",
)
- with InsertionPoint(sequence.body):
- structured.PadOp(
- sequence.bodyTarget,
- padding_values=[FloatAttr.get_f32(42.0), StringAttr.get("0")],
- padding_dimensions=Attribute.parse("[1]"),
- pad_to_multiple_of=[128],
- pack_paddings=[0],
- transpose_paddings=[[1, Attribute.parse("0")], Attribute.parse("[0, 1]")],
- copy_back_op="linalg.copy",
- )
- transform.YieldOp()
# CHECK-LABEL: TEST: testPadOpArgs
# CHECK: transform.sequence
# CHECK: transform.structured.pad
@@ -386,39 +312,27 @@ def testPadOpArgs():
@run
-def testScalarize():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
- )
- with InsertionPoint(sequence.body):
- structured.ScalarizeOp(sequence.bodyTarget)
- transform.YieldOp()
+ at create_sequence
+def testScalarize(target):
+ structured.ScalarizeOp(target)
# CHECK-LABEL: TEST: testScalarize
# CHECK: transform.structured.scalarize
@run
-def testSplit():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
- )
- with InsertionPoint(sequence.body):
- split = structured.SplitOp(sequence.bodyTarget, dimension=1, split_point=42)
- structured.SplitOp(split.results[0], dimension=3, split_point=split.results[1])
- transform.YieldOp()
+ at create_sequence
+def testSplit(target):
+ split = structured.SplitOp(target, dimension=1, split_point=42)
+ structured.SplitOp(split.results[0], dimension=3, split_point=split.results[1])
# CHECK-LABEL: TEST: testSplit
# CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1
# CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3
@run
-def testTileCompact():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
- )
- with InsertionPoint(sequence.body):
- structured.TileOp(sequence.bodyTarget, sizes=[4, 8], interchange=[0, 1])
- transform.YieldOp()
+ at create_sequence
+def testTileCompact(target):
+ structured.TileOp(target, sizes=[4, 8], interchange=[0, 1])
# CHECK-LABEL: TEST: testTileCompact
# CHECK: transform.sequence
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8]
@@ -426,15 +340,11 @@ def testTileCompact():
@run
-def testTileAttributes():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
- )
+ at create_sequence
+def testTileAttributes(target):
attr = DenseI64ArrayAttr.get([4, 8])
ichange = DenseI64ArrayAttr.get([0, 1])
- with InsertionPoint(sequence.body):
- structured.TileOp(sequence.bodyTarget, sizes=attr, interchange=ichange)
- transform.YieldOp()
+ structured.TileOp(target, sizes=attr, interchange=ichange)
# CHECK-LABEL: TEST: testTileAttributes
# CHECK: transform.sequence
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8]
@@ -442,15 +352,9 @@ def testTileAttributes():
@run
-def testTileZero():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
- )
- with InsertionPoint(sequence.body):
- structured.TileOp(
- sequence.bodyTarget, sizes=[4, 0, 2, 0], interchange=[0, 1, 2, 3]
- )
- transform.YieldOp()
+ at create_sequence
+def testTileZero(target):
+ structured.TileOp(target, sizes=[4, 0, 2, 0], interchange=[0, 1, 2, 3])
# CHECK-LABEL: TEST: testTileZero
# CHECK: transform.sequence
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 0, 2, 0]
@@ -480,32 +384,22 @@ def testTileDynamic():
@run
-def testTileExplicitLoopTypeSingle():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
- )
- with InsertionPoint(sequence.body):
- structured.TileOp(
- transform.OperationType.get("scf.for"), sequence.bodyTarget, sizes=[2, 3, 4]
- )
- transform.YieldOp()
+ at create_sequence
+def testTileExplicitLoopTypeSingle(target):
+ structured.TileOp(transform.OperationType.get("scf.for"), target, sizes=[2, 3, 4])
# CHECK-LABEL: TEST: testTileExplicitLoopTypeSingle
# CHECK: = transform.structured.tile %{{.*}} : (!{{.*}}) ->
# CHECK-COUNT-3: !transform.op<"scf.for">
@run
-def testTileExplicitLoopTypeAll():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
- )
+ at create_sequence
+def testTileExplicitLoopTypeAll(target):
types = [
transform.OperationType.get(x)
for x in ["scf.for", "scf.parallel", "scf.forall"]
]
- with InsertionPoint(sequence.body):
- structured.TileOp(types, sequence.bodyTarget, sizes=[2, 3, 4])
- transform.YieldOp()
+ structured.TileOp(types, target, sizes=[2, 3, 4])
# CHECK-LABEL: TEST: testTileExplicitLoopTypeAll
# CHECK: = transform.structured.tile
# CHECK-SAME : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">,
@@ -513,31 +407,22 @@ def testTileExplicitLoopTypeAll():
@run
-def testTileScalable():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+ at create_sequence
+def testTileScalable(target):
+ structured.TileOp(
+ target,
+ sizes=[4, [2]],
)
- with InsertionPoint(sequence.body):
- structured.TileOp(
- sequence.bodyTarget,
- sizes=[4, [2]],
- )
- transform.YieldOp()
# CHECK-LABEL: TEST: testTileScalable
# CHECK: transform.sequence
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, [2]]
@run
-def testTileToForallCompact():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate,
- [],
- transform.OperationType.get("linalg.matmul"),
- )
- with InsertionPoint(sequence.body):
- structured.TileToForallOp(sequence.bodyTarget, num_threads=[2, 3, 4])
- transform.YieldOp()
+ at create_sequence
+def testTileToForallCompact(target):
+ matmul = transform.CastOp(transform.OperationType.get("linalg.matmul"), target)
+ structured.TileToForallOp(matmul, num_threads=[2, 3, 4])
# CHECK-LABEL: TEST: testTileToForallCompact
# CHECK: = transform.structured.tile_to_forall_op
# CHECK-SAME: num_threads [2, 3, 4] tile_sizes []
@@ -545,18 +430,14 @@ def testTileToForallCompact():
@run
-def testTileToForallLoopsAndTileOpTypes():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+ at create_sequence
+def testTileToForallLoopsAndTileOpTypes(target):
+ structured.TileToForallOp(
+ transform.OperationType.get("scf.forall"), # loops_type
+ transform.OperationType.get("linalg.matmul"), # tiled_op_type
+ target,
+ num_threads=[2, 3, 4],
)
- with InsertionPoint(sequence.body):
- structured.TileToForallOp(
- transform.OperationType.get("scf.forall"), # loops_type
- transform.OperationType.get("linalg.matmul"), # tiled_op_type
- sequence.bodyTarget,
- num_threads=[2, 3, 4],
- )
- transform.YieldOp()
# CHECK-LABEL: TEST: testTileToForallLoopsAndTileOpTypes
# CHECK: = transform.structured.tile_to_forall_op
# CHECK-SAME: num_threads [2, 3, 4] tile_sizes []
@@ -564,76 +445,54 @@ def testTileToForallLoopsAndTileOpTypes():
@run
-def testTileToForallTileSizes():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
- )
- with InsertionPoint(sequence.body):
- structured.TileToForallOp(sequence.bodyTarget, tile_sizes=[2, 3, 4])
- transform.YieldOp()
+ at create_sequence
+def testTileToForallTileSizes(target):
+ structured.TileToForallOp(target, tile_sizes=[2, 3, 4])
# CHECK-LABEL: TEST: testTileToForallTileSizes
# CHECK: = transform.structured.tile_to_forall_op
# CHECK-SAME: num_threads [] tile_sizes [2, 3, 4]
@run
-def testTileToForallMixedDynamic():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
- )
- with InsertionPoint(sequence.body):
- n = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
- structured.TileToForallOp(sequence.bodyTarget, num_threads=[n, 3, 4])
- transform.YieldOp()
+ at create_sequence
+def testTileToForallMixedDynamic(target):
+ n = structured.MatchOp.match_op_names(target, ["test.dummy"])
+ structured.TileToForallOp(target, num_threads=[n, 3, 4])
# CHECK-LABEL: TEST: testTileToForallMixedDynamic
# CHECK: = transform.structured.tile_to_forall_op
# CHECK-SAME: num_threads [%{{.*}} : !transform.any_op, 3, 4]
@run
-def testTileToForallPackedDynamic():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
- )
- with InsertionPoint(sequence.body):
- n = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
- structured.TileToForallOp(sequence.bodyTarget, num_threads=n)
- transform.YieldOp()
+ at create_sequence
+def testTileToForallPackedDynamic(target):
+ n = structured.MatchOp.match_op_names(target, ["test.dummy"])
+ structured.TileToForallOp(target, num_threads=n)
# CHECK-LABEL: TEST: testTileToForallPackedDynamic
# CHECK: = transform.structured.tile_to_forall_op
# CHECK-SAME: num_threads *(%0 : !transform.any_op)
@run
-def testTileToForallMapping():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
- )
- with InsertionPoint(sequence.body):
- mapping = Attribute.parse("[ #gpu.thread<y>, #gpu.thread<x> ]")
- structured.TileToForallOp(
- sequence.bodyTarget, num_threads=[2, 3], mapping=mapping
- )
- transform.YieldOp()
+ at create_sequence
+def testTileToForallMapping(target):
+ mapping = Attribute.parse("[ #gpu.thread<y>, #gpu.thread<x> ]")
+ structured.TileToForallOp(target, num_threads=[2, 3], mapping=mapping)
# CHECK-LABEL: TEST: testTileToForallMapping
# CHECK: = transform.structured.tile_to_forall_op
# CHECK-SAME: mapping = [#gpu.thread<y>, #gpu.thread<x>]
@run
-def testVectorizeAllAttrs():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
+ at create_sequence
+def testVectorizeAllAttrs(target):
+ structured.VectorizeOp(
+ target,
+ disable_multi_reduction_to_contract_patterns=True,
+ disable_transfer_permutation_map_lowering_patterns=True,
+ vectorize_nd_extract=True,
+ vectorize_padding=True,
)
- with InsertionPoint(sequence.body):
- structured.VectorizeOp(
- sequence.bodyTarget,
- disable_multi_reduction_to_contract_patterns=True,
- disable_transfer_permutation_map_lowering_patterns=True,
- vectorize_nd_extract=True,
- vectorize_padding=True,
- )
- transform.YieldOp()
# CHECK-LABEL: TEST: testVectorizeAllAttrs
# CHECK: transform.sequence
# CHECK: = transform.structured.vectorize
@@ -644,19 +503,15 @@ def testVectorizeAllAttrs():
@run
-def testVectorizeNoAttrs():
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
+ at create_sequence
+def testVectorizeNoAttrs(target):
+ structured.VectorizeOp(
+ target,
+ disable_multi_reduction_to_contract_patterns=False,
+ disable_transfer_permutation_map_lowering_patterns=False,
+ vectorize_nd_extract=False,
+ vectorize_padding=False,
)
- with InsertionPoint(sequence.body):
- structured.VectorizeOp(
- sequence.bodyTarget,
- disable_multi_reduction_to_contract_patterns=False,
- disable_transfer_permutation_map_lowering_patterns=False,
- vectorize_nd_extract=False,
- vectorize_padding=False,
- )
- transform.YieldOp()
# CHECK-LABEL: TEST: testVectorizeNoAttrs
# CHECK: transform.sequence
# CHECK: = transform.structured.vectorize
@@ -667,20 +522,16 @@ def testVectorizeNoAttrs():
@run
-def testMatchInterfaceEnum():
+ at create_sequence
+def testMatchInterfaceEnum(target):
names = ArrayAttr.get([StringAttr.get("test.dummy")])
result_type = transform.AnyOpType.get()
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+ fused = structured.MatchOp.__base__(
+ result_type,
+ target,
+ ops=names,
+ interface=structured.MatchInterfaceEnum.LinalgOp,
)
- with InsertionPoint(sequence.body):
- fused = structured.MatchOp.__base__(
- result_type,
- sequence.bodyTarget,
- ops=names,
- interface=structured.MatchInterfaceEnum.LinalgOp,
- )
- transform.YieldOp()
# CHECK-LABEL: TEST: testMatchInterfaceEnum
# CHECK: transform.sequence
# CHECK: = transform.structured.match
@@ -688,7 +539,8 @@ def testMatchInterfaceEnum():
@run
-def testMatchInterfaceEnumReplaceAttributeBuilder():
+ at create_sequence
+def testMatchInterfaceEnumReplaceAttributeBuilder(target):
@register_attribute_builder("MatchInterfaceEnum", replace=True)
def match_interface_enum(x, context):
if x == "LinalgOp":
@@ -699,17 +551,12 @@ def match_interface_enum(x, context):
names = ArrayAttr.get([StringAttr.get("test.dummy")])
result_type = transform.AnyOpType.get()
- sequence = transform.SequenceOp(
- transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+ fused = structured.MatchOp.__base__(
+ result_type,
+ target,
+ ops=names,
+ interface="TilingInterface",
)
- with InsertionPoint(sequence.body):
- fused = structured.MatchOp.__base__(
- result_type,
- sequence.bodyTarget,
- ops=names,
- interface="TilingInterface",
- )
- transform.YieldOp()
# CHECK-LABEL: TEST: testMatchInterfaceEnumReplaceAttributeBuilder
# CHECK: transform.sequence
# CHECK: = transform.structured.match
More information about the Mlir-commits
mailing list