[Mlir-commits] [mlir] bffec21 - [mlir][transform] Add ForeachOp to transform dialect
Matthias Springer
llvmlistbot at llvm.org
Tue Jul 26 09:11:05 PDT 2022
Author: Matthias Springer
Date: 2022-07-26T18:07:44+02:00
New Revision: bffec215abbd643ceac83e58caa244ded8cd837c
URL: https://github.com/llvm/llvm-project/commit/bffec215abbd643ceac83e58caa244ded8cd837c
DIFF: https://github.com/llvm/llvm-project/commit/bffec215abbd643ceac83e58caa244ded8cd837c.diff
LOG: [mlir][transform] Add ForeachOp to transform dialect
This op "unbatches" an op handle and executes the loop body for each payload op.
Differential Revision: https://reviews.llvm.org/D130257
Added:
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/test/Dialect/Transform/ops-invalid.mlir
mlir/test/Dialect/Transform/ops.mlir
mlir/test/Dialect/Transform/test-interpreter.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
index 27a68b05429c4..f11e3da6f93d3 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
@@ -64,6 +64,14 @@ def Transform_Dialect : Dialect {
correspond to groups of outer and inner loops, respectively, produced by
the tiling transformation.
+ A Transform IR value such as `%0` may be associated with multiple payload
+ operations. This is conceptually a set of operations and no assumptions
+ should be made about the order of ops. Most Transform IR ops support
+ operand values that are mapped to multiple operations. They usually apply
+ the respective transformation for every mapped op ("batched execution").
+ Deviations from this convention are described in the documentation of
+ Transform IR ops.
+
Overall, Transform IR ops are expected to be contained in a single top-level
op. Such top-level ops specify how to apply the transformations described
by the operations they contain, e.g., `transform.sequence` executes
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index d578c15e48370..bc5fd01d215ae 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -95,6 +95,46 @@ def AlternativesOp : TransformDialectOp<"alternatives",
let hasVerifier = 1;
}
+def ForeachOp : TransformDialectOp<"foreach",
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ DeclareOpInterfaceMethods<RegionBranchOpInterface, [
+ "getSuccessorRegions", "getSuccessorEntryOperands"]>,
+ SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">
+ ]> {
+ let summary = "Executes the body for each payload op";
+ let description = [{
+ This op has exactly one region with exactly one block ("body"). The body is
+ executed for each payload op that is associated to the target operand in an
+ unbatched fashion. I.e., the block argument ("iteration variable") is always
+ mapped to exactly one payload op.
+
+ This op always reads the target handle. Furthermore, it consumes the handle
+ if there is a transform op in the body that consumes the iteration variable.
+ This op does not return anything.
+
+ The transformations inside the body are applied in order of their
+ appearance. During application, if any transformation in the sequence fails,
+ the entire sequence fails immediately leaving the payload IR in potentially
+ invalid state, i.e., this operation offers no transformation rollback
+ capabilities.
+ }];
+
+ let arguments = (ins PDL_Operation:$target);
+ let results = (outs);
+ let regions = (region SizedRegion<1>:$body);
+ let assemblyFormat = "$target $body attr-dict";
+
+ let extraClassDeclaration = [{
+ /// Allow the dialect prefix to be omitted.
+ static StringRef getDefaultDialect() { return "transform"; }
+
+ BlockArgument getIterationVariable() {
+ return getBody().front().getArgument(0);
+ }
+ }];
+}
+
def GetClosestIsolatedParentOp : TransformDialectOp<"get_closest_isolated_parent",
[DeclareOpInterfaceMethods<TransformOpInterface>,
NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index bc2d9710b8a41..5d702d9f3dbdd 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -273,6 +273,64 @@ LogicalResult transform::AlternativesOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// ForeachOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::ForeachOp::apply(transform::TransformResults &results,
+ transform::TransformState &state) {
+ ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
+ for (Operation *op : payloadOps) {
+ auto scope = state.make_region_scope(getBody());
+ if (failed(state.mapBlockArguments(getIterationVariable(), {op})))
+ return DiagnosedSilenceableFailure::definiteFailure();
+
+ for (Operation &transform : getBody().front().without_terminator()) {
+ DiagnosedSilenceableFailure result = state.applyTransform(
+ cast<transform::TransformOpInterface>(transform));
+ if (!result.succeeded())
+ return result;
+ }
+ }
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::ForeachOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ BlockArgument iterVar = getIterationVariable();
+ if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
+ return isHandleConsumed(iterVar, cast<TransformOpInterface>(&op));
+ })) {
+ consumesHandle(getTarget(), effects);
+ } else {
+ onlyReadsHandle(getTarget(), effects);
+ }
+}
+
+void transform::ForeachOp::getSuccessorRegions(
+ Optional<unsigned> index, ArrayRef<Attribute> operands,
+ SmallVectorImpl<RegionSuccessor> ®ions) {
+ Region *bodyRegion = &getBody();
+ if (!index) {
+ regions.emplace_back(bodyRegion, bodyRegion->getArguments());
+ return;
+ }
+
+ // Branch back to the region or the parent.
+ assert(*index == 0 && "unexpected region index");
+ regions.emplace_back(bodyRegion, bodyRegion->getArguments());
+ regions.emplace_back();
+}
+
+OperandRange
+transform::ForeachOp::getSuccessorEntryOperands(Optional<unsigned> index) {
+ // The iteration variable op handle is mapped to a subset (one op to be
+ // precise) of the payload ops of the ForeachOp operand.
+ assert(index && *index == 0 && "unexpected region index");
+ return getOperation()->getOperands();
+}
+
//===----------------------------------------------------------------------===//
// GetClosestIsolatedParentOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir
index b76bd07d3a475..2650651d25b33 100644
--- a/mlir/test/Dialect/Transform/ops-invalid.mlir
+++ b/mlir/test/Dialect/Transform/ops-invalid.mlir
@@ -184,3 +184,18 @@ transform.alternatives {
^bb0:
transform.yield
}
+
+// -----
+
+transform.sequence {
+^bb0(%arg0: !pdl.operation):
+ // expected-error @below {{result #0 has more than one potential consumer}}
+ %0 = test_produce_param_or_forward_operand 42
+ // expected-note @below {{used here as operand #0}}
+ transform.foreach %0 {
+ ^bb1(%arg1: !pdl.operation):
+ transform.test_consume_operand %arg1
+ }
+ // expected-note @below {{used here as operand #0}}
+ transform.test_consume_operand %0
+}
diff --git a/mlir/test/Dialect/Transform/ops.mlir b/mlir/test/Dialect/Transform/ops.mlir
index e9e99de310b72..23dd6b835dd8f 100644
--- a/mlir/test/Dialect/Transform/ops.mlir
+++ b/mlir/test/Dialect/Transform/ops.mlir
@@ -49,3 +49,12 @@ transform.sequence {
^bb3(%arg3: !pdl.operation):
}
}
+
+// CHECK: transform.sequence
+// CHECK: foreach
+transform.sequence {
+^bb0(%arg0: !pdl.operation):
+ transform.foreach %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ }
+}
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 1f0dd40fc8b8b..ed3a2507e2a1c 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -597,3 +597,33 @@ module {
}
}
}
+
+// -----
+
+func.func @bar() {
+ // expected-remark @below {{transform applied}}
+ %0 = arith.constant 0 : i32
+ // expected-remark @below {{transform applied}}
+ %1 = arith.constant 1 : i32
+ return
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @const : benefit(1) {
+ %r = pdl.types
+ %0 = pdl.operation "arith.constant" -> (%r : !pdl.range<type>)
+ pdl.rewrite %0 with "transform.dialect"
+ }
+
+ transform.sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %f = pdl_match @const in %arg1
+ transform.foreach %f {
+ ^bb2(%arg2: !pdl.operation):
+ // expected-remark @below {{1}}
+ transform.test_print_number_of_associated_payload_ir_ops %arg2
+ transform.test_print_remark_at_operand %arg2, "transform applied"
+ }
+ }
+}
More information about the Mlir-commits
mailing list