[Mlir-commits] [mlir] [mlir][Transform] Extend transform.foreach to take multiple arguments (PR #93705)

Rolf Morel llvmlistbot at llvm.org
Thu May 30 04:00:47 PDT 2024


================
@@ -1409,28 +1456,44 @@ transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
         return result;
     }
 
-    // Append yielded payload ops to result list (if any).
-    for (unsigned i = 0; i < getNumResults(); ++i) {
-      auto yieldedOps = state.getPayloadOps(getYieldOp().getOperand(i));
-      resultOps[i].append(yieldedOps.begin(), yieldedOps.end());
-    }
+    // Append yielded payloads to results.
+    auto yieldOperands = getYieldOp().getOperands();
+    for (auto &&[result, yieldOperand, resTuple] :
+         llvm::zip_equal(getResults(), yieldOperands, zippedResults))
+      // NB: each iteration we add any number of ops/vals/params to an opresult.
+      if (isa<TransformHandleTypeInterface>(result.getType()))
+        llvm::append_range(resTuple, state.getPayloadOps(yieldOperand));
+      else if (isa<TransformValueHandleTypeInterface>(result.getType()))
+        llvm::append_range(resTuple, state.getPayloadValues(yieldOperand));
+      else if (isa<TransformParamTypeInterface>(result.getType()))
+        llvm::append_range(resTuple, state.getParams(yieldOperand));
+      else
+        return emitDefiniteFailure()
+               << "unhandled handle type " << result.getType();
   }
 
   for (unsigned i = 0; i < getNumResults(); ++i)
-    results.set(llvm::cast<OpResult>(getResult(i)), resultOps[i]);
+    results.setMappedValues(cast<OpResult>(getResult(i)), zippedResults[i]);
 
   return DiagnosedSilenceableFailure::success();
 }
 
 void transform::ForeachOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  BlockArgument iterVar = getIterationVariable();
-  if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
-        return isHandleConsumed(iterVar, cast<TransformOpInterface>(&op));
-      })) {
-    consumesHandle(getTarget(), effects);
-  } else {
-    onlyReadsHandle(getTarget(), effects);
+
+  // NB: this `zip` should be `zip_equal` - while this op's verifier catches
+  // arity errors, this method might get called before/in absence of `verify()`.
----------------
rolfmorel wrote:

As far as I understand `zip_shortest` is not a name of any available function and the semantics of `zip` _is_ `zip_shortest`. Let me know if that's not right though.

https://github.com/llvm/llvm-project/pull/93705


More information about the Mlir-commits mailing list