[Mlir-commits] [mlir] [mlir] Makes `zip_shortest` an optional keyword in `transform.foreach` (PR #98492)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jul 11 07:53:48 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Guillermo Callaghan (Guillermo-Callaghan)
<details>
<summary>Changes</summary>
This PR addresses a [comment] made by @<!-- -->ftynse about the syntax for `ForeachOp`. The syntax was modified by @<!-- -->muneebkhan85 in #<!-- -->82792, where the attribute dictionary was moved to the middle.
This patch moves it back to its original place at the end. And introduces an optional keyword for `zip_shortest`.
[comment]: https://github.com/llvm/llvm-project/pull/82792#pullrequestreview-2132814144
---
Full diff: https://github.com/llvm/llvm-project/pull/98492.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/Transform/IR/TransformOps.td (+3-2)
- (modified) mlir/lib/Dialect/Transform/IR/TransformOps.cpp (+3-3)
- (modified) mlir/test/Dialect/Linalg/continuous-tiling-full.mlir (+1-1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 7a661c663e010..b946fc8875860 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -649,11 +649,12 @@ def ForeachOp : TransformDialectOp<"foreach",
}];
let arguments = (ins Variadic<Transform_AnyHandleOrParamType>:$targets,
- UnitAttr:$zip_shortest);
+ UnitAttr:$with_zip_shortest);
let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
let regions = (region SizedRegion<1>:$body);
let assemblyFormat =
- "$targets attr-dict `:` type($targets) (`->` type($results)^)? $body";
+ "$targets oilist(`with_zip_shortest` $with_zip_shortest) `:` "
+ "type($targets) (`->` type($results)^)? $body attr-dict";
let hasVerifier = 1;
let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 3d6a9a18d9f41..c4238080533be 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1396,11 +1396,11 @@ transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
SmallVector<SmallVector<MappedValue>> payloads;
detail::prepareValueMappings(payloads, getTargets(), state);
size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
- bool isZipShortest = getZipShortest();
+ bool withZipShortest = getWithZipShortest();
// In case of `zip_shortest`, set the number of iterations to the
// smallest payload in the targets.
- if (isZipShortest) {
+ if (withZipShortest) {
numIterations =
llvm::min_element(payloads, [&](const SmallVector<MappedValue> &A,
const SmallVector<MappedValue> &B) {
@@ -1414,7 +1414,7 @@ transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
// As we will be "zipping" over them, check all payloads have the same size.
// `zip_shortest` adjusts all payloads to the same size, so skip this check
// when true.
- for (size_t argIdx = 1; !isZipShortest && argIdx < payloads.size();
+ for (size_t argIdx = 1; !withZipShortest && argIdx < payloads.size();
argIdx++) {
if (payloads[argIdx].size() != numIterations) {
return emitSilenceableError()
diff --git a/mlir/test/Dialect/Linalg/continuous-tiling-full.mlir b/mlir/test/Dialect/Linalg/continuous-tiling-full.mlir
index b61c727f9cffc..7410ff593d01a 100644
--- a/mlir/test/Dialect/Linalg/continuous-tiling-full.mlir
+++ b/mlir/test/Dialect/Linalg/continuous-tiling-full.mlir
@@ -127,7 +127,7 @@ module attributes {transform.with_named_sequence} {
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%tile_sizes, %chunk_sizes = transform.structured.continuous_tile_sizes %0 { dimension = 0, target_size = 9 } : (!transform.any_op) -> !transform.any_op
%linalg_splits, %empty = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.any_op
- transform.foreach %linalg_splits, %tile_sizes {zip_shortest} : !transform.any_op, !transform.any_op {
+ transform.foreach %linalg_splits, %tile_sizes with_zip_shortest : !transform.any_op, !transform.any_op {
^bb1(%linalg_split: !transform.any_op, %tile_size: !transform.any_op):
%tiled_linalg_split, %dim0_loop = transform.structured.tile_using_for %linalg_split tile_sizes [%tile_size] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
``````````
</details>
https://github.com/llvm/llvm-project/pull/98492
More information about the Mlir-commits
mailing list