[Mlir-commits] [mlir] [mlir][linalg] Allow fusing reshapes with nonparallel operands (PR #130148)

Ian Wood llvmlistbot at llvm.org
Thu Mar 6 18:16:37 PST 2025


https://github.com/IanWood1 updated https://github.com/llvm/llvm-project/pull/130148

>From f31cf5aae3fb03d2afbdbef5456717e83e8b209e Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood2024 at u.northwestern.edu>
Date: Thu, 6 Mar 2025 21:10:21 -0800
Subject: [PATCH 1/2] [mlir][linalg] Allow fusing reshapes with parallel
 operands

Signed-off-by: Ian Wood <ianwood2024 at u.northwestern.edu>
---
 mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp | 7 +------
 1 file changed, 1 insertion(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index a45b5c43f5d33..337fd8f3a0ac1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -566,7 +566,6 @@ static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
   // - All the indexing maps for operands and results are projected
   //   permutations.
   // - The fused tensor is not a scalar.
-  // - All the loops for the reshaped operand are parallel loops.
   SmallVector<utils::IteratorType> iteratorTypes =
       linalgOp.getIteratorTypesArray();
   AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
@@ -577,11 +576,7 @@ static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
                             .getValue()
                             .isProjectedPermutation();
                       }) &&
-         operandMap.getNumResults() > 0 &&
-         llvm::all_of(operandMap.getResults(), [&](AffineExpr expr) {
-           return isParallelIterator(
-               iteratorTypes[cast<AffineDimExpr>(expr).getPosition()]);
-         });
+         operandMap.getNumResults() > 0;
 }
 
 namespace {

>From 0201daa3e0773993c3e5df21209971f2dc95b28e Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood2024 at u.northwestern.edu>
Date: Fri, 7 Mar 2025 06:12:25 -0800
Subject: [PATCH 2/2] Add test to check reduction reshape fusion

Signed-off-by: Ian Wood <ianwood2024 at u.northwestern.edu>
---
 mlir/test/Dialect/Linalg/reshape_fusion.mlir | 25 ++++++++++++++++++++
 1 file changed, 25 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 4da9c0851ac70..c8720ebd98c09 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -482,6 +482,31 @@ func.func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor<?x?xf32>,
 
 // -----
 
+func.func @fuse_collapse_reduction(%arg0: tensor<10x10x20xf32>) -> tensor<100xf32> {
+  %c0 = arith.constant 0 : index
+  %c_0 = arith.constant 0.0 : f32
+  %0 = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<10x10x20xf32> into tensor<100x20xf32>
+  %2 = tensor.empty() : tensor<100xf32>
+  %3 = linalg.fill ins(%c_0 : f32) outs(%2 : tensor<100xf32>) -> tensor<100xf32>
+  %4 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
+    iterator_types = ["parallel", "reduction"]}
+    ins(%0 : tensor<100x20xf32>) outs(%3 : tensor<100xf32>) {
+      ^bb0(%arg1 : f32, %arg2: f32):
+        %4 = arith.addf %arg1, %arg2 : f32
+        linalg.yield %4 : f32
+    } -> tensor<100xf32>
+  return %4 : tensor<100xf32>
+}
+
+//      CHECK: func @fuse_collapse_reduction
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<10x10x20xf32>
+//      CHECK:   %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME:       ins(%[[ARG0]] : tensor<10x10x20xf32>)
+//      CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[GENERIC]]
+//      CHECK:   return %[[COLLAPSE]]
+// -----
+
 func.func @no_fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
   %c0 = arith.constant 0 : index
   %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<?x?xf32> into tensor<?xf32>



More information about the Mlir-commits mailing list