[Mlir-commits] [mlir] c439913 - [MLIR][Transform] Add attribute in MatchOp to filter by operand type (#67994)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Dec 7 00:28:56 PST 2023
Author: Pablo Antonio Martinez
Date: 2023-12-07T08:28:52Z
New Revision: c4399130ae403acf4e6325b8b46a51bb6abf222f
URL: https://github.com/llvm/llvm-project/commit/c4399130ae403acf4e6325b8b46a51bb6abf222f
DIFF: https://github.com/llvm/llvm-project/commit/c4399130ae403acf4e6325b8b46a51bb6abf222f.diff
LOG: [MLIR][Transform] Add attribute in MatchOp to filter by operand type (#67994)
This patchs adds the `filter_operand_types` attribute to transform::MatchOp, allowing to filter ops depending on their operand types.
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/test/Dialect/Linalg/transform-op-match.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index de65f3176c46a..77ed9db5e71bd 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -574,6 +574,11 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
- attribute: the matched op must have all specified attributes (with their
specified values).
- filter_result_type: the matched op must return exactly this one type.
+ - filter_operand_types: all the operands of the matched op must must be of
+ this type. If more than a type is specified, then the length of the list
+ must be equal to the number of operands in the matched op, and the match
+ will succeed only if the operand types match all the types in the list
+ in the order in which they are specified.
Note: Only ops that satisfy all specified constraints are matched.
@@ -595,7 +600,8 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
OptionalAttr<StrArrayAttr>:$ops,
OptionalAttr<MatchInterfaceEnum>:$interface,
OptionalAttr<DictionaryAttr>:$op_attrs,
- OptionalAttr<TypeAttr>:$filter_result_type);
+ OptionalAttr<TypeAttr>:$filter_result_type,
+ OptionalAttr<TypeArrayAttr>:$filter_operand_types);
// TODO: variadic results when needed.
let results = (outs TransformHandleTypeInterface:$results);
@@ -609,6 +615,7 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
(`interface` `{` $interface^ `}`)?
(`attributes` $op_attrs^)?
(`filter_result_type` `=` $filter_result_type^)?
+ (`filter_operand_types` `=` $filter_operand_types^)?
`in` $target attr-dict
`:` functional-type($target, results)
}];
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index e3713457e8412..54055aefbc512 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1171,6 +1171,7 @@ transform::MatchOp::apply(transform::TransformRewriter &rewriter,
}
SmallVector<Operation *> res;
+ bool incorrectNumOperandTypes = false;
auto matchFun = [&](Operation *op) {
if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))
return;
@@ -1210,12 +1211,47 @@ transform::MatchOp::apply(transform::TransformRewriter &rewriter,
return;
}
+ if (getFilterOperandTypes().has_value()) {
+ mlir::ArrayAttr types = getFilterOperandTypes().value();
+ auto operandTypes = op->getOperandTypes();
+
+ if (types.size() == 1) {
+ // All the operands must must be equal to the specified type
+ auto typeattr =
+ dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
+ Type t = typeattr.getValue().cast<::mlir::Type>();
+ if (!llvm::all_of(op->getOperandTypes(),
+ [&](Type operandType) { return operandType == t; }))
+ return;
+ } else {
+ // The operand types must match all the types in the list (in the same
+ // order in with they are specified)
+ if (types.size() != operandTypes.size()) {
+ incorrectNumOperandTypes = true;
+ return;
+ }
+
+ for (auto [attr, operandType] :
+ llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
+ auto typeattr = cast<mlir::TypeAttr>(attr);
+ Type type = typeattr.getValue().cast<::mlir::Type>();
+
+ if (type != operandType)
+ return;
+ }
+ }
+ }
+
// All constraints are satisfied.
res.push_back(op);
return;
};
(*payloadOps.begin())->walk(matchFun);
+ if (incorrectNumOperandTypes)
+ return emitDefiniteFailure("If filter_operand_types contains more than a "
+ "type, then it must contain as much types as "
+ "the number of operands in the target ops");
results.set(cast<OpResult>(getResult()), res);
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/test/Dialect/Linalg/transform-op-match.mlir b/mlir/test/Dialect/Linalg/transform-op-match.mlir
index 7d48b1f403b3b..3b30c18457643 100644
--- a/mlir/test/Dialect/Linalg/transform-op-match.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-match.mlir
@@ -43,6 +43,46 @@ module attributes {transform.with_named_sequence} {
// -----
+func.func @by_operand_type() {
+ %c2 = arith.constant 2.0: f32
+ %v = arith.constant 8: i32
+ %r1 = math.fpowi %c2, %v : f32, i32
+ // expected-remark @below {{matched op name}}
+ %r2 = arith.addf %c2, %c2 : f32
+ // expected-remark @below {{matched op name}}
+ %r3 = arith.fptoui %r2 : f32 to i32
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %match_name1 = transform.structured.match
+ ops{["arith.fptoui"]} filter_operand_types = [f32] in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.test_print_remark_at_operand %match_name1, "matched op name" : !transform.any_op
+ transform.test_consume_operand %match_name1 : !transform.any_op
+
+ %match_name2 = transform.structured.match
+ ops{["arith.addf"]} filter_operand_types = [f32] in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.test_print_remark_at_operand %match_name2, "matched op name" : !transform.any_op
+ transform.test_consume_operand %match_name2 : !transform.any_op
+
+ %no_match_name1 = transform.structured.match
+ ops{["arith.fptoui"]} filter_operand_types = [i32] in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.test_print_remark_at_operand %no_match_name1, "should not match" : !transform.any_op
+ transform.test_consume_operand %no_match_name1 : !transform.any_op
+
+ %no_match_name2 = transform.structured.match
+ ops{["math.fpowi"]} filter_operand_types = [f32] in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.test_print_remark_at_operand %no_match_name2, "should not match" : !transform.any_op
+ transform.test_consume_operand %no_match_name2 : !transform.any_op
+
+ // expected-error @+1 {{If filter_operand_types contains more than a type, then it must contain as much types as the number of operands in the target ops}}
+ %failure_match = transform.structured.match
+ ops{["arith.fptoui"]} filter_operand_types = [i32, i32] in %arg1 : (!transform.any_op) -> !transform.any_op
+}
+
+// -----
+
func.func @foo(%a: tensor<4x4xf32>, %b: tensor<4x4xf32>, %c: tensor<4x4xf32>) {
%c0 = arith.constant 0.0 : f32
// expected-remark @below {{tileable}}
More information about the Mlir-commits
mailing list