[Mlir-commits] [mlir] [mlir][Transform] Extend transform.foreach to take multiple arguments (PR #93705)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Thu May 30 00:59:49 PDT 2024


================
@@ -1391,15 +1391,62 @@ DiagnosedSilenceableFailure
 transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
                             transform::TransformResults &results,
                             transform::TransformState &state) {
-  SmallVector<SmallVector<Operation *>> resultOps(getNumResults(), {});
-  // Store payload ops in a vector because ops may be removed from the mapping
-  // by the TrackingRewriter while the iteration is in progress.
-  SmallVector<Operation *> targets =
-      llvm::to_vector(state.getPayloadOps(getTarget()));
-  for (Operation *op : targets) {
+  // Collect the arguments with which to call each iteration of the body.
+  // We store the payload before executing the body as ops may be removed from
+  // the mapping by the TrackingRewriter while the iteration is in progress.
+  SmallVector<SmallVector<MappedValue>> zippedArgs;
+  for (auto firstTarget : getTargets().take_front(1)) // Loop runs at most once.
+    // For each element, init a tuple with which to call the body later on.
+    if (isa<TransformHandleTypeInterface>(firstTarget.getType()))
+      for (auto &op : state.getPayloadOps(firstTarget))
+        zippedArgs.append({{op}}); // NB: append's argument is an init-list.
+    else if (isa<TransformValueHandleTypeInterface>(firstTarget.getType()))
+      for (auto val : state.getPayloadValues(firstTarget))
+        zippedArgs.append({{val}});
+    else if (isa<TransformParamTypeInterface>(firstTarget.getType()))
+      for (auto param : state.getParams(firstTarget))
+        zippedArgs.append({{param}});
+    else
+      return emitDefiniteFailure()
+             << "unhandled handle type " << firstTarget.getType();
+
+  for (auto target : getTargets().drop_front(1)) {
+    // Append each element of payload to the co-indexed body-arguments-as-tuple.
+    size_t payloadSize = 0;
+    if (isa<TransformHandleTypeInterface>(target.getType())) {
+      for (auto op : state.getPayloadOps(target))
+        if (++payloadSize <= zippedArgs.size())
+          zippedArgs[payloadSize - 1].append({op});
+    } else if (isa<TransformValueHandleTypeInterface>(target.getType())) {
+      for (auto val : state.getPayloadValues(target))
+        if (++payloadSize <= zippedArgs.size())
+          zippedArgs[payloadSize - 1].append({val});
+    } else if (isa<TransformParamTypeInterface>(target.getType())) {
+      for (auto param : state.getParams(target))
+        if (++payloadSize <= zippedArgs.size())
+          zippedArgs[payloadSize - 1].append({param});
+    } else
+      return emitDefiniteFailure()
+             << "unhandled handle type " << target.getType();
----------------
ftynse wrote:

```suggestion
    } else {
      return emitDefiniteFailure()
             << "unhandled handle type " << target.getType();
    }
```

https://github.com/llvm/llvm-project/pull/93705


More information about the Mlir-commits mailing list