[Mlir-commits] [mlir] [mlir][transform] Fix failure in flattening already flattened linalg ops (PR #86037)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Mar 20 19:53:06 PDT 2024
https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/86037
>From 127b99832b0051954ee0cd0247a0735e74126df7 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 20 Mar 2024 18:30:30 -0500
Subject: [PATCH] Fix failure in flattening already flattened linalg ops
---
.../TransformOps/LinalgTransformOps.cpp | 15 ++++++++-----
.../Dialect/Linalg/flatten-elementwise.mlir | 21 +++++++++++++++++++
2 files changed, 31 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index ae28049f02e391..c93b656f42353c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3269,15 +3269,20 @@ DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
rewriter.setInsertionPoint(target);
- if (target.getNumLoops() <= 1)
+ if (!isElementwise(target)) {
+ failed(rewriter.notifyMatchFailure(
+ target, "only elementwise flattening is supported"));
+ return emitDefaultSilenceableFailure(target);
+ }
+ // If rank <= 1, do nothing
+ if (target.getNumLoops() <= 1) {
+ results.push_back(target);
return DiagnosedSilenceableFailure::success();
+ }
ReassociationIndices reassociation(target.getNumLoops());
std::iota(reassociation.begin(), reassociation.end(), 0);
auto maybeFlattened =
- (isElementwise(target))
- ? collapseOpIterationDims(target, reassociation, rewriter)
- : FailureOr<CollapseResult>(rewriter.notifyMatchFailure(
- target, "only elementwise flattening is supported"));
+ collapseOpIterationDims(target, reassociation, rewriter);
if (failed(maybeFlattened))
return emitDefaultSilenceableFailure(target);
results.push_back(maybeFlattened->collapsedOp);
diff --git a/mlir/test/Dialect/Linalg/flatten-elementwise.mlir b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
index 858c133dd536ca..5a27fe76b13411 100644
--- a/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
+++ b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
@@ -67,6 +67,27 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: func.func @map_already_flat(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<32xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<32xf32>
+// CHECK-NEXT: linalg.map { arith.addf } ins(%[[ARG0]], %[[ARG1]] : memref<32xf32>, memref<32xf32>) outs(%[[ARG2]] : memref<32xf32>)
+func.func @map_already_flat(%arg0: memref<32xf32>, %arg1: memref<32xf32>, %arg2: memref<32xf32>) {
+ linalg.map {arith.addf} ins(%arg0, %arg1: memref<32xf32>, memref<32xf32>) outs(%arg2: memref<32xf32>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %flattened = transform.structured.flatten_elementwise %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: func.func @generic
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<32x7xf32>
More information about the Mlir-commits
mailing list