[Mlir-commits] [mlir] [mlir][Transform] Extend transform.foreach to take multiple arguments (PR #93705)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Thu May 30 00:59:48 PDT 2024


================
@@ -1391,15 +1391,62 @@ DiagnosedSilenceableFailure
 transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
                             transform::TransformResults &results,
                             transform::TransformState &state) {
-  SmallVector<SmallVector<Operation *>> resultOps(getNumResults(), {});
-  // Store payload ops in a vector because ops may be removed from the mapping
-  // by the TrackingRewriter while the iteration is in progress.
-  SmallVector<Operation *> targets =
-      llvm::to_vector(state.getPayloadOps(getTarget()));
-  for (Operation *op : targets) {
+  // Collect the arguments with which to call each iteration of the body.
+  // We store the payload before executing the body as ops may be removed from
+  // the mapping by the TrackingRewriter while the iteration is in progress.
+  SmallVector<SmallVector<MappedValue>> zippedArgs;
+  for (auto firstTarget : getTargets().take_front(1)) // Loop runs at most once.
+    // For each element, init a tuple with which to call the body later on.
+    if (isa<TransformHandleTypeInterface>(firstTarget.getType()))
+      for (auto &op : state.getPayloadOps(firstTarget))
+        zippedArgs.append({{op}}); // NB: append's argument is an init-list.
+    else if (isa<TransformValueHandleTypeInterface>(firstTarget.getType()))
+      for (auto val : state.getPayloadValues(firstTarget))
+        zippedArgs.append({{val}});
+    else if (isa<TransformParamTypeInterface>(firstTarget.getType()))
+      for (auto param : state.getParams(firstTarget))
+        zippedArgs.append({{param}});
+    else
+      return emitDefiniteFailure()
+             << "unhandled handle type " << firstTarget.getType();
+
+  for (auto target : getTargets().drop_front(1)) {
+    // Append each element of payload to the co-indexed body-arguments-as-tuple.
+    size_t payloadSize = 0;
+    if (isa<TransformHandleTypeInterface>(target.getType())) {
+      for (auto op : state.getPayloadOps(target))
+        if (++payloadSize <= zippedArgs.size())
+          zippedArgs[payloadSize - 1].append({op});
+    } else if (isa<TransformValueHandleTypeInterface>(target.getType())) {
+      for (auto val : state.getPayloadValues(target))
+        if (++payloadSize <= zippedArgs.size())
+          zippedArgs[payloadSize - 1].append({val});
+    } else if (isa<TransformParamTypeInterface>(target.getType())) {
+      for (auto param : state.getParams(target))
+        if (++payloadSize <= zippedArgs.size())
+          zippedArgs[payloadSize - 1].append({param});
+    } else
+      return emitDefiniteFailure()
+             << "unhandled handle type " << target.getType();
+
+    if (payloadSize != zippedArgs.size())
+      return emitSilenceableError()
+             << "payload size of prior targets (" << zippedArgs.size()
+             << ") differs from payload size (" << payloadSize << ") of target "
+             << target;
+  }
+
+  // For each arguments-as-tuple collected up above, execute the body region.
+  SmallVector<SmallVector<MappedValue>> zippedResults(getNumResults(), {});
+  for (SmallVector<MappedValue> &argsTuple : zippedArgs) {
     auto scope = state.make_region_scope(getBody());
-    if (failed(state.mapBlockArguments(getIterationVariable(), {op})))
-      return DiagnosedSilenceableFailure::definiteFailure();
+    // Set up arguments to the region's block.
+    for (auto &&[blockArg, argument] :
+         llvm::zip_equal(getBody().front().getArguments(), argsTuple))
+      // Note: each blockArg's handle gets associated with just a single element
+      // from the corresponding target's payload.
+      if (failed(state.mapBlockArgument(blockArg, {argument})))
----------------
ftynse wrote:

I wonder if the entire dance with transposition the payload "matrix" is worth it. Can't we just call `prepareValueMappings` on the operands, check that all nested vectors have the same length, and then just index with `[j][i]` using counted loops. I'm not convinced that will be less efficient (transposed access, but using less memory and most of these matrices easily fit in cache) and will most likely be easier to read.

https://github.com/llvm/llvm-project/pull/93705


More information about the Mlir-commits mailing list