[Mlir-commits] [mlir] [mlir][Transform] Extend transform.foreach to take multiple arguments (PR #93705)
Rolf Morel
llvmlistbot at llvm.org
Thu May 30 09:38:20 PDT 2024
================
@@ -1391,15 +1391,62 @@ DiagnosedSilenceableFailure
transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
- SmallVector<SmallVector<Operation *>> resultOps(getNumResults(), {});
- // Store payload ops in a vector because ops may be removed from the mapping
- // by the TrackingRewriter while the iteration is in progress.
- SmallVector<Operation *> targets =
- llvm::to_vector(state.getPayloadOps(getTarget()));
- for (Operation *op : targets) {
+ // Collect the arguments with which to call each iteration of the body.
+ // We store the payload before executing the body as ops may be removed from
+ // the mapping by the TrackingRewriter while the iteration is in progress.
+ SmallVector<SmallVector<MappedValue>> zippedArgs;
+ for (auto firstTarget : getTargets().take_front(1)) // Loop runs at most once.
+ // For each element, init a tuple with which to call the body later on.
+ if (isa<TransformHandleTypeInterface>(firstTarget.getType()))
+ for (auto &op : state.getPayloadOps(firstTarget))
+ zippedArgs.append({{op}}); // NB: append's argument is an init-list.
+ else if (isa<TransformValueHandleTypeInterface>(firstTarget.getType()))
+ for (auto val : state.getPayloadValues(firstTarget))
+ zippedArgs.append({{val}});
+ else if (isa<TransformParamTypeInterface>(firstTarget.getType()))
+ for (auto param : state.getParams(firstTarget))
+ zippedArgs.append({{param}});
+ else
+ return emitDefiniteFailure()
+ << "unhandled handle type " << firstTarget.getType();
+
+ for (auto target : getTargets().drop_front(1)) {
+ // Append each element of payload to the co-indexed body-arguments-as-tuple.
+ size_t payloadSize = 0;
+ if (isa<TransformHandleTypeInterface>(target.getType())) {
+ for (auto op : state.getPayloadOps(target))
+ if (++payloadSize <= zippedArgs.size())
+ zippedArgs[payloadSize - 1].append({op});
+ } else if (isa<TransformValueHandleTypeInterface>(target.getType())) {
+ for (auto val : state.getPayloadValues(target))
+ if (++payloadSize <= zippedArgs.size())
+ zippedArgs[payloadSize - 1].append({val});
+ } else if (isa<TransformParamTypeInterface>(target.getType())) {
+ for (auto param : state.getParams(target))
+ if (++payloadSize <= zippedArgs.size())
+ zippedArgs[payloadSize - 1].append({param});
+ } else
+ return emitDefiniteFailure()
+ << "unhandled handle type " << target.getType();
+
+ if (payloadSize != zippedArgs.size())
+ return emitSilenceableError()
+ << "payload size of prior targets (" << zippedArgs.size()
+ << ") differs from payload size (" << payloadSize << ") of target "
+ << target;
+ }
+
+ // For each arguments-as-tuple collected up above, execute the body region.
+ SmallVector<SmallVector<MappedValue>> zippedResults(getNumResults(), {});
+ for (SmallVector<MappedValue> &argsTuple : zippedArgs) {
auto scope = state.make_region_scope(getBody());
- if (failed(state.mapBlockArguments(getIterationVariable(), {op})))
- return DiagnosedSilenceableFailure::definiteFailure();
+ // Set up arguments to the region's block.
+ for (auto &&[blockArg, argument] :
+ llvm::zip_equal(getBody().front().getArguments(), argsTuple))
+ // Note: each blockArg's handle gets associated with just a single element
+ // from the corresponding target's payload.
+ if (failed(state.mapBlockArgument(blockArg, {argument})))
----------------
rolfmorel wrote:
Good call, have adopted this. The only disadvantage appears to be the transposed accesses, though the code is much simpler now.
https://github.com/llvm/llvm-project/pull/93705
More information about the Mlir-commits
mailing list