[Mlir-commits] [mlir] [mlir] Allow multi-result ops in reshape fusion (PR #108576)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Sep 13 07:47:50 PDT 2024
https://github.com/Max191 created https://github.com/llvm/llvm-project/pull/108576
Fusion of reshapes by collapsing patterns were restricted to single result operations, but the implementation supports multi result ops. This PR removes the restriction, since it is not necessary.
>From 848a59d97d5dd6113bee09df431496971bcdeb1d Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Fri, 13 Sep 2024 09:56:50 -0400
Subject: [PATCH] [mlir] Allow multi-result ops in reshape fusion
---
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 2 +-
.../fuse-with-reshape-by-collapsing.mlir | 42 +++++++++++--------
2 files changed, 25 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index c818675993c2c3..a934e47794051c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1254,7 +1254,7 @@ static SmallVector<ReassociationIndices>
getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
ArrayRef<ReassociationIndices> reassociation) {
// Some basic checks for this fusion to be valid.
- if (!genericOp.hasPureTensorSemantics() || genericOp.getNumDpsInits() != 1)
+ if (!genericOp.hasPureTensorSemantics())
return {};
if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) {
diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
index 600f0dea31f4a8..f17881d59a266e 100644
--- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
+++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
@@ -7,49 +7,55 @@
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6)>
#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
+#map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d0, d7, d3, d4, d5, d6)>
func.func @fuse_by_collapsing(%arg0 : tensor<2x12x5x336x9xi32>,
- %arg1 : tensor<2x3x4xi32>, %arg2 : tensor<5x6x7x8xi32>) -> tensor<2x3x4x5x6x7x8x9xi32> {
+ %arg1 : tensor<2x3x4xi32>, %arg2 : tensor<5x6x7x8xi32>) -> (tensor<2x3x4x5x6x7x8x9xi32>, tensor<3x4x2x9x5x6x7x8xi32>) {
%expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32>
- %init = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32>
- %generic = linalg.generic {
- indexing_maps = [#map0, #map1, #map2, #map3],
+ %init_0 = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32>
+ %init_1 = tensor.empty() : tensor<3x4x2x9x5x6x7x8xi32>
+ %generic:2 = linalg.generic {
+ indexing_maps = [#map0, #map1, #map2, #map3, #map4],
iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]}
ins(%expand, %arg1, %arg2 : tensor<2x3x4x5x6x7x8x9xi32>, tensor<2x3x4xi32>, tensor<5x6x7x8xi32>)
- outs(%init : tensor<2x3x4x5x6x7x8x9xi32>) {
- ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32):
+ outs(%init_0, %init_1 : tensor<2x3x4x5x6x7x8x9xi32>, tensor<3x4x2x9x5x6x7x8xi32>) {
+ ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, %b4 : i32):
%t0 = arith.addi %b0, %b1 : i32
%t1 = arith.addi %t0, %b2 : i32
- linalg.yield %t1 : i32
- } -> tensor<2x3x4x5x6x7x8x9xi32>
- return %generic : tensor<2x3x4x5x6x7x8x9xi32>
+ linalg.yield %t1, %t1 : i32, i32
+ } -> (tensor<2x3x4x5x6x7x8x9xi32>, tensor<3x4x2x9x5x6x7x8xi32>)
+ return %generic#0, %generic#1 : tensor<2x3x4x5x6x7x8x9xi32>, tensor<3x4x2x9x5x6x7x8xi32>
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d0, d4, d2, d3)>
// CHECK: func @fuse_by_collapsing(
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<2x3x4xi32>
// CHECK-SAME: %[[ARG2:.+]]: tensor<5x6x7x8xi32>
-// CHECK-DAG: %[[INIT:.+]] = tensor.empty()
+// CHECK-DAG: %[[INIT0:.+]] = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32>
+// CHECK-DAG: %[[INIT1:.+]] = tensor.empty() : tensor<3x4x2x9x5x6x7x8xi32>
// CHECK-DAG: %[[ARG1_RESHAPE:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1, 2]{{\]}}
// CHECK-DAG: %[[ARG2_RESHAPE:.+]] = tensor.collapse_shape %[[ARG2]] {{\[}}[0], [1, 2, 3]{{\]}}
-// CHECK-DAG: %[[INIT_RESHAPE:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}}
-// CHECK: %[[COLLAPSED_OP:.+]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]]]
+// CHECK-DAG: %[[INIT0_RESHAPE:.+]] = tensor.collapse_shape %[[INIT0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}}
+// CHECK-DAG: %[[INIT1_RESHAPE:.+]] = tensor.collapse_shape %[[INIT1]] {{\[}}[0, 1], [2], [3], [4], [5, 6, 7]{{\]}}
+// CHECK: %[[COLLAPSED_OP:.+]]:2 = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]], #[[MAP3]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1_RESHAPE]], %[[ARG2_RESHAPE]] :
-// CHECK-SAME: outs(%[[INIT_RESHAPE]] :
-// CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[COLLAPSED_OP]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}} output_shape [2, 3, 4, 5, 6, 7, 8, 9]
-// CHECK: return %[[RESULT_RESHAPE]]
+// CHECK-SAME: outs(%[[INIT0_RESHAPE]], %[[INIT1_RESHAPE]] :
+// CHECK: %[[RESULT0_RESHAPE:.+]] = tensor.expand_shape %[[COLLAPSED_OP]]#0 {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}} output_shape [2, 3, 4, 5, 6, 7, 8, 9]
+// CHECK: %[[RESULT1_RESHAPE:.+]] = tensor.expand_shape %[[COLLAPSED_OP]]#1 {{\[}}[0, 1], [2], [3], [4], [5, 6, 7]{{\]}} output_shape [3, 4, 2, 9, 5, 6, 7, 8]
+// CHECK: return %[[RESULT0_RESHAPE]], %[[RESULT1_RESHAPE]]
// CONTROL: func @fuse_by_collapsing(
// CONTROL-SAME: %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>
// CONTROL-SAME: %[[ARG1:.+]]: tensor<2x3x4xi32>
// CONTROL-SAME: %[[ARG2:.+]]: tensor<5x6x7x8xi32>
// CONTROL: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]]
-// CONTROL: %[[GENERIC:.+]] = linalg.generic
+// CONTROL: %[[GENERIC:.+]]:2 = linalg.generic
// CONTROL-SAME: ins(%[[EXPAND]],
-// CONTROL: return %[[GENERIC]]
+// CONTROL: return %[[GENERIC]]#0, %[[GENERIC]]#1
// -----
More information about the Mlir-commits
mailing list