[Mlir-commits] [mlir] [mlir][Transform] Extend transform.foreach to take multiple arguments (PR #93705)

Rolf Morel llvmlistbot at llvm.org
Thu May 30 09:57:17 PDT 2024


https://github.com/rolfmorel updated https://github.com/llvm/llvm-project/pull/93705

>From e2f72148ead3871e589cd5f1283d384d7849bdd2 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at huawei.com>
Date: Thu, 23 May 2024 19:48:23 +0800
Subject: [PATCH 1/2] [mlir][Transform] Extend transform.foreach to take
 multiple arguments

Changes transform.foreach's interface to take multiple arguments, e.g.
transform.foreach %ops1, %ops2, %params : ... { ^bb0(%op1, %op2, %param): BODY }
The semantics are that the payloads for these handles get iterated over as if
the payloads have been zipped-up together - BODY gets executed once for each
such tuple. The documentation explains that this implementation requires that
the payloads have the same length.

This change also enables the target argument(s) to be any op/value/param handle.

The added test cases demonstrate some use cases for this change.
---
 .../mlir/Dialect/Transform/IR/TransformOps.td |  57 +++----
 .../lib/Dialect/Transform/IR/TransformOps.cpp | 141 ++++++++++++++----
 .../Dialect/Linalg/multisize-tiling-full.mlir |  21 ++-
 .../SCF/transform-loop-fuse-sibling.mlir      |  73 +++++++++
 mlir/test/Dialect/Transform/ops.mlir          |  22 ++-
 .../Dialect/Transform/test-interpreter.mlir   |  85 +++++++++++
 6 files changed, 332 insertions(+), 67 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 77048a28d7510..e61cd77339ac6 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -614,43 +614,48 @@ def ForeachOp : TransformDialectOp<"foreach",
          "getSuccessorRegions", "getEntrySuccessorOperands"]>,
      SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">
     ]> {
-  let summary = "Executes the body for each payload op";
+  let summary = "Executes the body for each element of the payload";
   let description = [{
-    This op has exactly one region with exactly one block ("body"). The body is
-    executed for each payload op that is associated to the target operand in an
-    unbatched fashion. I.e., the block argument ("iteration variable") is always
-    mapped to exactly one payload op.
-
-    This op always reads the target handle. Furthermore, it consumes the handle
-    if there is a transform op in the body that consumes the iteration variable.
-    This op does not return anything.
-
-    The transformations inside the body are applied in order of their
-    appearance. During application, if any transformation in the sequence fails,
-    the entire sequence fails immediately leaving the payload IR in potentially
-    invalid state, i.e., this operation offers no transformation rollback
-    capabilities.
-
-    This op generates as many handles as the terminating YieldOp has operands.
-    For each result, the payload ops of the corresponding YieldOp operand are
-    merged and mapped to the same resulting handle.
+    Execute the op's body - its single region block - exactly once per
+    element of the payload associated to a target handle. The body's
+    transformations are applied in order of appearance until reaching the
+    (implicit) YieldOp terminator.
+
+    Each iteration gets executed by co-indexing the payloads of the arguments
+    and mapping the body's arguments to these tuples, as though iterating over
+    the zipped together `targets`. As such, in each iteration, the size of the
+    payload of each of the body's block arguments is exactly one.
+
+    This op always reads the target handles. Furthermore, it consumes a handle
+    if there is a transform op in the body that consumes the corresponding
+    block argument. Handles can point to ops, values, or parameters.
+
+    #### Return Modes
+
+    This op produces as many result handles as the body's terminating YieldOp
+    has operands. For each result, the payloads of the corresponding YieldOp
+    operand are merged and mapped to the same resulting handle.
+
+    If the target handles do not associate payloads of the same size, or they
+    do not associate any payload at all, a silencable failure will be generated.
+
+    During application, if any transformation in the sequence fails, the entire
+    sequence fails immediately with the same failure, leaving the payload IR in
+    a potentially invalid state, i.e., this operation offers no transformation
+    rollback capabilities.
   }];
 
-  let arguments = (ins TransformHandleTypeInterface:$target);
-  let results = (outs Variadic<TransformHandleTypeInterface>:$results);
+  let arguments = (ins Variadic<Transform_AnyHandleOrParamType>:$targets);
+  let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
   let regions = (region SizedRegion<1>:$body);
   let assemblyFormat =
-    "$target `:` type($target) (`->` type($results)^)? $body attr-dict";
+    "$targets `:` type($targets) (`->` type($results)^)? $body attr-dict";
   let hasVerifier = 1;
 
   let extraClassDeclaration = [{
     /// Allow the dialect prefix to be omitted.
     static StringRef getDefaultDialect() { return "transform"; }
 
-    BlockArgument getIterationVariable() {
-      return getBody().front().getArgument(0);
-    }
-
     transform::YieldOp getYieldOp();
   }];
 }
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 247759e21efb1..b58f4e0672fc2 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1391,15 +1391,62 @@ DiagnosedSilenceableFailure
 transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
                             transform::TransformResults &results,
                             transform::TransformState &state) {
-  SmallVector<SmallVector<Operation *>> resultOps(getNumResults(), {});
-  // Store payload ops in a vector because ops may be removed from the mapping
-  // by the TrackingRewriter while the iteration is in progress.
-  SmallVector<Operation *> targets =
-      llvm::to_vector(state.getPayloadOps(getTarget()));
-  for (Operation *op : targets) {
+  // Collect the arguments with which to call each iteration of the body.
+  // We store the payload before executing the body as ops may be removed from
+  // the mapping by the TrackingRewriter while the iteration is in progress.
+  SmallVector<SmallVector<MappedValue>> zippedArgs;
+  for (auto firstTarget : getTargets().take_front(1)) // Loop runs at most once.
+    // For each element, init a tuple with which to call the body later on.
+    if (isa<TransformHandleTypeInterface>(firstTarget.getType()))
+      for (auto &op : state.getPayloadOps(firstTarget))
+        zippedArgs.append({{op}}); // NB: append's argument is an init-list.
+    else if (isa<TransformValueHandleTypeInterface>(firstTarget.getType()))
+      for (auto val : state.getPayloadValues(firstTarget))
+        zippedArgs.append({{val}});
+    else if (isa<TransformParamTypeInterface>(firstTarget.getType()))
+      for (auto param : state.getParams(firstTarget))
+        zippedArgs.append({{param}});
+    else
+      return emitDefiniteFailure()
+             << "unhandled handle type " << firstTarget.getType();
+
+  for (auto target : getTargets().drop_front(1)) {
+    // Append each element of payload to the co-indexed body-arguments-as-tuple.
+    size_t payloadSize = 0;
+    if (isa<TransformHandleTypeInterface>(target.getType())) {
+      for (auto op : state.getPayloadOps(target))
+        if (++payloadSize <= zippedArgs.size())
+          zippedArgs[payloadSize - 1].append({op});
+    } else if (isa<TransformValueHandleTypeInterface>(target.getType())) {
+      for (auto val : state.getPayloadValues(target))
+        if (++payloadSize <= zippedArgs.size())
+          zippedArgs[payloadSize - 1].append({val});
+    } else if (isa<TransformParamTypeInterface>(target.getType())) {
+      for (auto param : state.getParams(target))
+        if (++payloadSize <= zippedArgs.size())
+          zippedArgs[payloadSize - 1].append({param});
+    } else
+      return emitDefiniteFailure()
+             << "unhandled handle type " << target.getType();
+
+    if (payloadSize != zippedArgs.size())
+      return emitSilenceableError()
+             << "payload size of prior targets (" << zippedArgs.size()
+             << ") differs from payload size (" << payloadSize << ") of target "
+             << target;
+  }
+
+  // For each arguments-as-tuple collected up above, execute the body region.
+  SmallVector<SmallVector<MappedValue>> zippedResults(getNumResults(), {});
+  for (SmallVector<MappedValue> &argsTuple : zippedArgs) {
     auto scope = state.make_region_scope(getBody());
-    if (failed(state.mapBlockArguments(getIterationVariable(), {op})))
-      return DiagnosedSilenceableFailure::definiteFailure();
+    // Set up arguments to the region's block.
+    for (auto &&[blockArg, argument] :
+         llvm::zip_equal(getBody().front().getArguments(), argsTuple))
+      // Note: each blockArg's handle gets associated with just a single element
+      // from the corresponding target's payload.
+      if (failed(state.mapBlockArgument(blockArg, {argument})))
+        return DiagnosedSilenceableFailure::definiteFailure();
 
     // Execute loop body.
     for (Operation &transform : getBody().front().without_terminator()) {
@@ -1409,28 +1456,44 @@ transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
         return result;
     }
 
-    // Append yielded payload ops to result list (if any).
-    for (unsigned i = 0; i < getNumResults(); ++i) {
-      auto yieldedOps = state.getPayloadOps(getYieldOp().getOperand(i));
-      resultOps[i].append(yieldedOps.begin(), yieldedOps.end());
-    }
+    // Append yielded payloads to results.
+    auto yieldOperands = getYieldOp().getOperands();
+    for (auto &&[result, yieldOperand, resTuple] :
+         llvm::zip_equal(getResults(), yieldOperands, zippedResults))
+      // NB: each iteration we add any number of ops/vals/params to an opresult.
+      if (isa<TransformHandleTypeInterface>(result.getType()))
+        llvm::append_range(resTuple, state.getPayloadOps(yieldOperand));
+      else if (isa<TransformValueHandleTypeInterface>(result.getType()))
+        llvm::append_range(resTuple, state.getPayloadValues(yieldOperand));
+      else if (isa<TransformParamTypeInterface>(result.getType()))
+        llvm::append_range(resTuple, state.getParams(yieldOperand));
+      else
+        return emitDefiniteFailure()
+               << "unhandled handle type " << result.getType();
   }
 
   for (unsigned i = 0; i < getNumResults(); ++i)
-    results.set(llvm::cast<OpResult>(getResult(i)), resultOps[i]);
+    results.setMappedValues(cast<OpResult>(getResult(i)), zippedResults[i]);
 
   return DiagnosedSilenceableFailure::success();
 }
 
 void transform::ForeachOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  BlockArgument iterVar = getIterationVariable();
-  if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
-        return isHandleConsumed(iterVar, cast<TransformOpInterface>(&op));
-      })) {
-    consumesHandle(getTarget(), effects);
-  } else {
-    onlyReadsHandle(getTarget(), effects);
+
+  // NB: this `zip` should be `zip_equal` - while this op's verifier catches
+  // arity errors, this method might get called before/in absence of `verify()`.
+  for (auto &&[target, blockArg] :
+       llvm::zip(getTargets(), getBody().front().getArguments())) {
+    BlockArgument blockArgument = blockArg;
+    if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
+          return isHandleConsumed(blockArgument,
+                                  cast<TransformOpInterface>(&op));
+        })) {
+      consumesHandle(target, effects);
+    } else {
+      onlyReadsHandle(target, effects);
+    }
   }
 
   if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
@@ -1463,6 +1526,7 @@ void transform::ForeachOp::getSuccessorRegions(
 
 OperandRange
 transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+  // TODO: figure out how to update the comment & impl if necessary
   // The iteration variable op handle is mapped to a subset (one op to be
   // precise) of the payload ops of the ForeachOp operand.
   assert(point == getBody() && "unexpected region index");
@@ -1474,14 +1538,33 @@ transform::YieldOp transform::ForeachOp::getYieldOp() {
 }
 
 LogicalResult transform::ForeachOp::verify() {
-  auto yieldOp = getYieldOp();
-  if (getNumResults() != yieldOp.getNumOperands())
-    return emitOpError() << "expects the same number of results as the "
-                            "terminator has operands";
-  for (Value v : yieldOp.getOperands())
-    if (!llvm::isa<TransformHandleTypeInterface>(v.getType()))
-      return yieldOp->emitOpError("expects operands to have types implementing "
-                                  "TransformHandleTypeInterface");
+  for (auto [targetOpt, bodyArgOpt] :
+       llvm::zip_longest(getTargets(), getBody().front().getArguments())) {
+    if (!targetOpt || !bodyArgOpt)
+      return emitOpError() << "expects the same number of targets as the body "
+                              "has block arguments";
+    auto target = targetOpt.value();
+    if (target.getType() != bodyArgOpt.value().getType() ||
+        !isa<TransformHandleTypeInterface, TransformValueHandleTypeInterface,
+             TransformParamTypeInterface>(target.getType()))
+      return emitOpError(
+          "expects co-indexed targets and the body's "
+          "block arguments to have the same op/value/param type");
+  }
+
+  for (auto [resultOpt, yieldOperandOpt] :
+       llvm::zip_longest(getResults(), getYieldOp().getOperands())) {
+    if (!resultOpt || !yieldOperandOpt)
+      return emitOpError() << "expects the same number of results as the "
+                              "yield terminator has operands";
+    auto result = resultOpt.value();
+    if (result.getType() != yieldOperandOpt.value().getType() ||
+        !isa<TransformHandleTypeInterface, TransformValueHandleTypeInterface,
+             TransformParamTypeInterface>(result.getType()))
+      return emitOpError("expects co-indexed results and yield "
+                         "operands to have the same op/value/param type");
+  }
+
   return success();
 }
 
diff --git a/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
index 15b24b56608e3..51332ffce03d1 100644
--- a/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
+++ b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
@@ -6,15 +6,17 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3} : (!transform.any_op) -> !transform.any_op
-    %t:3 = transform.structured.multitile_sizes %0 { dimension = 1, target_size = 10} : (!transform.any_op) -> !transform.any_op
     %2:2 = transform.structured.split %0 after %1#2 { dimension = 0 } : !transform.any_op, !transform.any_op
     %3:2 = transform.structured.tile_using_for %2#0 tile_sizes [%1#0] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
     %4:2 = transform.structured.tile_using_for %2#1 tile_sizes [%1#1] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
     %5 = transform.merge_handles %3#0, %4#0 : !transform.any_op
-    %tt:3 = transform.replicate num(%5) %t#0, %t#1, %t#2 : !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op
-    %6:2 = transform.structured.split %5 after %tt#2 { dimension = 1 } : !transform.any_op, !transform.any_op
-    transform.structured.tile_using_for %6#0 tile_sizes [0, %tt#0] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-    transform.structured.tile_using_for %6#1 tile_sizes [0, %tt#1] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.foreach %5 : !transform.any_op {
+    ^bb0(%inner_linalg: !transform.any_op):
+      %low, %high, %split_point = transform.structured.multitile_sizes %inner_linalg { dimension = 1, target_size = 10} : (!transform.any_op) -> !transform.any_op
+      %inner_linalg_low, %inner_linalg_high = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.any_op
+      transform.structured.tile_using_for %inner_linalg_low tile_sizes [0, %low] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+      transform.structured.tile_using_for %inner_linalg_high tile_sizes [0, %high] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    }
     transform.yield
   }
 }
@@ -114,9 +116,12 @@ module attributes {transform.with_named_sequence} {
     %4:2 = transform.structured.tile_using_for %2#1 tile_sizes [%1#1] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
     %5 = transform.merge_handles %3#0, %4#0 : !transform.any_op
     %tt:3 = transform.replicate num(%5) %t#0, %t#1, %t#2 : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>
-    %6:2 = transform.structured.split %5 after %tt#2 { dimension = 1 } : !transform.any_op, !transform.param<i64>
-    transform.structured.tile_using_for %6#0 tile_sizes [0, %tt#0] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
-    transform.structured.tile_using_for %6#1 tile_sizes [0, %tt#1] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
+    transform.foreach %5, %tt#0, %tt#1, %tt#2 : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64> {
+    ^bb0(%inner_linalg: !transform.any_op, %low: !transform.param<i64>, %high: !transform.param<i64>, %split_point: !transform.param<i64>):
+      %inner_linalg_low, %inner_linalg_high = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.param<i64>
+      transform.structured.tile_using_for %inner_linalg_low tile_sizes [0, %low] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
+      transform.structured.tile_using_for %inner_linalg_high tile_sizes [0, %high] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
+    }
     transform.yield
   }
 }
diff --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
index 0f51b1cdbe0cf..54dd2bdf953ca 100644
--- a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
+++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
@@ -328,3 +328,76 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+// -----
+
+// CHECK: func.func @foreach_loop_pair_fuse([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
+func.func @foreach_loop_pair_fuse(%arg1: tensor<128xf32>, %arg2: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>, tensor<128xf32>, tensor<128xf32>) {
+  // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
+  // CHECK-DAG: [[C16:%.*]] = arith.constant 16 : index
+  // CHECK-DAG: [[C128:%.*]] = arith.constant 128 : index
+  // CHECK-DAG: [[ZERO:%.*]] = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  %c128 = arith.constant 128 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  // CHECK: [[RST:%.*]]:2 = scf.for [[IV:%.*]] = [[C0]] to [[C128]] step [[C16]] iter_args([[IB0:%.*]] = [[B]], [[IB1:%.*]] = [[B]]) {{.*}}
+  %1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) {
+  // CHECK-DAG:   [[ASLICE:%.*]] = vector.transfer_read [[A]][[[IV]]], [[ZERO]]
+  // CHECK-DAG:   [[SLICE0:%.*]] = vector.transfer_read [[IB0]][[[IV]]], [[ZERO]]
+  // CHECK:       [[OUT1:%.*]] = arith.addf [[SLICE0]], [[ASLICE]]
+  // CHECK-NEXT:  [[WRT0:%.*]] = vector.transfer_write [[OUT1]], [[IB0]][[[IV]]]
+    %2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %5 = arith.addf %3, %2 : vector<16xf32>
+    %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+    scf.yield %6 : tensor<128xf32>
+  } {target_loops}
+  %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) {
+  // CHECK-DAG:   [[SLICE1:%.*]] = vector.transfer_read [[IB1]][[[IV]]], [[ZERO]]
+  // CHECK:       [[OUT2:%.*]] = arith.addf [[SLICE1]], [[ASLICE]]
+  // CHECK-NEXT:  [[WRT1:%.*]] = vector.transfer_write [[OUT2]], [[IB1]][[[IV]]]
+    %dup2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %dup5 = arith.addf %dup3, %dup2 : vector<16xf32>
+    %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+  // CHECK: scf.yield [[WRT0]], [[WRT1]] : {{.*}}
+    scf.yield %dup6 : tensor<128xf32>
+  } {source_loops}
+  %2 = scf.for %arg3 = %c0 to %c128 step %c32 iter_args(%arg4 = %arg2) -> (tensor<128xf32>)  {
+  // CHECK-DAG:   [[ASLICE:%.*]] = vector.transfer_read [[A]][[[IV]]], [[ZERO]]
+  // CHECK-DAG:   [[SLICE0:%.*]] = vector.transfer_read [[IB0]][[[IV]]], [[ZERO]]
+  // CHECK:       [[OUT1:%.*]] = arith.addf [[SLICE0]], [[ASLICE]]
+  // CHECK-NEXT:  [[WRT0:%.*]] = vector.transfer_write [[OUT1]], [[IB0]][[[IV]]]
+    %2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<32xf32>
+    %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<32xf32>
+    %5 = arith.addf %3, %2 : vector<32xf32>
+    %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<32xf32>, tensor<128xf32>
+    scf.yield %6 : tensor<128xf32>
+  } {target_loops}
+  %dup2 = scf.for %arg3 = %c0 to %c128 step %c32 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) {
+  // CHECK-DAG:   [[SLICE1:%.*]] = vector.transfer_read [[IB1]][[[IV]]], [[ZERO]]
+  // CHECK:       [[OUT2:%.*]] = arith.addf [[SLICE1]], [[ASLICE]]
+  // CHECK-NEXT:  [[WRT1:%.*]] = vector.transfer_write [[OUT2]], [[IB1]][[[IV]]]
+    %dup2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<32xf32>
+    %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<32xf32>
+    %dup5 = arith.addf %dup3, %dup2 : vector<32xf32>
+    %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<32xf32>, tensor<128xf32>
+  // CHECK: scf.yield [[WRT0]], [[WRT1]] : {{.*}}
+    scf.yield %dup6 : tensor<128xf32>
+  } {source_loops}
+  return %1, %dup1, %2, %dup2 : tensor<128xf32>, tensor<128xf32>, tensor<128xf32>, tensor<128xf32>
+}
+
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %target_loops = transform.structured.match ops{["scf.for"]} attributes {target_loops} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %source_loops = transform.structured.match ops{["scf.for"]} attributes {source_loops} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.foreach %target_loops, %source_loops : !transform.any_op, !transform.any_op {
+    ^bb0(%target_loop: !transform.any_op, %source_loop: !transform.any_op):
+      %fused = transform.loop.fuse_sibling %target_loop into %source_loop : (!transform.any_op,!transform.any_op) ->  !transform.any_op
+    }
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Transform/ops.mlir b/mlir/test/Dialect/Transform/ops.mlir
index b03a9f4d760d2..e9baffde262fa 100644
--- a/mlir/test/Dialect/Transform/ops.mlir
+++ b/mlir/test/Dialect/Transform/ops.mlir
@@ -68,11 +68,25 @@ transform.sequence failures(propagate) {
 }
 
 // CHECK: transform.sequence
-// CHECK: foreach
 transform.sequence failures(propagate) {
-^bb0(%arg0: !transform.any_op):
-  transform.foreach %arg0 : !transform.any_op {
-  ^bb1(%arg1: !transform.any_op):
+^bb0(%op0: !transform.any_op, %val0: !transform.any_value, %par0: !transform.any_param):
+  // CHECK: foreach %{{.*}} : !transform.any_op
+  transform.foreach %op0 : !transform.any_op {
+  ^bb1(%op1: !transform.any_op):
+  }
+  // CHECK: foreach %{{.*}} : !transform.any_op, !transform.any_value, !transform.any_param
+  transform.foreach %op0, %val0, %par0 : !transform.any_op, !transform.any_value, !transform.any_param {
+  ^bb1(%op1: !transform.any_op, %val1: !transform.any_value, %par1: !transform.any_param):
+  }
+  // CHECK: foreach %{{.*}} : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_op
+  transform.foreach %op0, %val0, %par0 : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_op {
+  ^bb1(%op1: !transform.any_op, %val1: !transform.any_value, %par1: !transform.any_param):
+    transform.yield %op1 : !transform.any_op
+  }
+  // CHECK: foreach %{{.*}} : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_param, !transform.any_value
+  transform.foreach %op0, %val0, %par0 : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_param, !transform.any_value {
+  ^bb1(%op1: !transform.any_op, %val1: !transform.any_value, %par1: !transform.any_param):
+    transform.yield %par1, %val1 : !transform.any_param, !transform.any_value
   }
 }
 
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index b6850e2024d53..0bdd6638b3e55 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -830,6 +830,91 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %results, %types = transform.foreach %0 : !transform.any_op -> !transform.any_value, !transform.any_param {
+    ^bb0(%op0 : !transform.any_op):
+      %result = transform.get_result %op0[0] : (!transform.any_op) -> !transform.any_value
+      %type = transform.get_type elemental %result  : (!transform.any_value) -> !transform.any_param
+      transform.yield %result, %type : !transform.any_value, !transform.any_param
+    }
+    transform.debug.emit_remark_at %results, "result selected" : !transform.any_value
+    transform.debug.emit_param_as_remark %types, "elemental types" at %0 : !transform.any_param, !transform.any_op
+
+    transform.yield
+  }
+}
+
+func.func @payload(%lhs: tensor<10x20xf16>,
+                   %rhs: tensor<20x15xf32>) -> (tensor<10x15xf64>, tensor<10x15xf32>) {
+  %cst64 = arith.constant 0.0 : f64
+  %empty64 = tensor.empty() : tensor<10x15xf64>
+  %fill64 = linalg.fill ins(%cst64 : f64) outs(%empty64 : tensor<10x15xf64>) -> tensor<10x15xf64>
+  // expected-remark @below {{result selected}}
+  // expected-note @below {{value handle points to an op result #0}}
+  // expected-remark @below {{elemental types f64, f32}}
+  %result64 = linalg.matmul ins(%lhs, %rhs: tensor<10x20xf16>, tensor<20x15xf32>)
+                         outs(%fill64: tensor<10x15xf64>) -> tensor<10x15xf64>
+
+  %cst32 = arith.constant 0.0 : f32
+  %empty32 = tensor.empty() : tensor<10x15xf32>
+  %fill32 = linalg.fill ins(%cst32 : f32) outs(%empty32 : tensor<10x15xf32>) -> tensor<10x15xf32>
+  // expected-remark @below {{result selected}}
+  // expected-note @below {{value handle points to an op result #0}}
+  // expected-remark @below {{elemental types f64, f32}}
+  %result32 = linalg.matmul ins(%lhs, %rhs: tensor<10x20xf16>, tensor<20x15xf32>)
+                           outs(%fill32: tensor<10x15xf32>) -> tensor<10x15xf32>
+
+  return %result64, %result32 : tensor<10x15xf64>, tensor<10x15xf32>
+
+}
+
+// -----
+
+func.func @two_const_ops() {
+  %0 = arith.constant 0 : index
+  %1 = arith.constant 1 : index
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %two_ops = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %one_param = transform.param.constant 1 : i32 -> !transform.test_dialect_param
+    // expected-error @below {{payload size of prior targets (2) differs from payload size (1) of target}}
+    transform.foreach %two_ops, %one_param : !transform.any_op, !transform.test_dialect_param {
+    ^bb2(%op: !transform.any_op, %param: !transform.test_dialect_param):
+    }
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @one_const_op() {
+  %0 = arith.constant 0 : index
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %one_op = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %one_val = transform.test_produce_value_handle_to_self_operand %one_op : (!transform.any_op) -> !transform.any_value
+    %param_one = transform.param.constant 1 : i32 -> !transform.test_dialect_param
+    %param_two = transform.param.constant 2 : i32 -> !transform.test_dialect_param
+    %two_params = transform.merge_handles %param_one, %param_two : !transform.test_dialect_param
+
+    // expected-error @below {{payload size of prior targets (1) differs from payload size (2) of target}}
+    transform.foreach %one_val, %one_op, %two_params : !transform.any_value, !transform.any_op, !transform.test_dialect_param {
+    ^bb2(%val: !transform.any_value, %op: !transform.any_op, %param: !transform.test_dialect_param):
+    }
+    transform.yield
+  }
+}
+
+// -----
+
 // CHECK-LABEL: func @consume_in_foreach()
 //  CHECK-NEXT:   return
 func.func @consume_in_foreach() {

>From a5f147e53112db3636c7713064c06e36958f58af Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at huawei.com>
Date: Fri, 31 May 2024 00:12:03 +0800
Subject: [PATCH 2/2] Rewrite and fixes based on @ftynse's review.

---
 .../lib/Dialect/Transform/IR/TransformOps.cpp | 103 ++++++------------
 .../Dialect/Transform/test-interpreter.mlir   |   4 +-
 2 files changed, 37 insertions(+), 70 deletions(-)

diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index b58f4e0672fc2..1a7ec030f0eb1 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1391,76 +1391,51 @@ DiagnosedSilenceableFailure
 transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
                             transform::TransformResults &results,
                             transform::TransformState &state) {
-  // Collect the arguments with which to call each iteration of the body.
-  // We store the payload before executing the body as ops may be removed from
-  // the mapping by the TrackingRewriter while the iteration is in progress.
-  SmallVector<SmallVector<MappedValue>> zippedArgs;
-  for (auto firstTarget : getTargets().take_front(1)) // Loop runs at most once.
-    // For each element, init a tuple with which to call the body later on.
-    if (isa<TransformHandleTypeInterface>(firstTarget.getType()))
-      for (auto &op : state.getPayloadOps(firstTarget))
-        zippedArgs.append({{op}}); // NB: append's argument is an init-list.
-    else if (isa<TransformValueHandleTypeInterface>(firstTarget.getType()))
-      for (auto val : state.getPayloadValues(firstTarget))
-        zippedArgs.append({{val}});
-    else if (isa<TransformParamTypeInterface>(firstTarget.getType()))
-      for (auto param : state.getParams(firstTarget))
-        zippedArgs.append({{param}});
-    else
-      return emitDefiniteFailure()
-             << "unhandled handle type " << firstTarget.getType();
-
-  for (auto target : getTargets().drop_front(1)) {
-    // Append each element of payload to the co-indexed body-arguments-as-tuple.
-    size_t payloadSize = 0;
-    if (isa<TransformHandleTypeInterface>(target.getType())) {
-      for (auto op : state.getPayloadOps(target))
-        if (++payloadSize <= zippedArgs.size())
-          zippedArgs[payloadSize - 1].append({op});
-    } else if (isa<TransformValueHandleTypeInterface>(target.getType())) {
-      for (auto val : state.getPayloadValues(target))
-        if (++payloadSize <= zippedArgs.size())
-          zippedArgs[payloadSize - 1].append({val});
-    } else if (isa<TransformParamTypeInterface>(target.getType())) {
-      for (auto param : state.getParams(target))
-        if (++payloadSize <= zippedArgs.size())
-          zippedArgs[payloadSize - 1].append({param});
-    } else
-      return emitDefiniteFailure()
-             << "unhandled handle type " << target.getType();
-
-    if (payloadSize != zippedArgs.size())
+  // We store the payloads before executing the body as ops may be removed from
+  // the mapping by the TrackingRewriter while iteration is in progress.
+  SmallVector<SmallVector<MappedValue>> payloads;
+  detail::prepareValueMappings(payloads, getTargets(), state);
+  size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
+
+  // As we will be "zipping" over them, check all payloads have the same size.
+  for (size_t argIdx = 1; argIdx < payloads.size(); argIdx++) {
+    if (payloads[argIdx].size() != numIterations) {
       return emitSilenceableError()
-             << "payload size of prior targets (" << zippedArgs.size()
-             << ") differs from payload size (" << payloadSize << ") of target "
-             << target;
+             << "prior targets' payload size (" << numIterations
+             << ") differs from payload size (" << payloads[argIdx].size()
+             << ") of target " << getTargets()[argIdx];
+    }
   }
 
-  // For each arguments-as-tuple collected up above, execute the body region.
+  // Start iterating, indexing into payloads to obtain the right arguments to
+  // call the body with - each slice of payloads at the same argument index
+  // corresponding to a tuple to use as the body's block arguments.
+  ArrayRef<BlockArgument> blockArguments = getBody().front().getArguments();
   SmallVector<SmallVector<MappedValue>> zippedResults(getNumResults(), {});
-  for (SmallVector<MappedValue> &argsTuple : zippedArgs) {
+  for (size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) {
     auto scope = state.make_region_scope(getBody());
     // Set up arguments to the region's block.
-    for (auto &&[blockArg, argument] :
-         llvm::zip_equal(getBody().front().getArguments(), argsTuple))
-      // Note: each blockArg's handle gets associated with just a single element
-      // from the corresponding target's payload.
+    for (auto &&[argIdx, blockArg] : llvm::enumerate(blockArguments)) {
+      MappedValue argument = payloads[argIdx][iterIdx];
+      // Note that each blockArg's handle gets associated with just a single
+      // element from the corresponding target's payload.
       if (failed(state.mapBlockArgument(blockArg, {argument})))
         return DiagnosedSilenceableFailure::definiteFailure();
+    }
 
     // Execute loop body.
     for (Operation &transform : getBody().front().without_terminator()) {
       DiagnosedSilenceableFailure result = state.applyTransform(
-          cast<transform::TransformOpInterface>(transform));
+          llvm::cast<transform::TransformOpInterface>(transform));
       if (!result.succeeded())
         return result;
     }
 
-    // Append yielded payloads to results.
-    auto yieldOperands = getYieldOp().getOperands();
+    // Append yielded payloads to corresponding results from prior iterations.
+    OperandRange yieldOperands = getYieldOp().getOperands();
     for (auto &&[result, yieldOperand, resTuple] :
          llvm::zip_equal(getResults(), yieldOperands, zippedResults))
-      // NB: each iteration we add any number of ops/vals/params to an opresult.
+      // NB: each iteration we add any number of ops/vals/params to a result.
       if (isa<TransformHandleTypeInterface>(result.getType()))
         llvm::append_range(resTuple, state.getPayloadOps(yieldOperand));
       else if (isa<TransformValueHandleTypeInterface>(result.getType()))
@@ -1468,19 +1443,18 @@ transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
       else if (isa<TransformParamTypeInterface>(result.getType()))
         llvm::append_range(resTuple, state.getParams(yieldOperand));
       else
-        return emitDefiniteFailure()
-               << "unhandled handle type " << result.getType();
+        assert(false && "unhandled handle type");
   }
 
-  for (unsigned i = 0; i < getNumResults(); ++i)
-    results.setMappedValues(cast<OpResult>(getResult(i)), zippedResults[i]);
+  // Associate the accumulated result payloads to the op's actual results.
+  for (auto &&[result, resPayload] : zip_equal(getResults(), zippedResults))
+    results.setMappedValues(llvm::cast<OpResult>(result), resPayload);
 
   return DiagnosedSilenceableFailure::success();
 }
 
 void transform::ForeachOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-
   // NB: this `zip` should be `zip_equal` - while this op's verifier catches
   // arity errors, this method might get called before/in absence of `verify()`.
   for (auto &&[target, blockArg] :
@@ -1526,9 +1500,8 @@ void transform::ForeachOp::getSuccessorRegions(
 
 OperandRange
 transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) {
-  // TODO: figure out how to update the comment & impl if necessary
-  // The iteration variable op handle is mapped to a subset (one op to be
-  // precise) of the payload ops of the ForeachOp operand.
+  // Each block argument handle is mapped to a subset (one op to be precise)
+  // of the payload of the corresponding `targets` operand of ForeachOp.
   assert(point == getBody() && "unexpected region index");
   return getOperation()->getOperands();
 }
@@ -1543,10 +1516,7 @@ LogicalResult transform::ForeachOp::verify() {
     if (!targetOpt || !bodyArgOpt)
       return emitOpError() << "expects the same number of targets as the body "
                               "has block arguments";
-    auto target = targetOpt.value();
-    if (target.getType() != bodyArgOpt.value().getType() ||
-        !isa<TransformHandleTypeInterface, TransformValueHandleTypeInterface,
-             TransformParamTypeInterface>(target.getType()))
+    if (targetOpt.value().getType() != bodyArgOpt.value().getType())
       return emitOpError(
           "expects co-indexed targets and the body's "
           "block arguments to have the same op/value/param type");
@@ -1557,10 +1527,7 @@ LogicalResult transform::ForeachOp::verify() {
     if (!resultOpt || !yieldOperandOpt)
       return emitOpError() << "expects the same number of results as the "
                               "yield terminator has operands";
-    auto result = resultOpt.value();
-    if (result.getType() != yieldOperandOpt.value().getType() ||
-        !isa<TransformHandleTypeInterface, TransformValueHandleTypeInterface,
-             TransformParamTypeInterface>(result.getType()))
+    if (resultOpt.value().getType() != yieldOperandOpt.value().getType())
       return emitOpError("expects co-indexed results and yield "
                          "operands to have the same op/value/param type");
   }
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 0bdd6638b3e55..4fe2dbedff56e 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -882,7 +882,7 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
     %two_ops = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %one_param = transform.param.constant 1 : i32 -> !transform.test_dialect_param
-    // expected-error @below {{payload size of prior targets (2) differs from payload size (1) of target}}
+    // expected-error @below {{prior targets' payload size (2) differs from payload size (1) of target}}
     transform.foreach %two_ops, %one_param : !transform.any_op, !transform.test_dialect_param {
     ^bb2(%op: !transform.any_op, %param: !transform.test_dialect_param):
     }
@@ -905,7 +905,7 @@ module attributes {transform.with_named_sequence} {
     %param_two = transform.param.constant 2 : i32 -> !transform.test_dialect_param
     %two_params = transform.merge_handles %param_one, %param_two : !transform.test_dialect_param
 
-    // expected-error @below {{payload size of prior targets (1) differs from payload size (2) of target}}
+    // expected-error @below {{prior targets' payload size (1) differs from payload size (2) of target}}
     transform.foreach %one_val, %one_op, %two_params : !transform.any_value, !transform.any_op, !transform.test_dialect_param {
     ^bb2(%val: !transform.any_value, %op: !transform.any_op, %param: !transform.test_dialect_param):
     }



More information about the Mlir-commits mailing list