[Mlir-commits] [mlir] 0581ab6 - [mlir][linalg][transform] Support matching of attributes (and their values)
Matthias Springer
llvmlistbot at llvm.org
Fri Aug 12 05:55:09 PDT 2022
Author: Matthias Springer
Date: 2022-08-12T14:55:00+02:00
New Revision: 0581ab65ea049069faaa103226948f033fb3fda6
URL: https://github.com/llvm/llvm-project/commit/0581ab65ea049069faaa103226948f033fb3fda6
DIFF: https://github.com/llvm/llvm-project/commit/0581ab65ea049069faaa103226948f033fb3fda6.diff
LOG: [mlir][linalg][transform] Support matching of attributes (and their values)
Do not just check if an attribute exists on the payload op. Also check its value.
Differential Revision: https://reviews.llvm.org/D131760
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/test/Dialect/Linalg/transform-op-match.mlir
mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 0ad37ada457d7..11c9b898ef58d 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -197,21 +197,21 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
- ops: an optional StrArrayAttr specifying the concrete name of an op.
Multiple names can be specified. Matched ops must have one of specified
names.
- - attribute: an optional Str specifying the name of an attribute that
- matched ops must have.
+ - attribute: the matched op must have all specified attributes (with their
+ specified values).
Note: Only ops that satisfy all specified constraints are matched.
TODO: Extend with regions to allow a limited form of constraints.
#### Return modes
-
+
This op traverses the ops nested under `target` and returns the handles to
all the operations that match the requirements.
This op fails if the target is not a handle to exactly one operation.
Otherwise it succeeds.
-
+
This operation does not consume the target handle and produces new handles:
it is a navigation op.
}];
@@ -219,14 +219,14 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
let arguments = (ins PDL_Operation:$target,
OptionalAttr<StrArrayAttr>:$ops,
OptionalAttr<MatchInterfaceEnum>:$interface,
- OptionalAttr<StrAttr>:$attribute);
+ OptionalAttr<DictionaryAttr>:$op_attrs);
// TODO: variadic results when needed.
let results = (outs PDL_Operation:$results);
let assemblyFormat = [{
- (`ops` `{` $ops^ `}`)?
- (`interface` `{` $interface^ `}`)?
- (`attribute` `{` $attribute^ `}`)?
+ (`ops` `{` $ops^ `}`)?
+ (`interface` `{` $interface^ `}`)?
+ (`attributes` $op_attrs^)?
`in` $target attr-dict
}];
}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8f2fb306030bd..328c4141e50fa 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -467,8 +467,19 @@ transform::MatchOp::apply(transform::TransformResults &results,
return WalkResult::advance();
}
- if (getAttribute().has_value() && !op->hasAttr(getAttribute().value()))
- return WalkResult::advance();
+ // Check if all specified attributes match.
+ if (getOpAttrs().has_value()) {
+ DictionaryAttr opAttrs = getOpAttrs().value();
+ for (NamedAttribute attr : opAttrs) {
+ if (attr.getName() == getInterfaceAttrName() ||
+ attr.getName() == getOpsAttrName())
+ continue;
+ if (!op->hasAttr(attr.getName()))
+ return WalkResult::advance();
+ if (op->getAttr(attr.getName()) != attr.getValue())
+ return WalkResult::advance();
+ }
+ }
// All constraints are satisfied.
res.push_back(op);
diff --git a/mlir/test/Dialect/Linalg/transform-op-match.mlir b/mlir/test/Dialect/Linalg/transform-op-match.mlir
index 06af6c8dad96a..002f142f55d1b 100644
--- a/mlir/test/Dialect/Linalg/transform-op-match.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-match.mlir
@@ -17,8 +17,45 @@ transform.with_pdl_patterns {
transform.test_print_remark_at_operand %match_name, "matched op name"
transform.test_consume_operand %match_name
- %match_attr = transform.structured.match ops{["arith.constant"]} attribute{"my_attr"} in %arg1
+ %match_attr = transform.structured.match ops{["arith.constant"]} attributes{my_attr} in %arg1
transform.test_print_remark_at_operand %match_attr, "matched attr name"
transform.test_consume_operand %match_attr
}
}
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
+func.func @match_complex_attribute(%arg0: tensor<12x128x32xf32>)
+ -> tensor<128x12x32xf32> {
+ %0 = linalg.init_tensor [128, 12, 32] : tensor<128x12x32xf32>
+ // expected-remark @below {{matched complex attr}}
+ %1 = linalg.generic {indexing_maps = [#map0, #map1],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%arg0 : tensor<12x128x32xf32>)
+ outs(%0 : tensor<128x12x32xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ linalg.yield %arg1 : f32
+ } -> tensor<128x12x32xf32>
+ return %1 : tensor<128x12x32xf32>
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %match_attr = transform.structured.match
+ ops{["linalg.generic"]}
+ attributes{iterator_types = ["parallel", "parallel", "parallel"]}
+ in %arg1
+ transform.test_print_remark_at_operand %match_attr, "matched complex attr"
+ transform.test_consume_operand %match_attr
+
+ %no_match = transform.structured.match
+ attributes{iterator_types = ["parallel", "parallel", "reduction"]}
+ in %arg1
+ // expected-remark @below {{0}}
+ transform.test_print_number_of_associated_payload_ir_ops %no_match
+ }
+}
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir
index 1109950916ed8..1f6185f664661 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir
@@ -45,8 +45,8 @@ module {
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
// Find the root and all producers.
- %root = transform.structured.match attribute{"__root__"} in %arg1
- %producers = transform.structured.match attribute{"__producer__"} in %arg1
+ %root = transform.structured.match attributes{"__root__"} in %arg1
+ %producers = transform.structured.match attributes{"__producer__"} in %arg1
// Tile the root.
%foreach_thread_op, %tiled_op = transform.structured.tile_to_foreach_thread_op %root num_threads [10, 20]
@@ -105,8 +105,8 @@ module {
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
// Find the root and all producers.
- %root = transform.structured.match attribute{"__root__"} in %arg1
- %producers = transform.structured.match attribute{"__producer__"} in %arg1
+ %root = transform.structured.match attributes{"__root__"} in %arg1
+ %producers = transform.structured.match attributes{"__producer__"} in %arg1
%reversed_producers = transform.test_reverse_payload_ops %producers
// Tile the root.
More information about the Mlir-commits
mailing list