[Mlir-commits] [mlir] [mlir] Allow multi-result ops in reshape fusion (PR #108576)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 13 07:47:50 PDT 2024


https://github.com/Max191 created https://github.com/llvm/llvm-project/pull/108576

Fusion of reshapes by collapsing patterns were restricted to single result operations, but the implementation supports multi result ops. This PR removes the restriction, since it is not necessary.

>From 848a59d97d5dd6113bee09df431496971bcdeb1d Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Fri, 13 Sep 2024 09:56:50 -0400
Subject: [PATCH] [mlir] Allow multi-result ops in reshape fusion

---
 .../Linalg/Transforms/ElementwiseOpFusion.cpp |  2 +-
 .../fuse-with-reshape-by-collapsing.mlir      | 42 +++++++++++--------
 2 files changed, 25 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index c818675993c2c3..a934e47794051c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1254,7 +1254,7 @@ static SmallVector<ReassociationIndices>
 getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
                                  ArrayRef<ReassociationIndices> reassociation) {
   // Some basic checks for this fusion to be valid.
-  if (!genericOp.hasPureTensorSemantics() || genericOp.getNumDpsInits() != 1)
+  if (!genericOp.hasPureTensorSemantics())
     return {};
 
   if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) {
diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
index 600f0dea31f4a8..f17881d59a266e 100644
--- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
+++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
@@ -7,49 +7,55 @@
 #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2)>
 #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6)>
 #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
+#map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d0, d7, d3, d4, d5, d6)>
 func.func @fuse_by_collapsing(%arg0 : tensor<2x12x5x336x9xi32>,
-    %arg1 : tensor<2x3x4xi32>, %arg2 : tensor<5x6x7x8xi32>) -> tensor<2x3x4x5x6x7x8x9xi32> {
+    %arg1 : tensor<2x3x4xi32>, %arg2 : tensor<5x6x7x8xi32>) -> (tensor<2x3x4x5x6x7x8x9xi32>, tensor<3x4x2x9x5x6x7x8xi32>) {
   %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32>
-  %init = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32>
-  %generic = linalg.generic {
-    indexing_maps = [#map0, #map1, #map2, #map3],
+  %init_0 = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32>
+  %init_1 = tensor.empty() : tensor<3x4x2x9x5x6x7x8xi32>
+  %generic:2 = linalg.generic {
+    indexing_maps = [#map0, #map1, #map2, #map3, #map4],
     iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]}
     ins(%expand, %arg1, %arg2 : tensor<2x3x4x5x6x7x8x9xi32>, tensor<2x3x4xi32>, tensor<5x6x7x8xi32>)
-    outs(%init : tensor<2x3x4x5x6x7x8x9xi32>) {
-      ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32):
+    outs(%init_0, %init_1 : tensor<2x3x4x5x6x7x8x9xi32>, tensor<3x4x2x9x5x6x7x8xi32>) {
+      ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, %b4 : i32):
         %t0 = arith.addi %b0, %b1 : i32
         %t1 = arith.addi %t0, %b2 : i32
-        linalg.yield %t1 : i32
-    } -> tensor<2x3x4x5x6x7x8x9xi32>
-  return %generic : tensor<2x3x4x5x6x7x8x9xi32>
+        linalg.yield %t1, %t1 : i32, i32
+    } -> (tensor<2x3x4x5x6x7x8x9xi32>, tensor<3x4x2x9x5x6x7x8xi32>)
+  return %generic#0, %generic#1 : tensor<2x3x4x5x6x7x8x9xi32>, tensor<3x4x2x9x5x6x7x8xi32>
 }
 //  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>
 //  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d0, d4, d2, d3)>
 //      CHECK: func @fuse_by_collapsing(
 // CHECK-SAME:   %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>
 // CHECK-SAME:   %[[ARG1:.+]]: tensor<2x3x4xi32>
 // CHECK-SAME:   %[[ARG2:.+]]: tensor<5x6x7x8xi32>
-//  CHECK-DAG:   %[[INIT:.+]] = tensor.empty()
+//  CHECK-DAG:   %[[INIT0:.+]] = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32>
+//  CHECK-DAG:   %[[INIT1:.+]] = tensor.empty() : tensor<3x4x2x9x5x6x7x8xi32>
 //  CHECK-DAG:   %[[ARG1_RESHAPE:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1, 2]{{\]}}
 //  CHECK-DAG:   %[[ARG2_RESHAPE:.+]] = tensor.collapse_shape %[[ARG2]] {{\[}}[0], [1, 2, 3]{{\]}}
-//  CHECK-DAG:   %[[INIT_RESHAPE:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}}
-//      CHECK:   %[[COLLAPSED_OP:.+]] = linalg.generic
-// CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]]]
+//  CHECK-DAG:   %[[INIT0_RESHAPE:.+]] = tensor.collapse_shape %[[INIT0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}}
+//  CHECK-DAG:   %[[INIT1_RESHAPE:.+]] = tensor.collapse_shape %[[INIT1]] {{\[}}[0, 1], [2], [3], [4], [5, 6, 7]{{\]}}
+//      CHECK:   %[[COLLAPSED_OP:.+]]:2 = linalg.generic
+// CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]], #[[MAP3]]]
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
 // CHECK-SAME:       ins(%[[ARG0]], %[[ARG1_RESHAPE]], %[[ARG2_RESHAPE]] :
-// CHECK-SAME:       outs(%[[INIT_RESHAPE]] :
-//      CHECK:   %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[COLLAPSED_OP]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}} output_shape [2, 3, 4, 5, 6, 7, 8, 9]
-//      CHECK:   return %[[RESULT_RESHAPE]]
+// CHECK-SAME:       outs(%[[INIT0_RESHAPE]], %[[INIT1_RESHAPE]] :
+//      CHECK:   %[[RESULT0_RESHAPE:.+]] = tensor.expand_shape %[[COLLAPSED_OP]]#0 {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}} output_shape [2, 3, 4, 5, 6, 7, 8, 9]
+//      CHECK:   %[[RESULT1_RESHAPE:.+]] = tensor.expand_shape %[[COLLAPSED_OP]]#1 {{\[}}[0, 1], [2], [3], [4], [5, 6, 7]{{\]}} output_shape [3, 4, 2, 9, 5, 6, 7, 8]
+//      CHECK:   return %[[RESULT0_RESHAPE]], %[[RESULT1_RESHAPE]]
 
 //      CONTROL: func @fuse_by_collapsing(
 // CONTROL-SAME:   %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>
 // CONTROL-SAME:   %[[ARG1:.+]]: tensor<2x3x4xi32>
 // CONTROL-SAME:   %[[ARG2:.+]]: tensor<5x6x7x8xi32>
 //      CONTROL:   %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]]
-//      CONTROL:   %[[GENERIC:.+]] = linalg.generic
+//      CONTROL:   %[[GENERIC:.+]]:2 = linalg.generic
 // CONTROL-SAME:       ins(%[[EXPAND]],
-//      CONTROL:   return %[[GENERIC]]
+//      CONTROL:   return %[[GENERIC]]#0, %[[GENERIC]]#1
 
 // -----
 



More information about the Mlir-commits mailing list