[Mlir-commits] [mlir] [mlir] Makes `zip_shortest` an optional keyword in `transform.foreach` (PR #98492)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jul 11 07:53:48 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Guillermo Callaghan (Guillermo-Callaghan)

<details>
<summary>Changes</summary>

This PR addresses a [comment] made by @<!-- -->ftynse about the syntax for `ForeachOp`. The syntax was modified by @<!-- -->muneebkhan85 in #<!-- -->82792, where the attribute dictionary was moved to the middle.
This patch moves it back to its original place at the end. And introduces an optional keyword for `zip_shortest`.

[comment]: https://github.com/llvm/llvm-project/pull/82792#pullrequestreview-2132814144

---
Full diff: https://github.com/llvm/llvm-project/pull/98492.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/Transform/IR/TransformOps.td (+3-2) 
- (modified) mlir/lib/Dialect/Transform/IR/TransformOps.cpp (+3-3) 
- (modified) mlir/test/Dialect/Linalg/continuous-tiling-full.mlir (+1-1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 7a661c663e010..b946fc8875860 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -649,11 +649,12 @@ def ForeachOp : TransformDialectOp<"foreach",
   }];
 
   let arguments = (ins Variadic<Transform_AnyHandleOrParamType>:$targets,
-                       UnitAttr:$zip_shortest);
+                       UnitAttr:$with_zip_shortest);
   let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
   let regions = (region SizedRegion<1>:$body);
   let assemblyFormat =
-    "$targets attr-dict `:` type($targets) (`->` type($results)^)? $body";
+    "$targets oilist(`with_zip_shortest` $with_zip_shortest) `:` "
+    "type($targets) (`->` type($results)^)? $body attr-dict";
   let hasVerifier = 1;
 
   let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 3d6a9a18d9f41..c4238080533be 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1396,11 +1396,11 @@ transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
   SmallVector<SmallVector<MappedValue>> payloads;
   detail::prepareValueMappings(payloads, getTargets(), state);
   size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
-  bool isZipShortest = getZipShortest();
+  bool withZipShortest = getWithZipShortest();
 
   // In case of `zip_shortest`, set the number of iterations to the
   // smallest payload in the targets.
-  if (isZipShortest) {
+  if (withZipShortest) {
     numIterations =
         llvm::min_element(payloads, [&](const SmallVector<MappedValue> &A,
                                         const SmallVector<MappedValue> &B) {
@@ -1414,7 +1414,7 @@ transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
   // As we will be "zipping" over them, check all payloads have the same size.
   // `zip_shortest` adjusts all payloads to the same size, so skip this check
   // when true.
-  for (size_t argIdx = 1; !isZipShortest && argIdx < payloads.size();
+  for (size_t argIdx = 1; !withZipShortest && argIdx < payloads.size();
        argIdx++) {
     if (payloads[argIdx].size() != numIterations) {
       return emitSilenceableError()
diff --git a/mlir/test/Dialect/Linalg/continuous-tiling-full.mlir b/mlir/test/Dialect/Linalg/continuous-tiling-full.mlir
index b61c727f9cffc..7410ff593d01a 100644
--- a/mlir/test/Dialect/Linalg/continuous-tiling-full.mlir
+++ b/mlir/test/Dialect/Linalg/continuous-tiling-full.mlir
@@ -127,7 +127,7 @@ module attributes {transform.with_named_sequence} {
     %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %tile_sizes, %chunk_sizes = transform.structured.continuous_tile_sizes %0 { dimension = 0, target_size = 9 } : (!transform.any_op) -> !transform.any_op
     %linalg_splits, %empty = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.any_op
-    transform.foreach %linalg_splits, %tile_sizes {zip_shortest} : !transform.any_op, !transform.any_op {
+    transform.foreach %linalg_splits, %tile_sizes with_zip_shortest : !transform.any_op, !transform.any_op {
     ^bb1(%linalg_split: !transform.any_op, %tile_size: !transform.any_op):
       %tiled_linalg_split, %dim0_loop = transform.structured.tile_using_for %linalg_split tile_sizes [%tile_size] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
       transform.yield

``````````

</details>


https://github.com/llvm/llvm-project/pull/98492


More information about the Mlir-commits mailing list