[Mlir-commits] [mlir] [mlir][Transform] Extend transform.foreach to take multiple arguments (PR #93705)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 29 09:57:20 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Rolf Morel (rolfmorel)

<details>
<summary>Changes</summary>

Changes transform.foreach's interface to take multiple arguments, e.g. transform.foreach %ops1, %ops2, %params : ... { ^bb0(%op1, %op2, %param): BODY } The semantics are that the payloads for these handles get iterated over as if the payloads have been zipped-up together - BODY gets executed once for each such tuple. The documentation explains that this implementation requires that the payloads have the same length.

This change also enables the target argument(s) to be any op/value/param handle.

The added test cases demonstrate some use cases for this change.

---

Patch is 28.59 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/93705.diff


6 Files Affected:

- (modified) mlir/include/mlir/Dialect/Transform/IR/TransformOps.td (+31-26) 
- (modified) mlir/lib/Dialect/Transform/IR/TransformOps.cpp (+112-29) 
- (modified) mlir/test/Dialect/Linalg/multisize-tiling-full.mlir (+13-8) 
- (modified) mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir (+73) 
- (modified) mlir/test/Dialect/Transform/ops.mlir (+18-4) 
- (modified) mlir/test/Dialect/Transform/test-interpreter.mlir (+85) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 77048a28d7510..e61cd77339ac6 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -614,43 +614,48 @@ def ForeachOp : TransformDialectOp<"foreach",
          "getSuccessorRegions", "getEntrySuccessorOperands"]>,
      SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">
     ]> {
-  let summary = "Executes the body for each payload op";
+  let summary = "Executes the body for each element of the payload";
   let description = [{
-    This op has exactly one region with exactly one block ("body"). The body is
-    executed for each payload op that is associated to the target operand in an
-    unbatched fashion. I.e., the block argument ("iteration variable") is always
-    mapped to exactly one payload op.
-
-    This op always reads the target handle. Furthermore, it consumes the handle
-    if there is a transform op in the body that consumes the iteration variable.
-    This op does not return anything.
-
-    The transformations inside the body are applied in order of their
-    appearance. During application, if any transformation in the sequence fails,
-    the entire sequence fails immediately leaving the payload IR in potentially
-    invalid state, i.e., this operation offers no transformation rollback
-    capabilities.
-
-    This op generates as many handles as the terminating YieldOp has operands.
-    For each result, the payload ops of the corresponding YieldOp operand are
-    merged and mapped to the same resulting handle.
+    Execute the op's body - its single region block - exactly once per
+    element of the payload associated to a target handle. The body's
+    transformations are applied in order of appearance until reaching the
+    (implicit) YieldOp terminator.
+
+    Each iteration gets executed by co-indexing the payloads of the arguments
+    and mapping the body's arguments to these tuples, as though iterating over
+    the zipped together `targets`. As such, in each iteration, the size of the
+    payload of each of the body's block arguments is exactly one.
+
+    This op always reads the target handles. Furthermore, it consumes a handle
+    if there is a transform op in the body that consumes the corresponding
+    block argument. Handles can point to ops, values, or parameters.
+
+    #### Return Modes
+
+    This op produces as many result handles as the body's terminating YieldOp
+    has operands. For each result, the payloads of the corresponding YieldOp
+    operand are merged and mapped to the same resulting handle.
+
+    If the target handles do not associate payloads of the same size, or they
+    do not associate any payload at all, a silencable failure will be generated.
+
+    During application, if any transformation in the sequence fails, the entire
+    sequence fails immediately with the same failure, leaving the payload IR in
+    a potentially invalid state, i.e., this operation offers no transformation
+    rollback capabilities.
   }];
 
-  let arguments = (ins TransformHandleTypeInterface:$target);
-  let results = (outs Variadic<TransformHandleTypeInterface>:$results);
+  let arguments = (ins Variadic<Transform_AnyHandleOrParamType>:$targets);
+  let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
   let regions = (region SizedRegion<1>:$body);
   let assemblyFormat =
-    "$target `:` type($target) (`->` type($results)^)? $body attr-dict";
+    "$targets `:` type($targets) (`->` type($results)^)? $body attr-dict";
   let hasVerifier = 1;
 
   let extraClassDeclaration = [{
     /// Allow the dialect prefix to be omitted.
     static StringRef getDefaultDialect() { return "transform"; }
 
-    BlockArgument getIterationVariable() {
-      return getBody().front().getArgument(0);
-    }
-
     transform::YieldOp getYieldOp();
   }];
 }
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 247759e21efb1..b58f4e0672fc2 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1391,15 +1391,62 @@ DiagnosedSilenceableFailure
 transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
                             transform::TransformResults &results,
                             transform::TransformState &state) {
-  SmallVector<SmallVector<Operation *>> resultOps(getNumResults(), {});
-  // Store payload ops in a vector because ops may be removed from the mapping
-  // by the TrackingRewriter while the iteration is in progress.
-  SmallVector<Operation *> targets =
-      llvm::to_vector(state.getPayloadOps(getTarget()));
-  for (Operation *op : targets) {
+  // Collect the arguments with which to call each iteration of the body.
+  // We store the payload before executing the body as ops may be removed from
+  // the mapping by the TrackingRewriter while the iteration is in progress.
+  SmallVector<SmallVector<MappedValue>> zippedArgs;
+  for (auto firstTarget : getTargets().take_front(1)) // Loop runs at most once.
+    // For each element, init a tuple with which to call the body later on.
+    if (isa<TransformHandleTypeInterface>(firstTarget.getType()))
+      for (auto &op : state.getPayloadOps(firstTarget))
+        zippedArgs.append({{op}}); // NB: append's argument is an init-list.
+    else if (isa<TransformValueHandleTypeInterface>(firstTarget.getType()))
+      for (auto val : state.getPayloadValues(firstTarget))
+        zippedArgs.append({{val}});
+    else if (isa<TransformParamTypeInterface>(firstTarget.getType()))
+      for (auto param : state.getParams(firstTarget))
+        zippedArgs.append({{param}});
+    else
+      return emitDefiniteFailure()
+             << "unhandled handle type " << firstTarget.getType();
+
+  for (auto target : getTargets().drop_front(1)) {
+    // Append each element of payload to the co-indexed body-arguments-as-tuple.
+    size_t payloadSize = 0;
+    if (isa<TransformHandleTypeInterface>(target.getType())) {
+      for (auto op : state.getPayloadOps(target))
+        if (++payloadSize <= zippedArgs.size())
+          zippedArgs[payloadSize - 1].append({op});
+    } else if (isa<TransformValueHandleTypeInterface>(target.getType())) {
+      for (auto val : state.getPayloadValues(target))
+        if (++payloadSize <= zippedArgs.size())
+          zippedArgs[payloadSize - 1].append({val});
+    } else if (isa<TransformParamTypeInterface>(target.getType())) {
+      for (auto param : state.getParams(target))
+        if (++payloadSize <= zippedArgs.size())
+          zippedArgs[payloadSize - 1].append({param});
+    } else
+      return emitDefiniteFailure()
+             << "unhandled handle type " << target.getType();
+
+    if (payloadSize != zippedArgs.size())
+      return emitSilenceableError()
+             << "payload size of prior targets (" << zippedArgs.size()
+             << ") differs from payload size (" << payloadSize << ") of target "
+             << target;
+  }
+
+  // For each arguments-as-tuple collected up above, execute the body region.
+  SmallVector<SmallVector<MappedValue>> zippedResults(getNumResults(), {});
+  for (SmallVector<MappedValue> &argsTuple : zippedArgs) {
     auto scope = state.make_region_scope(getBody());
-    if (failed(state.mapBlockArguments(getIterationVariable(), {op})))
-      return DiagnosedSilenceableFailure::definiteFailure();
+    // Set up arguments to the region's block.
+    for (auto &&[blockArg, argument] :
+         llvm::zip_equal(getBody().front().getArguments(), argsTuple))
+      // Note: each blockArg's handle gets associated with just a single element
+      // from the corresponding target's payload.
+      if (failed(state.mapBlockArgument(blockArg, {argument})))
+        return DiagnosedSilenceableFailure::definiteFailure();
 
     // Execute loop body.
     for (Operation &transform : getBody().front().without_terminator()) {
@@ -1409,28 +1456,44 @@ transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
         return result;
     }
 
-    // Append yielded payload ops to result list (if any).
-    for (unsigned i = 0; i < getNumResults(); ++i) {
-      auto yieldedOps = state.getPayloadOps(getYieldOp().getOperand(i));
-      resultOps[i].append(yieldedOps.begin(), yieldedOps.end());
-    }
+    // Append yielded payloads to results.
+    auto yieldOperands = getYieldOp().getOperands();
+    for (auto &&[result, yieldOperand, resTuple] :
+         llvm::zip_equal(getResults(), yieldOperands, zippedResults))
+      // NB: each iteration we add any number of ops/vals/params to an opresult.
+      if (isa<TransformHandleTypeInterface>(result.getType()))
+        llvm::append_range(resTuple, state.getPayloadOps(yieldOperand));
+      else if (isa<TransformValueHandleTypeInterface>(result.getType()))
+        llvm::append_range(resTuple, state.getPayloadValues(yieldOperand));
+      else if (isa<TransformParamTypeInterface>(result.getType()))
+        llvm::append_range(resTuple, state.getParams(yieldOperand));
+      else
+        return emitDefiniteFailure()
+               << "unhandled handle type " << result.getType();
   }
 
   for (unsigned i = 0; i < getNumResults(); ++i)
-    results.set(llvm::cast<OpResult>(getResult(i)), resultOps[i]);
+    results.setMappedValues(cast<OpResult>(getResult(i)), zippedResults[i]);
 
   return DiagnosedSilenceableFailure::success();
 }
 
 void transform::ForeachOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  BlockArgument iterVar = getIterationVariable();
-  if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
-        return isHandleConsumed(iterVar, cast<TransformOpInterface>(&op));
-      })) {
-    consumesHandle(getTarget(), effects);
-  } else {
-    onlyReadsHandle(getTarget(), effects);
+
+  // NB: this `zip` should be `zip_equal` - while this op's verifier catches
+  // arity errors, this method might get called before/in absence of `verify()`.
+  for (auto &&[target, blockArg] :
+       llvm::zip(getTargets(), getBody().front().getArguments())) {
+    BlockArgument blockArgument = blockArg;
+    if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
+          return isHandleConsumed(blockArgument,
+                                  cast<TransformOpInterface>(&op));
+        })) {
+      consumesHandle(target, effects);
+    } else {
+      onlyReadsHandle(target, effects);
+    }
   }
 
   if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
@@ -1463,6 +1526,7 @@ void transform::ForeachOp::getSuccessorRegions(
 
 OperandRange
 transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+  // TODO: figure out how to update the comment & impl if necessary
   // The iteration variable op handle is mapped to a subset (one op to be
   // precise) of the payload ops of the ForeachOp operand.
   assert(point == getBody() && "unexpected region index");
@@ -1474,14 +1538,33 @@ transform::YieldOp transform::ForeachOp::getYieldOp() {
 }
 
 LogicalResult transform::ForeachOp::verify() {
-  auto yieldOp = getYieldOp();
-  if (getNumResults() != yieldOp.getNumOperands())
-    return emitOpError() << "expects the same number of results as the "
-                            "terminator has operands";
-  for (Value v : yieldOp.getOperands())
-    if (!llvm::isa<TransformHandleTypeInterface>(v.getType()))
-      return yieldOp->emitOpError("expects operands to have types implementing "
-                                  "TransformHandleTypeInterface");
+  for (auto [targetOpt, bodyArgOpt] :
+       llvm::zip_longest(getTargets(), getBody().front().getArguments())) {
+    if (!targetOpt || !bodyArgOpt)
+      return emitOpError() << "expects the same number of targets as the body "
+                              "has block arguments";
+    auto target = targetOpt.value();
+    if (target.getType() != bodyArgOpt.value().getType() ||
+        !isa<TransformHandleTypeInterface, TransformValueHandleTypeInterface,
+             TransformParamTypeInterface>(target.getType()))
+      return emitOpError(
+          "expects co-indexed targets and the body's "
+          "block arguments to have the same op/value/param type");
+  }
+
+  for (auto [resultOpt, yieldOperandOpt] :
+       llvm::zip_longest(getResults(), getYieldOp().getOperands())) {
+    if (!resultOpt || !yieldOperandOpt)
+      return emitOpError() << "expects the same number of results as the "
+                              "yield terminator has operands";
+    auto result = resultOpt.value();
+    if (result.getType() != yieldOperandOpt.value().getType() ||
+        !isa<TransformHandleTypeInterface, TransformValueHandleTypeInterface,
+             TransformParamTypeInterface>(result.getType()))
+      return emitOpError("expects co-indexed results and yield "
+                         "operands to have the same op/value/param type");
+  }
+
   return success();
 }
 
diff --git a/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
index 15b24b56608e3..51332ffce03d1 100644
--- a/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
+++ b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
@@ -6,15 +6,17 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3} : (!transform.any_op) -> !transform.any_op
-    %t:3 = transform.structured.multitile_sizes %0 { dimension = 1, target_size = 10} : (!transform.any_op) -> !transform.any_op
     %2:2 = transform.structured.split %0 after %1#2 { dimension = 0 } : !transform.any_op, !transform.any_op
     %3:2 = transform.structured.tile_using_for %2#0 tile_sizes [%1#0] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
     %4:2 = transform.structured.tile_using_for %2#1 tile_sizes [%1#1] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
     %5 = transform.merge_handles %3#0, %4#0 : !transform.any_op
-    %tt:3 = transform.replicate num(%5) %t#0, %t#1, %t#2 : !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op
-    %6:2 = transform.structured.split %5 after %tt#2 { dimension = 1 } : !transform.any_op, !transform.any_op
-    transform.structured.tile_using_for %6#0 tile_sizes [0, %tt#0] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-    transform.structured.tile_using_for %6#1 tile_sizes [0, %tt#1] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.foreach %5 : !transform.any_op {
+    ^bb0(%inner_linalg: !transform.any_op):
+      %low, %high, %split_point = transform.structured.multitile_sizes %inner_linalg { dimension = 1, target_size = 10} : (!transform.any_op) -> !transform.any_op
+      %inner_linalg_low, %inner_linalg_high = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.any_op
+      transform.structured.tile_using_for %inner_linalg_low tile_sizes [0, %low] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+      transform.structured.tile_using_for %inner_linalg_high tile_sizes [0, %high] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    }
     transform.yield
   }
 }
@@ -114,9 +116,12 @@ module attributes {transform.with_named_sequence} {
     %4:2 = transform.structured.tile_using_for %2#1 tile_sizes [%1#1] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
     %5 = transform.merge_handles %3#0, %4#0 : !transform.any_op
     %tt:3 = transform.replicate num(%5) %t#0, %t#1, %t#2 : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>
-    %6:2 = transform.structured.split %5 after %tt#2 { dimension = 1 } : !transform.any_op, !transform.param<i64>
-    transform.structured.tile_using_for %6#0 tile_sizes [0, %tt#0] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
-    transform.structured.tile_using_for %6#1 tile_sizes [0, %tt#1] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
+    transform.foreach %5, %tt#0, %tt#1, %tt#2 : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64> {
+    ^bb0(%inner_linalg: !transform.any_op, %low: !transform.param<i64>, %high: !transform.param<i64>, %split_point: !transform.param<i64>):
+      %inner_linalg_low, %inner_linalg_high = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.param<i64>
+      transform.structured.tile_using_for %inner_linalg_low tile_sizes [0, %low] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
+      transform.structured.tile_using_for %inner_linalg_high tile_sizes [0, %high] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
+    }
     transform.yield
   }
 }
diff --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
index 0f51b1cdbe0cf..54dd2bdf953ca 100644
--- a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
+++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
@@ -328,3 +328,76 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+// -----
+
+// CHECK: func.func @foreach_loop_pair_fuse([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}
+func.func @foreach_loop_pair_fuse(%arg1: tensor<128xf32>, %arg2: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>, tensor<128xf32>, tensor<128xf32>) {
+  // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
+  // CHECK-DAG: [[C16:%.*]] = arith.constant 16 : index
+  // CHECK-DAG: [[C128:%.*]] = arith.constant 128 : index
+  // CHECK-DAG: [[ZERO:%.*]] = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  %c128 = arith.constant 128 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  // CHECK: [[RST:%.*]]:2 = scf.for [[IV:%.*]] = [[C0]] to [[C128]] step [[C16]] iter_args([[IB0:%.*]] = [[B]], [[IB1:%.*]] = [[B]]) {{.*}}
+  %1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) {
+  // CHECK-DAG:   [[ASLICE:%.*]] = vector.transfer_read [[A]][[[IV]]], [[ZERO]]
+  // CHECK-DAG:   [[SLICE0:%.*]] = vector.transfer_read [[IB0]][[[IV]]], [[ZERO]]
+  // CHECK:       [[OUT1:%.*]] = arith.addf [[SLICE0]], [[ASLICE]]
+  // CHECK-NEXT:  [[WRT0:%.*]] = vector.transfer_write [[OUT1]], [[IB0]][[[IV]]]
+    %2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %5 = arith.addf %3, %2 : vector<16xf32>
+    %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+    scf.yield %6 : tensor<128xf32>
+  } {target_loops}
+  %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %arg2) -> (tensor<128xf32>) {
+  // CHECK-DAG:   [[SLICE1:%.*]] = vector.transfer_read [[IB1]][[[IV]]], [[ZERO]]
+  // CHECK:       [[OUT2:%.*]] = arith.addf [[SLICE1]], [[ASLICE]]
+  // CHECK-NEXT:  [[WRT1:%.*]] = vector.transfer_write [[OUT2]], [[IB1]][[[IV]]]
+    %dup2 = vector.transfer_read %arg1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
+    %dup5 = arith.addf %dup3, %dup2 : vector<16xf32>
+    %dup6 = vector.transfer_write %dup5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
+  // CHECK: scf.yield [[WRT0]], [[WRT1]] : {{.*}}
+    scf.yield %dup6 : tensor<128xf32>
+  } {source_loops}
+  %2 = scf.for %arg3 = %c0 to %c128 step %c32 iter_args(%arg4 = %arg2) -> (tensor<128xf32>)  {
+  // CHECK-DAG:   [[ASLICE:%.*]] = vector.transfer_re...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/93705


More information about the Mlir-commits mailing list