[Mlir-commits] [mlir] [mlir][transform] Fix failure in flattening already flattened linalg ops (PR #86037)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Mar 20 16:32:00 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: None (srcarroll)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/86037.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+10-5) 
- (modified) mlir/test/Dialect/Linalg/flatten-elementwise.mlir (+21) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index ae28049f02e391..c93b656f42353c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3269,15 +3269,20 @@ DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
     transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
   rewriter.setInsertionPoint(target);
-  if (target.getNumLoops() <= 1)
+  if (!isElementwise(target)) {
+    failed(rewriter.notifyMatchFailure(
+        target, "only elementwise flattening is supported"));
+    return emitDefaultSilenceableFailure(target);
+  }
+  // If rank <= 1, do nothing
+  if (target.getNumLoops() <= 1) {
+    results.push_back(target);
     return DiagnosedSilenceableFailure::success();
+  }
   ReassociationIndices reassociation(target.getNumLoops());
   std::iota(reassociation.begin(), reassociation.end(), 0);
   auto maybeFlattened =
-      (isElementwise(target))
-          ? collapseOpIterationDims(target, reassociation, rewriter)
-          : FailureOr<CollapseResult>(rewriter.notifyMatchFailure(
-                target, "only elementwise flattening is supported"));
+      collapseOpIterationDims(target, reassociation, rewriter);
   if (failed(maybeFlattened))
     return emitDefaultSilenceableFailure(target);
   results.push_back(maybeFlattened->collapsedOp);
diff --git a/mlir/test/Dialect/Linalg/flatten-elementwise.mlir b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
index 858c133dd536ca..5a27fe76b13411 100644
--- a/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
+++ b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
@@ -67,6 +67,27 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: func.func @map_already_flat(
+// CHECK-SAME:                 %[[ARG0:[a-zA-Z0-9_]*]]: memref<32xf32>
+// CHECK-SAME:                 %[[ARG1:[a-zA-Z0-9_]*]]: memref<32xf32>
+// CHECK-SAME:                 %[[ARG2:[a-zA-Z0-9_]*]]: memref<32xf32>
+// CHECK-NEXT:    linalg.map { arith.addf } ins(%[[ARG0]], %[[ARG1]] : memref<32xf32>, memref<32xf32>) outs(%[[ARG2]] : memref<32xf32>)
+func.func @map_already_flat(%arg0: memref<32xf32>, %arg1: memref<32xf32>, %arg2: memref<32xf32>) {
+    linalg.map {arith.addf} ins(%arg0, %arg1: memref<32xf32>, memref<32xf32>) outs(%arg2: memref<32xf32>)
+    return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %flattened = transform.structured.flatten_elementwise %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
 // CHECK-LABEL: func.func @generic
 // CHECK-SAME:                 %[[ARG0:[a-zA-Z0-9_]*]]: memref<32x7xf32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/86037


More information about the Mlir-commits mailing list