[Mlir-commits] [mlir] 488f88b - [mlir][transform] Add elementwise criteria to `match.structured.body` (#79626)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jan 31 01:12:38 PST 2024
Author: srcarroll
Date: 2024-01-31T10:12:33+01:00
New Revision: 488f88b844739fb8dac6a05799a1e1ec450c0ad9
URL: https://github.com/llvm/llvm-project/commit/488f88b844739fb8dac6a05799a1e1ec450c0ad9
DIFF: https://github.com/llvm/llvm-project/commit/488f88b844739fb8dac6a05799a1e1ec450c0ad9.diff
LOG: [mlir][transform] Add elementwise criteria to `match.structured.body` (#79626)
As far as I am aware, there is no simple way to match on elementwise
ops. I propose to add an `elementwise` criteria to the
`match.structured.body` op. Although my only hesitation is that
elementwise is not only determined by the body, but also the indexing
maps. So if others find this too awkward, I can implement a separate
match op instead.
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
mlir/test/Dialect/Linalg/match-ops-invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
index 162dd05f93030..dfeb8ae5d5ddb 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
@@ -106,6 +106,9 @@ def MatchStructuredBodyOp : Op<Transform_Dialect, "match.structured.body", [
* `passthrough`: the body of the structured payload op only forwards
inputs to the outputs (copy or broadcast).
+ * `elementwise`: the body of the structured payload op represents an
+ elementwise operation.
+
* `contraction`: the body of the structured payload op is a contraction
of the form `<red>(<elem>(bbarg0, bbarg1), bbarg2)` where `<elem>` and
`<red>` are binary operations whose names are specified in the attribute
@@ -123,6 +126,7 @@ def MatchStructuredBodyOp : Op<Transform_Dialect, "match.structured.body", [
let arguments = (ins TransformHandleTypeInterface:$operand_handle,
OptionalAttr<I64Attr>:$reduction_position,
UnitAttr:$passthrough,
+ UnitAttr:$elementwise,
OptionalAttr<StrArrayAttr>:$contraction);
let assemblyFormat = "$operand_handle attr-dict `:` type($operand_handle)";
let extraClassDeclaration = SingleOpMatcher.extraDeclaration;
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
index 115da4b90e063..fb18886c16b16 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
@@ -11,6 +11,7 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/TransformOps/Syntax.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Transform/IR/MatchInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Interfaces/FunctionImplementation.h"
@@ -187,6 +188,11 @@ DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation(
}
return DiagnosedSilenceableFailure::success();
}
+ if (getElementwise()) {
+ if (!isElementwise(linalgOp))
+ return emitSilenceableError() << "not elementwise";
+ return DiagnosedSilenceableFailure::success();
+ }
if (std::optional<ArrayAttr> contractionOps = getContraction()) {
Block &body = linalgOp->getRegion(0).front();
std::string message;
@@ -209,13 +215,14 @@ DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation(
LogicalResult transform::MatchStructuredBodyOp::verify() {
int64_t numOptions = getReductionPosition().has_value() + getPassthrough() +
- getContraction().has_value();
+ getElementwise() + getContraction().has_value();
if (numOptions > 1) {
std::string attributeNames;
llvm::raw_string_ostream os(attributeNames);
llvm::interleaveComma(ArrayRef<StringAttr>{getReductionPositionAttrName(),
getPassthroughAttrName(),
+ getElementwiseAttrName(),
getContractionAttrName()},
os);
return emitOpError() << "only one of {" << os.str() << "} is allowed";
diff --git a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
index a7353a4c38881..24c7bdd9e1050 100644
--- a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
+++ b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
@@ -180,6 +180,63 @@ module attributes { transform.with_named_sequence } {
// -----
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @print_elementwise(%arg0: !transform.any_op {transform.readonly}) {
+ transform.debug.emit_remark_at %arg0, "elementwise" : !transform.any_op
+ transform.yield
+ }
+
+ transform.named_sequence @match_structured_body_elementwise(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
+ %0 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.any_op {
+ ^bb0(%arg1: !transform.any_op):
+ transform.match.structured.body %arg1 { elementwise } : !transform.any_op
+ transform.match.structured.yield %arg1 : !transform.any_op
+ }
+ transform.yield %0 : !transform.any_op
+ }
+
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.consumed}) {
+ transform.foreach_match in %arg0
+ @match_structured_body_elementwise -> @print_elementwise
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+
+ func.func @payload(%in1: tensor<2xf32>, %in2: tensor<2xf32>, %in3: tensor<2x3xf32>, %out: tensor<2xf32>, %out2: tensor<2x3xf32>) -> (tensor<2xf32>, tensor<2x3xf32>, tensor<2x3xf32>) attributes { transform.target_tag = "start_here" } {
+ %cst0 = arith.constant 0.0 : f32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // expected-remark @below {{elementwise}}
+ %fill = linalg.fill ins(%cst0: f32) outs(%out: tensor<2xf32>) -> tensor<2xf32>
+ // expected-remark @below {{elementwise}}
+ %add = linalg.map {arith.addf} ins(%in1, %in2: tensor<2xf32>, tensor<2xf32>) outs(%fill: tensor<2xf32>)
+ %non_elementwise = linalg.generic
+ {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%in1, %in3: tensor<2xf32>, tensor<2x3xf32>) outs(%out2: tensor<2x3xf32>) {
+ ^bb0(%arg0: f32, %arg1: f32, %arg3: f32):
+ %0 = arith.addf %arg0, %arg1 : f32
+ %1 = tensor.dim %add, %c0 : tensor<2xf32>
+ %2 = arith.subi %1, %c1 : index
+ %3 = tensor.extract %add[%2] : tensor<2xf32>
+ %4 = arith.mulf %0, %3 : f32
+ linalg.yield %4 : f32
+ } -> tensor<2x3xf32>
+ // expected-remark @below {{elementwise}}
+ %add_bcast = linalg.generic
+ {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%in1, %in3: tensor<2xf32>, tensor<2x3xf32>) outs(%out2: tensor<2x3xf32>) {
+ ^bb0(%arg0: f32, %arg1: f32, %arg3: f32):
+ %0 = arith.addf %arg0, %arg1 : f32
+ linalg.yield %0 : f32
+ } -> tensor<2x3xf32>
+ return %add, %add_bcast, %non_elementwise : tensor<2xf32>, tensor<2x3xf32>, tensor<2x3xf32>
+ }
+}
+
+// -----
+
module attributes { transform.with_named_sequence } {
transform.named_sequence @print_reduction(%arg0: !transform.any_op {transform.readonly}) {
transform.debug.emit_remark_at %arg0, "reduction" : !transform.any_op
diff --git a/mlir/test/Dialect/Linalg/match-ops-invalid.mlir b/mlir/test/Dialect/Linalg/match-ops-invalid.mlir
index ec99e205090c4..9ff430a350360 100644
--- a/mlir/test/Dialect/Linalg/match-ops-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/match-ops-invalid.mlir
@@ -64,7 +64,7 @@ transform.sequence failures(suppress) {
^bb0(%arg0: !transform.any_op):
transform.match.structured %arg0 : !transform.any_op {
^bb1(%arg1: !transform.any_op):
- // expected-error @below {{only one of {"reduction_position", "passthrough", "contraction"} is allowed}}
+ // expected-error @below {{only one of {"reduction_position", "passthrough", "elementwise", "contraction"} is allowed}}
transform.match.structured.body %arg1 { passthrough, reduction_position = 0 } : !transform.any_op
transform.match.structured.yield
}
More information about the Mlir-commits
mailing list