[Mlir-commits] [mlir] [mlir][Transform] Extend transform.foreach to take multiple arguments (PR #93705)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Thu May 30 00:59:48 PDT 2024


================
@@ -1391,15 +1391,62 @@ DiagnosedSilenceableFailure
 transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
                             transform::TransformResults &results,
                             transform::TransformState &state) {
-  SmallVector<SmallVector<Operation *>> resultOps(getNumResults(), {});
-  // Store payload ops in a vector because ops may be removed from the mapping
-  // by the TrackingRewriter while the iteration is in progress.
-  SmallVector<Operation *> targets =
-      llvm::to_vector(state.getPayloadOps(getTarget()));
-  for (Operation *op : targets) {
+  // Collect the arguments with which to call each iteration of the body.
+  // We store the payload before executing the body as ops may be removed from
+  // the mapping by the TrackingRewriter while the iteration is in progress.
+  SmallVector<SmallVector<MappedValue>> zippedArgs;
+  for (auto firstTarget : getTargets().take_front(1)) // Loop runs at most once.
+    // For each element, init a tuple with which to call the body later on.
+    if (isa<TransformHandleTypeInterface>(firstTarget.getType()))
+      for (auto &op : state.getPayloadOps(firstTarget))
+        zippedArgs.append({{op}}); // NB: append's argument is an init-list.
+    else if (isa<TransformValueHandleTypeInterface>(firstTarget.getType()))
+      for (auto val : state.getPayloadValues(firstTarget))
+        zippedArgs.append({{val}});
+    else if (isa<TransformParamTypeInterface>(firstTarget.getType()))
+      for (auto param : state.getParams(firstTarget))
+        zippedArgs.append({{param}});
+    else
+      return emitDefiniteFailure()
+             << "unhandled handle type " << firstTarget.getType();
+
+  for (auto target : getTargets().drop_front(1)) {
+    // Append each element of payload to the co-indexed body-arguments-as-tuple.
+    size_t payloadSize = 0;
+    if (isa<TransformHandleTypeInterface>(target.getType())) {
+      for (auto op : state.getPayloadOps(target))
+        if (++payloadSize <= zippedArgs.size())
----------------
ftynse wrote:

Please document what's going on here with `payloadSize`, i.e. it will keep track of how many payloads were associated with the target and will later be used to check that all payloads have the same number of elements, without traversing the corresponding list twice.

https://github.com/llvm/llvm-project/pull/93705


More information about the Mlir-commits mailing list