[Mlir-commits] [mlir] [mlir][linalg] Allow fusing reshapes with nonparallel operands (PR #130148)
Ian Wood
llvmlistbot at llvm.org
Thu Mar 6 18:16:37 PST 2025
https://github.com/IanWood1 updated https://github.com/llvm/llvm-project/pull/130148
>From f31cf5aae3fb03d2afbdbef5456717e83e8b209e Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood2024 at u.northwestern.edu>
Date: Thu, 6 Mar 2025 21:10:21 -0800
Subject: [PATCH 1/2] [mlir][linalg] Allow fusing reshapes with parallel
operands
Signed-off-by: Ian Wood <ianwood2024 at u.northwestern.edu>
---
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp | 7 +------
1 file changed, 1 insertion(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index a45b5c43f5d33..337fd8f3a0ac1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -566,7 +566,6 @@ static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
// - All the indexing maps for operands and results are projected
// permutations.
// - The fused tensor is not a scalar.
- // - All the loops for the reshaped operand are parallel loops.
SmallVector<utils::IteratorType> iteratorTypes =
linalgOp.getIteratorTypesArray();
AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
@@ -577,11 +576,7 @@ static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
.getValue()
.isProjectedPermutation();
}) &&
- operandMap.getNumResults() > 0 &&
- llvm::all_of(operandMap.getResults(), [&](AffineExpr expr) {
- return isParallelIterator(
- iteratorTypes[cast<AffineDimExpr>(expr).getPosition()]);
- });
+ operandMap.getNumResults() > 0;
}
namespace {
>From 0201daa3e0773993c3e5df21209971f2dc95b28e Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood2024 at u.northwestern.edu>
Date: Fri, 7 Mar 2025 06:12:25 -0800
Subject: [PATCH 2/2] Add test to check reduction reshape fusion
Signed-off-by: Ian Wood <ianwood2024 at u.northwestern.edu>
---
mlir/test/Dialect/Linalg/reshape_fusion.mlir | 25 ++++++++++++++++++++
1 file changed, 25 insertions(+)
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 4da9c0851ac70..c8720ebd98c09 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -482,6 +482,31 @@ func.func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor<?x?xf32>,
// -----
+func.func @fuse_collapse_reduction(%arg0: tensor<10x10x20xf32>) -> tensor<100xf32> {
+ %c0 = arith.constant 0 : index
+ %c_0 = arith.constant 0.0 : f32
+ %0 = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<10x10x20xf32> into tensor<100x20xf32>
+ %2 = tensor.empty() : tensor<100xf32>
+ %3 = linalg.fill ins(%c_0 : f32) outs(%2 : tensor<100xf32>) -> tensor<100xf32>
+ %4 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"]}
+ ins(%0 : tensor<100x20xf32>) outs(%3 : tensor<100xf32>) {
+ ^bb0(%arg1 : f32, %arg2: f32):
+ %4 = arith.addf %arg1, %arg2 : f32
+ linalg.yield %4 : f32
+ } -> tensor<100xf32>
+ return %4 : tensor<100xf32>
+}
+
+// CHECK: func @fuse_collapse_reduction
+// CHECK-SAME: %[[ARG0:.+]]: tensor<10x10x20xf32>
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0]] : tensor<10x10x20xf32>)
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[GENERIC]]
+// CHECK: return %[[COLLAPSE]]
+// -----
+
func.func @no_fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
%c0 = arith.constant 0 : index
%0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<?x?xf32> into tensor<?xf32>
More information about the Mlir-commits
mailing list