[Mlir-commits] [mlir] [mlir][transform] Fix failure in flattening already flattened linalg ops (PR #86037)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Mar 20 19:53:06 PDT 2024


https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/86037

>From 127b99832b0051954ee0cd0247a0735e74126df7 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 20 Mar 2024 18:30:30 -0500
Subject: [PATCH] Fix failure in flattening already flattened linalg ops

---
 .../TransformOps/LinalgTransformOps.cpp       | 15 ++++++++-----
 .../Dialect/Linalg/flatten-elementwise.mlir   | 21 +++++++++++++++++++
 2 files changed, 31 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index ae28049f02e391..c93b656f42353c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3269,15 +3269,20 @@ DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
     transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
   rewriter.setInsertionPoint(target);
-  if (target.getNumLoops() <= 1)
+  if (!isElementwise(target)) {
+    failed(rewriter.notifyMatchFailure(
+        target, "only elementwise flattening is supported"));
+    return emitDefaultSilenceableFailure(target);
+  }
+  // If rank <= 1, do nothing
+  if (target.getNumLoops() <= 1) {
+    results.push_back(target);
     return DiagnosedSilenceableFailure::success();
+  }
   ReassociationIndices reassociation(target.getNumLoops());
   std::iota(reassociation.begin(), reassociation.end(), 0);
   auto maybeFlattened =
-      (isElementwise(target))
-          ? collapseOpIterationDims(target, reassociation, rewriter)
-          : FailureOr<CollapseResult>(rewriter.notifyMatchFailure(
-                target, "only elementwise flattening is supported"));
+      collapseOpIterationDims(target, reassociation, rewriter);
   if (failed(maybeFlattened))
     return emitDefaultSilenceableFailure(target);
   results.push_back(maybeFlattened->collapsedOp);
diff --git a/mlir/test/Dialect/Linalg/flatten-elementwise.mlir b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
index 858c133dd536ca..5a27fe76b13411 100644
--- a/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
+++ b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
@@ -67,6 +67,27 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: func.func @map_already_flat(
+// CHECK-SAME:                 %[[ARG0:[a-zA-Z0-9_]*]]: memref<32xf32>
+// CHECK-SAME:                 %[[ARG1:[a-zA-Z0-9_]*]]: memref<32xf32>
+// CHECK-SAME:                 %[[ARG2:[a-zA-Z0-9_]*]]: memref<32xf32>
+// CHECK-NEXT:    linalg.map { arith.addf } ins(%[[ARG0]], %[[ARG1]] : memref<32xf32>, memref<32xf32>) outs(%[[ARG2]] : memref<32xf32>)
+func.func @map_already_flat(%arg0: memref<32xf32>, %arg1: memref<32xf32>, %arg2: memref<32xf32>) {
+    linalg.map {arith.addf} ins(%arg0, %arg1: memref<32xf32>, memref<32xf32>) outs(%arg2: memref<32xf32>)
+    return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %flattened = transform.structured.flatten_elementwise %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
 // CHECK-LABEL: func.func @generic
 // CHECK-SAME:                 %[[ARG0:[a-zA-Z0-9_]*]]: memref<32x7xf32>



More information about the Mlir-commits mailing list