[Mlir-commits] [mlir] df9ed9c - [mlir][transform] Fix failure in flattening already flattened linalg ops (#86037)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Mar 20 22:25:10 PDT 2024


Author: srcarroll
Date: 2024-03-21T00:25:07-05:00
New Revision: df9ed9cf52f82aed023adc968ca2a0e7f7cccc69

URL: https://github.com/llvm/llvm-project/commit/df9ed9cf52f82aed023adc968ca2a0e7f7cccc69
DIFF: https://github.com/llvm/llvm-project/commit/df9ed9cf52f82aed023adc968ca2a0e7f7cccc69.diff

LOG: [mlir][transform] Fix failure in flattening already flattened linalg ops (#86037)

The previous implementation was doing an early successful return on
`rank <= 1` without adding the original op to transform results. This
resulted in errors about number of returns. This patch fixes this by
adding the original op to results. Additionally, we first check if op is
elementwise and return a slienceable failure early if not.

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/test/Dialect/Linalg/flatten-elementwise.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index d82a6beb1086e5..ecf9983124821a 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3269,15 +3269,20 @@ DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
     transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
   rewriter.setInsertionPoint(target);
-  if (target.getNumLoops() <= 1)
+  if (!isElementwise(target)) {
+    failed(rewriter.notifyMatchFailure(
+        target, "only elementwise flattening is supported"));
+    return emitDefaultSilenceableFailure(target);
+  }
+  // If rank <= 1, do nothing
+  if (target.getNumLoops() <= 1) {
+    results.push_back(target);
     return DiagnosedSilenceableFailure::success();
+  }
   ReassociationIndices reassociation(target.getNumLoops());
   std::iota(reassociation.begin(), reassociation.end(), 0);
   auto maybeFlattened =
-      (isElementwise(target))
-          ? collapseOpIterationDims(target, reassociation, rewriter)
-          : FailureOr<CollapseResult>(rewriter.notifyMatchFailure(
-                target, "only elementwise flattening is supported"));
+      collapseOpIterationDims(target, reassociation, rewriter);
   if (failed(maybeFlattened))
     return emitDefaultSilenceableFailure(target);
   results.push_back(maybeFlattened->collapsedOp);

diff  --git a/mlir/test/Dialect/Linalg/flatten-elementwise.mlir b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
index 858c133dd536ca..5a27fe76b13411 100644
--- a/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
+++ b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
@@ -67,6 +67,27 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: func.func @map_already_flat(
+// CHECK-SAME:                 %[[ARG0:[a-zA-Z0-9_]*]]: memref<32xf32>
+// CHECK-SAME:                 %[[ARG1:[a-zA-Z0-9_]*]]: memref<32xf32>
+// CHECK-SAME:                 %[[ARG2:[a-zA-Z0-9_]*]]: memref<32xf32>
+// CHECK-NEXT:    linalg.map { arith.addf } ins(%[[ARG0]], %[[ARG1]] : memref<32xf32>, memref<32xf32>) outs(%[[ARG2]] : memref<32xf32>)
+func.func @map_already_flat(%arg0: memref<32xf32>, %arg1: memref<32xf32>, %arg2: memref<32xf32>) {
+    linalg.map {arith.addf} ins(%arg0, %arg1: memref<32xf32>, memref<32xf32>) outs(%arg2: memref<32xf32>)
+    return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %flattened = transform.structured.flatten_elementwise %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
 // CHECK-LABEL: func.func @generic
 // CHECK-SAME:                 %[[ARG0:[a-zA-Z0-9_]*]]: memref<32x7xf32>


        


More information about the Mlir-commits mailing list