[Mlir-commits] [mlir] [mlir] Makes `zip_shortest` an optional keyword in `transform.foreach` (PR #98492)

Guillermo Callaghan llvmlistbot at llvm.org
Thu Jul 11 07:52:58 PDT 2024


https://github.com/Guillermo-Callaghan created https://github.com/llvm/llvm-project/pull/98492

This PR addresses a [comment] made by @ftynse about the syntax for `ForeachOp`. The syntax was modified by @muneebkhan85 in #82792, where the attribute dictionary was moved to the middle.
This patch moves it back to its original place at the end. And introduces an optional keyword for `zip_shortest`.

[comment]: https://github.com/llvm/llvm-project/pull/82792#pullrequestreview-2132814144

>From 995ee5170e527bbc044cd4e671f142c2444e3ac4 Mon Sep 17 00:00:00 2001
From: Guillermo Callaghan <guillermo.callaghan at huawei.com>
Date: Thu, 11 Jul 2024 22:17:44 +0800
Subject: [PATCH] [mlir] Make zip_shortest an optional keyword

---
 mlir/include/mlir/Dialect/Transform/IR/TransformOps.td | 5 +++--
 mlir/lib/Dialect/Transform/IR/TransformOps.cpp         | 6 +++---
 mlir/test/Dialect/Linalg/continuous-tiling-full.mlir   | 2 +-
 3 files changed, 7 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 7a661c663e010..b946fc8875860 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -649,11 +649,12 @@ def ForeachOp : TransformDialectOp<"foreach",
   }];
 
   let arguments = (ins Variadic<Transform_AnyHandleOrParamType>:$targets,
-                       UnitAttr:$zip_shortest);
+                       UnitAttr:$with_zip_shortest);
   let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
   let regions = (region SizedRegion<1>:$body);
   let assemblyFormat =
-    "$targets attr-dict `:` type($targets) (`->` type($results)^)? $body";
+    "$targets oilist(`with_zip_shortest` $with_zip_shortest) `:` "
+    "type($targets) (`->` type($results)^)? $body attr-dict";
   let hasVerifier = 1;
 
   let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 3d6a9a18d9f41..c4238080533be 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1396,11 +1396,11 @@ transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
   SmallVector<SmallVector<MappedValue>> payloads;
   detail::prepareValueMappings(payloads, getTargets(), state);
   size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
-  bool isZipShortest = getZipShortest();
+  bool withZipShortest = getWithZipShortest();
 
   // In case of `zip_shortest`, set the number of iterations to the
   // smallest payload in the targets.
-  if (isZipShortest) {
+  if (withZipShortest) {
     numIterations =
         llvm::min_element(payloads, [&](const SmallVector<MappedValue> &A,
                                         const SmallVector<MappedValue> &B) {
@@ -1414,7 +1414,7 @@ transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
   // As we will be "zipping" over them, check all payloads have the same size.
   // `zip_shortest` adjusts all payloads to the same size, so skip this check
   // when true.
-  for (size_t argIdx = 1; !isZipShortest && argIdx < payloads.size();
+  for (size_t argIdx = 1; !withZipShortest && argIdx < payloads.size();
        argIdx++) {
     if (payloads[argIdx].size() != numIterations) {
       return emitSilenceableError()
diff --git a/mlir/test/Dialect/Linalg/continuous-tiling-full.mlir b/mlir/test/Dialect/Linalg/continuous-tiling-full.mlir
index b61c727f9cffc..7410ff593d01a 100644
--- a/mlir/test/Dialect/Linalg/continuous-tiling-full.mlir
+++ b/mlir/test/Dialect/Linalg/continuous-tiling-full.mlir
@@ -127,7 +127,7 @@ module attributes {transform.with_named_sequence} {
     %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %tile_sizes, %chunk_sizes = transform.structured.continuous_tile_sizes %0 { dimension = 0, target_size = 9 } : (!transform.any_op) -> !transform.any_op
     %linalg_splits, %empty = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.any_op
-    transform.foreach %linalg_splits, %tile_sizes {zip_shortest} : !transform.any_op, !transform.any_op {
+    transform.foreach %linalg_splits, %tile_sizes with_zip_shortest : !transform.any_op, !transform.any_op {
     ^bb1(%linalg_split: !transform.any_op, %tile_size: !transform.any_op):
       %tiled_linalg_split, %dim0_loop = transform.structured.tile_using_for %linalg_split tile_sizes [%tile_size] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
       transform.yield



More information about the Mlir-commits mailing list