[Mlir-commits] [mlir] [mlir][transform] Fix failure in flattening already flattened linalg ops (PR #86037)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Mar 20 16:32:00 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: None (srcarroll)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/86037.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+10-5)
- (modified) mlir/test/Dialect/Linalg/flatten-elementwise.mlir (+21)
``````````diff
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index ae28049f02e391..c93b656f42353c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3269,15 +3269,20 @@ DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
rewriter.setInsertionPoint(target);
- if (target.getNumLoops() <= 1)
+ if (!isElementwise(target)) {
+ failed(rewriter.notifyMatchFailure(
+ target, "only elementwise flattening is supported"));
+ return emitDefaultSilenceableFailure(target);
+ }
+ // If rank <= 1, do nothing
+ if (target.getNumLoops() <= 1) {
+ results.push_back(target);
return DiagnosedSilenceableFailure::success();
+ }
ReassociationIndices reassociation(target.getNumLoops());
std::iota(reassociation.begin(), reassociation.end(), 0);
auto maybeFlattened =
- (isElementwise(target))
- ? collapseOpIterationDims(target, reassociation, rewriter)
- : FailureOr<CollapseResult>(rewriter.notifyMatchFailure(
- target, "only elementwise flattening is supported"));
+ collapseOpIterationDims(target, reassociation, rewriter);
if (failed(maybeFlattened))
return emitDefaultSilenceableFailure(target);
results.push_back(maybeFlattened->collapsedOp);
diff --git a/mlir/test/Dialect/Linalg/flatten-elementwise.mlir b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
index 858c133dd536ca..5a27fe76b13411 100644
--- a/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
+++ b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
@@ -67,6 +67,27 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: func.func @map_already_flat(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<32xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<32xf32>
+// CHECK-NEXT: linalg.map { arith.addf } ins(%[[ARG0]], %[[ARG1]] : memref<32xf32>, memref<32xf32>) outs(%[[ARG2]] : memref<32xf32>)
+func.func @map_already_flat(%arg0: memref<32xf32>, %arg1: memref<32xf32>, %arg2: memref<32xf32>) {
+ linalg.map {arith.addf} ins(%arg0, %arg1: memref<32xf32>, memref<32xf32>) outs(%arg2: memref<32xf32>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %flattened = transform.structured.flatten_elementwise %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: func.func @generic
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<32x7xf32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/86037
More information about the Mlir-commits
mailing list