[Mlir-commits] [mlir] c1e6caa - [mlir][transform] Support results on ForeachOp
Matthias Springer
llvmlistbot at llvm.org
Thu Jul 28 01:43:27 PDT 2022
Author: Matthias Springer
Date: 2022-07-28T10:39:54+02:00
New Revision: c1e6caac7059ee4dd4a3860f7a1c954f07f6120e
URL: https://github.com/llvm/llvm-project/commit/c1e6caac7059ee4dd4a3860f7a1c954f07f6120e
DIFF: https://github.com/llvm/llvm-project/commit/c1e6caac7059ee4dd4a3860f7a1c954f07f6120e.diff
LOG: [mlir][transform] Support results on ForeachOp
Handles can be yielded from the ForeachOp.
Differential Revision: https://reviews.llvm.org/D130640
Added:
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/test/Dialect/Transform/test-interpreter.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index bc5fd01d215ae..1f28b33aa7b93 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -118,12 +118,17 @@ def ForeachOp : TransformDialectOp<"foreach",
the entire sequence fails immediately leaving the payload IR in potentially
invalid state, i.e., this operation offers no transformation rollback
capabilities.
+
+ This op generates as many handles as the terminating YieldOp has operands.
+ For each result, the payload ops of the corresponding YieldOp operand are
+ merged and mapped to the same resulting handle.
}];
let arguments = (ins PDL_Operation:$target);
- let results = (outs);
+ let results = (outs Variadic<PDL_Operation>:$results);
let regions = (region SizedRegion<1>:$body);
- let assemblyFormat = "$target $body attr-dict";
+ let assemblyFormat = "$target (`->` type($results)^)? $body attr-dict";
+ let hasVerifier = 1;
let extraClassDeclaration = [{
/// Allow the dialect prefix to be omitted.
@@ -132,6 +137,8 @@ def ForeachOp : TransformDialectOp<"foreach",
BlockArgument getIterationVariable() {
return getBody().front().getArgument(0);
}
+
+ transform::YieldOp getYieldOp();
}];
}
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 5d702d9f3dbdd..511b8544ddfdc 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -281,18 +281,32 @@ DiagnosedSilenceableFailure
transform::ForeachOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
+ SmallVector<SmallVector<Operation *>> resultOps(getNumResults(), {});
+
for (Operation *op : payloadOps) {
auto scope = state.make_region_scope(getBody());
if (failed(state.mapBlockArguments(getIterationVariable(), {op})))
return DiagnosedSilenceableFailure::definiteFailure();
+ // Execute loop body.
for (Operation &transform : getBody().front().without_terminator()) {
DiagnosedSilenceableFailure result = state.applyTransform(
cast<transform::TransformOpInterface>(transform));
if (!result.succeeded())
return result;
}
+
+ // Append yielded payload ops to result list (if any).
+ for (unsigned i = 0; i < getNumResults(); ++i) {
+ ArrayRef<Operation *> yieldedOps =
+ state.getPayloadOps(getYieldOp().getOperand(i));
+ resultOps[i].append(yieldedOps.begin(), yieldedOps.end());
+ }
}
+
+ for (unsigned i = 0; i < getNumResults(); ++i)
+ results.set(getResult(i).cast<OpResult>(), resultOps[i]);
+
return DiagnosedSilenceableFailure::success();
}
@@ -306,6 +320,9 @@ void transform::ForeachOp::getEffects(
} else {
onlyReadsHandle(getTarget(), effects);
}
+
+ for (Value result : getResults())
+ producesHandle(result, effects);
}
void transform::ForeachOp::getSuccessorRegions(
@@ -331,6 +348,21 @@ transform::ForeachOp::getSuccessorEntryOperands(Optional<unsigned> index) {
return getOperation()->getOperands();
}
+transform::YieldOp transform::ForeachOp::getYieldOp() {
+ return cast<transform::YieldOp>(getBody().front().getTerminator());
+}
+
+LogicalResult transform::ForeachOp::verify() {
+ auto yieldOp = getYieldOp();
+ if (getNumResults() != yieldOp.getNumOperands())
+ return emitOpError() << "expects the same number of results as the "
+ "terminator has operands";
+ for (Value v : yieldOp.getOperands())
+ if (!v.getType().isa<pdl::OperationType>())
+ return yieldOp->emitOpError("expects only PDL_Operation operands");
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// GetClosestIsolatedParentOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index ed3a2507e2a1c..425ba608adeb4 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -627,3 +627,52 @@ transform.with_pdl_patterns {
}
}
}
+
+// -----
+
+func.func @bar() {
+ scf.execute_region {
+ // expected-remark @below {{transform applied}}
+ %0 = arith.constant 0 : i32
+ scf.yield
+ }
+
+ scf.execute_region {
+ // expected-remark @below {{transform applied}}
+ %1 = arith.constant 1 : i32
+ // expected-remark @below {{transform applied}}
+ %2 = arith.constant 2 : i32
+ scf.yield
+ }
+
+ return
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @const : benefit(1) {
+ %r = pdl.types
+ %0 = pdl.operation "arith.constant" -> (%r : !pdl.range<type>)
+ pdl.rewrite %0 with "transform.dialect"
+ }
+
+ pdl.pattern @execute_region : benefit(1) {
+ %r = pdl.types
+ %0 = pdl.operation "scf.execute_region" -> (%r : !pdl.range<type>)
+ pdl.rewrite %0 with "transform.dialect"
+ }
+
+ transform.sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %f = pdl_match @execute_region in %arg1
+ %results = transform.foreach %f -> !pdl.operation {
+ ^bb2(%arg2: !pdl.operation):
+ %g = transform.pdl_match @const in %arg2
+ transform.yield %g : !pdl.operation
+ }
+
+ // expected-remark @below {{3}}
+ transform.test_print_number_of_associated_payload_ir_ops %results
+ transform.test_print_remark_at_operand %results, "transform applied"
+ }
+}
More information about the Mlir-commits
mailing list