[Mlir-commits] [mlir] d462bf6 - [mlir][Transform] Extend transform.foreach to take multiple arguments (#93705)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jun 14 08:02:51 PDT 2024


Author: Rolf Morel
Date: 2024-06-14T17:02:47+02:00
New Revision: d462bf687548a5630f60a8afaa66120df8319e88

URL: https://github.com/llvm/llvm-project/commit/d462bf687548a5630f60a8afaa66120df8319e88
DIFF: https://github.com/llvm/llvm-project/commit/d462bf687548a5630f60a8afaa66120df8319e88.diff

LOG: [mlir][Transform] Extend transform.foreach to take multiple arguments (#93705)

Changes transform.foreach's interface to take multiple arguments, e.g.
transform.foreach %ops1, %ops2, %params : ... { ^bb0(%op1, %op2,
%param): BODY } The semantics are that the payloads for these handles
get iterated over as if the payloads have been zipped-up together - BODY
gets executed once for each such tuple. The documentation explains that
this implementation requires that the payloads have the same length.

This change also enables the target argument(s) to be any op/value/param
handle.

The added test cases demonstrate some use cases for this change.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
    mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
    mlir/test/Dialect/Transform/ops-invalid.mlir
    mlir/test/Dialect/Transform/ops.mlir
    mlir/test/Dialect/Transform/test-interpreter.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 77048a28d7510..3bb297cbf91d2 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -614,43 +614,48 @@ def ForeachOp : TransformDialectOp<"foreach",
          "getSuccessorRegions", "getEntrySuccessorOperands"]>,
      SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">
     ]> {
-  let summary = "Executes the body for each payload op";
+  let summary = "Executes the body for each element of the payload";
   let description = [{
-    This op has exactly one region with exactly one block ("body"). The body is
-    executed for each payload op that is associated to the target operand in an
-    unbatched fashion. I.e., the block argument ("iteration variable") is always
-    mapped to exactly one payload op.
-
-    This op always reads the target handle. Furthermore, it consumes the handle
-    if there is a transform op in the body that consumes the iteration variable.
-    This op does not return anything.
-
-    The transformations inside the body are applied in order of their
-    appearance. During application, if any transformation in the sequence fails,
-    the entire sequence fails immediately leaving the payload IR in potentially
-    invalid state, i.e., this operation offers no transformation rollback
-    capabilities.
-
-    This op generates as many handles as the terminating YieldOp has operands.
-    For each result, the payload ops of the corresponding YieldOp operand are
-    merged and mapped to the same resulting handle.
+    Execute the op's body - its single region block - exactly once per
+    element of the payload associated to a target handle. The body's
+    transformations are applied in order of appearance until reaching the
+    (implicit) YieldOp terminator.
+
+    Each iteration gets executed by co-indexing the payloads of the arguments
+    and mapping the body's arguments to these tuples, as though iterating over
+    the zipped together `targets`. As such, in each iteration, the size of the
+    payload of each of the body's block arguments is exactly one.
+
+    This op always reads the target handles. Furthermore, it consumes a handle
+    if there is a transform op in the body that consumes the corresponding
+    block argument. Handles can point to ops, values, or parameters.
+
+    #### Return Modes
+
+    This op produces as many result handles as the body's terminating YieldOp
+    has operands. For each result, the payloads of the corresponding YieldOp
+    operand are merged and mapped to the same resulting handle.
+
+    If the target handles do not associate payloads of the same size, a
+    silencable failure will be generated.
+
+    During application, if any transformation in the sequence fails, the entire
+    sequence fails immediately with the same failure, leaving the payload IR in
+    a potentially invalid state, i.e., this operation offers no transformation
+    rollback capabilities.
   }];
 
-  let arguments = (ins TransformHandleTypeInterface:$target);
-  let results = (outs Variadic<TransformHandleTypeInterface>:$results);
+  let arguments = (ins Variadic<Transform_AnyHandleOrParamType>:$targets);
+  let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
   let regions = (region SizedRegion<1>:$body);
   let assemblyFormat =
-    "$target `:` type($target) (`->` type($results)^)? $body attr-dict";
+    "$targets `:` type($targets) (`->` type($results)^)? $body attr-dict";
   let hasVerifier = 1;
 
   let extraClassDeclaration = [{
     /// Allow the dialect prefix to be omitted.
     static StringRef getDefaultDialect() { return "transform"; }
 
-    BlockArgument getIterationVariable() {
-      return getBody().front().getArgument(0);
-    }
-
     transform::YieldOp getYieldOp();
   }];
 }

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 247759e21efb1..1a7ec030f0eb1 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1391,46 +1391,83 @@ DiagnosedSilenceableFailure
 transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
                             transform::TransformResults &results,
                             transform::TransformState &state) {
-  SmallVector<SmallVector<Operation *>> resultOps(getNumResults(), {});
-  // Store payload ops in a vector because ops may be removed from the mapping
-  // by the TrackingRewriter while the iteration is in progress.
-  SmallVector<Operation *> targets =
-      llvm::to_vector(state.getPayloadOps(getTarget()));
-  for (Operation *op : targets) {
+  // We store the payloads before executing the body as ops may be removed from
+  // the mapping by the TrackingRewriter while iteration is in progress.
+  SmallVector<SmallVector<MappedValue>> payloads;
+  detail::prepareValueMappings(payloads, getTargets(), state);
+  size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
+
+  // As we will be "zipping" over them, check all payloads have the same size.
+  for (size_t argIdx = 1; argIdx < payloads.size(); argIdx++) {
+    if (payloads[argIdx].size() != numIterations) {
+      return emitSilenceableError()
+             << "prior targets' payload size (" << numIterations
+             << ") 
diff ers from payload size (" << payloads[argIdx].size()
+             << ") of target " << getTargets()[argIdx];
+    }
+  }
+
+  // Start iterating, indexing into payloads to obtain the right arguments to
+  // call the body with - each slice of payloads at the same argument index
+  // corresponding to a tuple to use as the body's block arguments.
+  ArrayRef<BlockArgument> blockArguments = getBody().front().getArguments();
+  SmallVector<SmallVector<MappedValue>> zippedResults(getNumResults(), {});
+  for (size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) {
     auto scope = state.make_region_scope(getBody());
-    if (failed(state.mapBlockArguments(getIterationVariable(), {op})))
-      return DiagnosedSilenceableFailure::definiteFailure();
+    // Set up arguments to the region's block.
+    for (auto &&[argIdx, blockArg] : llvm::enumerate(blockArguments)) {
+      MappedValue argument = payloads[argIdx][iterIdx];
+      // Note that each blockArg's handle gets associated with just a single
+      // element from the corresponding target's payload.
+      if (failed(state.mapBlockArgument(blockArg, {argument})))
+        return DiagnosedSilenceableFailure::definiteFailure();
+    }
 
     // Execute loop body.
     for (Operation &transform : getBody().front().without_terminator()) {
       DiagnosedSilenceableFailure result = state.applyTransform(
-          cast<transform::TransformOpInterface>(transform));
+          llvm::cast<transform::TransformOpInterface>(transform));
       if (!result.succeeded())
         return result;
     }
 
-    // Append yielded payload ops to result list (if any).
-    for (unsigned i = 0; i < getNumResults(); ++i) {
-      auto yieldedOps = state.getPayloadOps(getYieldOp().getOperand(i));
-      resultOps[i].append(yieldedOps.begin(), yieldedOps.end());
-    }
-  }
-
-  for (unsigned i = 0; i < getNumResults(); ++i)
-    results.set(llvm::cast<OpResult>(getResult(i)), resultOps[i]);
+    // Append yielded payloads to corresponding results from prior iterations.
+    OperandRange yieldOperands = getYieldOp().getOperands();
+    for (auto &&[result, yieldOperand, resTuple] :
+         llvm::zip_equal(getResults(), yieldOperands, zippedResults))
+      // NB: each iteration we add any number of ops/vals/params to a result.
+      if (isa<TransformHandleTypeInterface>(result.getType()))
+        llvm::append_range(resTuple, state.getPayloadOps(yieldOperand));
+      else if (isa<TransformValueHandleTypeInterface>(result.getType()))
+        llvm::append_range(resTuple, state.getPayloadValues(yieldOperand));
+      else if (isa<TransformParamTypeInterface>(result.getType()))
+        llvm::append_range(resTuple, state.getParams(yieldOperand));
+      else
+        assert(false && "unhandled handle type");
+  }
+
+  // Associate the accumulated result payloads to the op's actual results.
+  for (auto &&[result, resPayload] : zip_equal(getResults(), zippedResults))
+    results.setMappedValues(llvm::cast<OpResult>(result), resPayload);
 
   return DiagnosedSilenceableFailure::success();
 }
 
 void transform::ForeachOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  BlockArgument iterVar = getIterationVariable();
-  if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
-        return isHandleConsumed(iterVar, cast<TransformOpInterface>(&op));
-      })) {
-    consumesHandle(getTarget(), effects);
-  } else {
-    onlyReadsHandle(getTarget(), effects);
+  // NB: this `zip` should be `zip_equal` - while this op's verifier catches
+  // arity errors, this method might get called before/in absence of `verify()`.
+  for (auto &&[target, blockArg] :
+       llvm::zip(getTargets(), getBody().front().getArguments())) {
+    BlockArgument blockArgument = blockArg;
+    if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
+          return isHandleConsumed(blockArgument,
+                                  cast<TransformOpInterface>(&op));
+        })) {
+      consumesHandle(target, effects);
+    } else {
+      onlyReadsHandle(target, effects);
+    }
   }
 
   if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
@@ -1463,8 +1500,8 @@ void transform::ForeachOp::getSuccessorRegions(
 
 OperandRange
 transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) {
-  // The iteration variable op handle is mapped to a subset (one op to be
-  // precise) of the payload ops of the ForeachOp operand.
+  // Each block argument handle is mapped to a subset (one op to be precise)
+  // of the payload of the corresponding `targets` operand of ForeachOp.
   assert(point == getBody() && "unexpected region index");
   return getOperation()->getOperands();
 }
@@ -1474,14 +1511,27 @@ transform::YieldOp transform::ForeachOp::getYieldOp() {
 }
 
 LogicalResult transform::ForeachOp::verify() {
-  auto yieldOp = getYieldOp();
-  if (getNumResults() != yieldOp.getNumOperands())
-    return emitOpError() << "expects the same number of results as the "
-                            "terminator has operands";
-  for (Value v : yieldOp.getOperands())
-    if (!llvm::isa<TransformHandleTypeInterface>(v.getType()))
-      return yieldOp->emitOpError("expects operands to have types implementing "
-                                  "TransformHandleTypeInterface");
+  for (auto [targetOpt, bodyArgOpt] :
+       llvm::zip_longest(getTargets(), getBody().front().getArguments())) {
+    if (!targetOpt || !bodyArgOpt)
+      return emitOpError() << "expects the same number of targets as the body "
+                              "has block arguments";
+    if (targetOpt.value().getType() != bodyArgOpt.value().getType())
+      return emitOpError(
+          "expects co-indexed targets and the body's "
+          "block arguments to have the same op/value/param type");
+  }
+
+  for (auto [resultOpt, yieldOperandOpt] :
+       llvm::zip_longest(getResults(), getYieldOp().getOperands())) {
+    if (!resultOpt || !yieldOperandOpt)
+      return emitOpError() << "expects the same number of results as the "
+                              "yield terminator has operands";
+    if (resultOpt.value().getType() != yieldOperandOpt.value().getType())
+      return emitOpError("expects co-indexed results and yield "
+                         "operands to have the same op/value/param type");
+  }
+
   return success();
 }
 

diff  --git a/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
index 15b24b56608e3..51332ffce03d1 100644
--- a/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
+++ b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
@@ -6,15 +6,17 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3} : (!transform.any_op) -> !transform.any_op
-    %t:3 = transform.structured.multitile_sizes %0 { dimension = 1, target_size = 10} : (!transform.any_op) -> !transform.any_op
     %2:2 = transform.structured.split %0 after %1#2 { dimension = 0 } : !transform.any_op, !transform.any_op
     %3:2 = transform.structured.tile_using_for %2#0 tile_sizes [%1#0] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
     %4:2 = transform.structured.tile_using_for %2#1 tile_sizes [%1#1] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
     %5 = transform.merge_handles %3#0, %4#0 : !transform.any_op
-    %tt:3 = transform.replicate num(%5) %t#0, %t#1, %t#2 : !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op
-    %6:2 = transform.structured.split %5 after %tt#2 { dimension = 1 } : !transform.any_op, !transform.any_op
-    transform.structured.tile_using_for %6#0 tile_sizes [0, %tt#0] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-    transform.structured.tile_using_for %6#1 tile_sizes [0, %tt#1] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.foreach %5 : !transform.any_op {
+    ^bb0(%inner_linalg: !transform.any_op):
+      %low, %high, %split_point = transform.structured.multitile_sizes %inner_linalg { dimension = 1, target_size = 10} : (!transform.any_op) -> !transform.any_op
+      %inner_linalg_low, %inner_linalg_high = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.any_op
+      transform.structured.tile_using_for %inner_linalg_low tile_sizes [0, %low] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+      transform.structured.tile_using_for %inner_linalg_high tile_sizes [0, %high] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    }
     transform.yield
   }
 }
@@ -114,9 +116,12 @@ module attributes {transform.with_named_sequence} {
     %4:2 = transform.structured.tile_using_for %2#1 tile_sizes [%1#1] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
     %5 = transform.merge_handles %3#0, %4#0 : !transform.any_op
     %tt:3 = transform.replicate num(%5) %t#0, %t#1, %t#2 : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>
-    %6:2 = transform.structured.split %5 after %tt#2 { dimension = 1 } : !transform.any_op, !transform.param<i64>
-    transform.structured.tile_using_for %6#0 tile_sizes [0, %tt#0] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
-    transform.structured.tile_using_for %6#1 tile_sizes [0, %tt#1] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
+    transform.foreach %5, %tt#0, %tt#1, %tt#2 : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64> {
+    ^bb0(%inner_linalg: !transform.any_op, %low: !transform.param<i64>, %high: !transform.param<i64>, %split_point: !transform.param<i64>):
+      %inner_linalg_low, %inner_linalg_high = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.param<i64>
+      transform.structured.tile_using_for %inner_linalg_low tile_sizes [0, %low] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
+      transform.structured.tile_using_for %inner_linalg_high tile_sizes [0, %high] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
+    }
     transform.yield
   }
 }

diff  --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
index 0f51b1cdbe0cf..54dd2bdf953ca 100644
--- a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
+++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
@@ -328,3 +328,76 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+// -----
+
+// CHECK: func.func @foreach_loop_pair_fuse([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
+func.func @foreach_loop_pair_fuse(%arg1: tensor<128xf32>, %arg2: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>, tensor<128xf32>, tensor<128xf32>) {
+  // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
+  // CHECK-DAG: [[C16:%.*]] = arith.constant 16 : index
+  // CHECK-DAG: [[C128:%.*]] = arith.constant 128 : index
+  // CHECK-DAG: [[ZERO:%.*]] = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  %c128 = arith.constant 128 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  // CHECK: [[RST:%.*]]:2 = scf.for [[IV:%.*]] = [[C0]] to [[C128]] step [[C16]] iter_args([[IB0:%.*]] = [[B]], [[IB1:%.*]] = [[B]]) {{.*}}
+  %1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) {
+  // CHECK-DAG:   [[ASLICE:%.*]] = vector.transfer_read [[A]][[[IV]]], [[ZERO]]
+  // CHECK-DAG:   [[SLICE0:%.*]] = vector.transfer_read [[IB0]][[[IV]]], [[ZERO]]
+  // CHECK:       [[OUT1:%.*]] = arith.addf [[SLICE0]], [[ASLICE]]
+  // CHECK-NEXT:  [[WRT0:%.*]] = vector.transfer_write [[OUT1]], [[IB0]][[[IV]]]
+    %2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %5 = arith.addf %3, %2 : vector<16xf32>
+    %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+    scf.yield %6 : tensor<128xf32>
+  } {target_loops}
+  %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) {
+  // CHECK-DAG:   [[SLICE1:%.*]] = vector.transfer_read [[IB1]][[[IV]]], [[ZERO]]
+  // CHECK:       [[OUT2:%.*]] = arith.addf [[SLICE1]], [[ASLICE]]
+  // CHECK-NEXT:  [[WRT1:%.*]] = vector.transfer_write [[OUT2]], [[IB1]][[[IV]]]
+    %dup2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %dup5 = arith.addf %dup3, %dup2 : vector<16xf32>
+    %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+  // CHECK: scf.yield [[WRT0]], [[WRT1]] : {{.*}}
+    scf.yield %dup6 : tensor<128xf32>
+  } {source_loops}
+  %2 = scf.for %arg3 = %c0 to %c128 step %c32 iter_args(%arg4 = %arg2) -> (tensor<128xf32>)  {
+  // CHECK-DAG:   [[ASLICE:%.*]] = vector.transfer_read [[A]][[[IV]]], [[ZERO]]
+  // CHECK-DAG:   [[SLICE0:%.*]] = vector.transfer_read [[IB0]][[[IV]]], [[ZERO]]
+  // CHECK:       [[OUT1:%.*]] = arith.addf [[SLICE0]], [[ASLICE]]
+  // CHECK-NEXT:  [[WRT0:%.*]] = vector.transfer_write [[OUT1]], [[IB0]][[[IV]]]
+    %2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<32xf32>
+    %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<32xf32>
+    %5 = arith.addf %3, %2 : vector<32xf32>
+    %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<32xf32>, tensor<128xf32>
+    scf.yield %6 : tensor<128xf32>
+  } {target_loops}
+  %dup2 = scf.for %arg3 = %c0 to %c128 step %c32 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) {
+  // CHECK-DAG:   [[SLICE1:%.*]] = vector.transfer_read [[IB1]][[[IV]]], [[ZERO]]
+  // CHECK:       [[OUT2:%.*]] = arith.addf [[SLICE1]], [[ASLICE]]
+  // CHECK-NEXT:  [[WRT1:%.*]] = vector.transfer_write [[OUT2]], [[IB1]][[[IV]]]
+    %dup2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<32xf32>
+    %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<32xf32>
+    %dup5 = arith.addf %dup3, %dup2 : vector<32xf32>
+    %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<32xf32>, tensor<128xf32>
+  // CHECK: scf.yield [[WRT0]], [[WRT1]] : {{.*}}
+    scf.yield %dup6 : tensor<128xf32>
+  } {source_loops}
+  return %1, %dup1, %2, %dup2 : tensor<128xf32>, tensor<128xf32>, tensor<128xf32>, tensor<128xf32>
+}
+
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %target_loops = transform.structured.match ops{["scf.for"]} attributes {target_loops} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %source_loops = transform.structured.match ops{["scf.for"]} attributes {source_loops} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.foreach %target_loops, %source_loops : !transform.any_op, !transform.any_op {
+    ^bb0(%target_loop: !transform.any_op, %source_loop: !transform.any_op):
+      %fused = transform.loop.fuse_sibling %target_loop into %source_loop : (!transform.any_op,!transform.any_op) ->  !transform.any_op
+    }
+    transform.yield
+  }
+}

diff  --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir
index 30a68cc5f3c44..71a260f1196e9 100644
--- a/mlir/test/Dialect/Transform/ops-invalid.mlir
+++ b/mlir/test/Dialect/Transform/ops-invalid.mlir
@@ -279,6 +279,55 @@ transform.sequence failures(propagate) {
 
 // -----
 
+transform.sequence failures(propagate) {
+  ^bb0(%root: !transform.any_op):
+  %op = test_produce_self_handle_or_forward_operand : () -> !transform.any_op
+  // expected-error @below {{op expects the same number of targets as the body has block arguments}}
+  transform.foreach %op : !transform.any_op -> !transform.any_op, !transform.any_value {
+  ^bb1(%op_arg: !transform.any_op, %val_arg: !transform.any_value):
+    transform.yield %op_arg, %val_arg : !transform.any_op, !transform.any_value
+  }
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+  ^bb0(%root: !transform.any_op):
+  %op = test_produce_self_handle_or_forward_operand : () -> !transform.any_op
+  // expected-error @below {{op expects co-indexed targets and the body's block arguments to have the same op/value/param type}}
+  transform.foreach %op : !transform.any_op -> !transform.any_value {
+  ^bb1(%val_arg: !transform.any_value):
+    transform.yield %val_arg : !transform.any_value
+  }
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+  ^bb0(%root: !transform.any_op):
+  %op = test_produce_self_handle_or_forward_operand : () -> !transform.any_op
+  // expected-error @below {{op expects the same number of results as the yield terminator has operands}}
+  transform.foreach %op : !transform.any_op -> !transform.any_op, !transform.any_op {
+  ^bb1(%arg_op: !transform.any_op):
+    transform.yield %arg_op : !transform.any_op
+  }
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+  ^bb0(%root: !transform.any_op):
+  %op = test_produce_self_handle_or_forward_operand : () -> !transform.any_op
+  %val = transform.test_produce_value_handle_to_self_operand %op : (!transform.any_op) -> !transform.any_value
+  // expected-error @below {{expects co-indexed results and yield operands to have the same op/value/param type}}
+  transform.foreach %op, %val : !transform.any_op, !transform.any_value -> !transform.any_op, !transform.any_value {
+  ^bb1(%op_arg: !transform.any_op, %val_arg: !transform.any_value):
+    transform.yield %val_arg, %op_arg : !transform.any_value, !transform.any_op
+  }
+}
+
+// -----
+
 transform.sequence failures(suppress) {
 ^bb0(%arg0: !transform.any_op):
   // expected-error @below {{TransformOpInterface requires memory effects on operands to be specified}}

diff  --git a/mlir/test/Dialect/Transform/ops.mlir b/mlir/test/Dialect/Transform/ops.mlir
index b03a9f4d760d2..e9baffde262fa 100644
--- a/mlir/test/Dialect/Transform/ops.mlir
+++ b/mlir/test/Dialect/Transform/ops.mlir
@@ -68,11 +68,25 @@ transform.sequence failures(propagate) {
 }
 
 // CHECK: transform.sequence
-// CHECK: foreach
 transform.sequence failures(propagate) {
-^bb0(%arg0: !transform.any_op):
-  transform.foreach %arg0 : !transform.any_op {
-  ^bb1(%arg1: !transform.any_op):
+^bb0(%op0: !transform.any_op, %val0: !transform.any_value, %par0: !transform.any_param):
+  // CHECK: foreach %{{.*}} : !transform.any_op
+  transform.foreach %op0 : !transform.any_op {
+  ^bb1(%op1: !transform.any_op):
+  }
+  // CHECK: foreach %{{.*}} : !transform.any_op, !transform.any_value, !transform.any_param
+  transform.foreach %op0, %val0, %par0 : !transform.any_op, !transform.any_value, !transform.any_param {
+  ^bb1(%op1: !transform.any_op, %val1: !transform.any_value, %par1: !transform.any_param):
+  }
+  // CHECK: foreach %{{.*}} : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_op
+  transform.foreach %op0, %val0, %par0 : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_op {
+  ^bb1(%op1: !transform.any_op, %val1: !transform.any_value, %par1: !transform.any_param):
+    transform.yield %op1 : !transform.any_op
+  }
+  // CHECK: foreach %{{.*}} : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_param, !transform.any_value
+  transform.foreach %op0, %val0, %par0 : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_param, !transform.any_value {
+  ^bb1(%op1: !transform.any_op, %val1: !transform.any_value, %par1: !transform.any_param):
+    transform.yield %par1, %val1 : !transform.any_param, !transform.any_value
   }
 }
 

diff  --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index b6850e2024d53..4fe2dbedff56e 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -830,6 +830,91 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %results, %types = transform.foreach %0 : !transform.any_op -> !transform.any_value, !transform.any_param {
+    ^bb0(%op0 : !transform.any_op):
+      %result = transform.get_result %op0[0] : (!transform.any_op) -> !transform.any_value
+      %type = transform.get_type elemental %result  : (!transform.any_value) -> !transform.any_param
+      transform.yield %result, %type : !transform.any_value, !transform.any_param
+    }
+    transform.debug.emit_remark_at %results, "result selected" : !transform.any_value
+    transform.debug.emit_param_as_remark %types, "elemental types" at %0 : !transform.any_param, !transform.any_op
+
+    transform.yield
+  }
+}
+
+func.func @payload(%lhs: tensor<10x20xf16>,
+                   %rhs: tensor<20x15xf32>) -> (tensor<10x15xf64>, tensor<10x15xf32>) {
+  %cst64 = arith.constant 0.0 : f64
+  %empty64 = tensor.empty() : tensor<10x15xf64>
+  %fill64 = linalg.fill ins(%cst64 : f64) outs(%empty64 : tensor<10x15xf64>) -> tensor<10x15xf64>
+  // expected-remark @below {{result selected}}
+  // expected-note @below {{value handle points to an op result #0}}
+  // expected-remark @below {{elemental types f64, f32}}
+  %result64 = linalg.matmul ins(%lhs, %rhs: tensor<10x20xf16>, tensor<20x15xf32>)
+                         outs(%fill64: tensor<10x15xf64>) -> tensor<10x15xf64>
+
+  %cst32 = arith.constant 0.0 : f32
+  %empty32 = tensor.empty() : tensor<10x15xf32>
+  %fill32 = linalg.fill ins(%cst32 : f32) outs(%empty32 : tensor<10x15xf32>) -> tensor<10x15xf32>
+  // expected-remark @below {{result selected}}
+  // expected-note @below {{value handle points to an op result #0}}
+  // expected-remark @below {{elemental types f64, f32}}
+  %result32 = linalg.matmul ins(%lhs, %rhs: tensor<10x20xf16>, tensor<20x15xf32>)
+                           outs(%fill32: tensor<10x15xf32>) -> tensor<10x15xf32>
+
+  return %result64, %result32 : tensor<10x15xf64>, tensor<10x15xf32>
+
+}
+
+// -----
+
+func.func @two_const_ops() {
+  %0 = arith.constant 0 : index
+  %1 = arith.constant 1 : index
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %two_ops = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %one_param = transform.param.constant 1 : i32 -> !transform.test_dialect_param
+    // expected-error @below {{prior targets' payload size (2) 
diff ers from payload size (1) of target}}
+    transform.foreach %two_ops, %one_param : !transform.any_op, !transform.test_dialect_param {
+    ^bb2(%op: !transform.any_op, %param: !transform.test_dialect_param):
+    }
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @one_const_op() {
+  %0 = arith.constant 0 : index
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %one_op = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %one_val = transform.test_produce_value_handle_to_self_operand %one_op : (!transform.any_op) -> !transform.any_value
+    %param_one = transform.param.constant 1 : i32 -> !transform.test_dialect_param
+    %param_two = transform.param.constant 2 : i32 -> !transform.test_dialect_param
+    %two_params = transform.merge_handles %param_one, %param_two : !transform.test_dialect_param
+
+    // expected-error @below {{prior targets' payload size (1) 
diff ers from payload size (2) of target}}
+    transform.foreach %one_val, %one_op, %two_params : !transform.any_value, !transform.any_op, !transform.test_dialect_param {
+    ^bb2(%val: !transform.any_value, %op: !transform.any_op, %param: !transform.test_dialect_param):
+    }
+    transform.yield
+  }
+}
+
+// -----
+
 // CHECK-LABEL: func @consume_in_foreach()
 //  CHECK-NEXT:   return
 func.func @consume_in_foreach() {


        


More information about the Mlir-commits mailing list