[Mlir-commits] [mlir] [mlir][linalg] Add transpose support for reshape as consumer fusion (PR #130344)
Nirvedh Meshram
llvmlistbot at llvm.org
Mon Mar 10 18:04:46 PDT 2025
https://github.com/nirvedhmeshram updated https://github.com/llvm/llvm-project/pull/130344
>From 14cb2327a55916d035de9cf198338237b0cf80bb Mon Sep 17 00:00:00 2001
From: Nirvedh Meshram <nirvedh at gmail.com>
Date: Thu, 6 Mar 2025 21:52:59 -0600
Subject: [PATCH 1/2] [mlir][linalg] Add transpose support for reshape as
consumer fusion
Signed-off-by: Nirvedh Meshram <nirvedh at gmail.com>
---
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 50 +++++++++++--
mlir/test/Dialect/Linalg/reshape_fusion.mlir | 75 +++++++++++--------
2 files changed, 88 insertions(+), 37 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index a45b5c43f5d33..222fdd40ca12f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -816,17 +816,51 @@ validateDynamicDimExpansion(LinalgOp linalgOp,
}
// Create an expanded transpose op.
+// For bubbling a collapse : transpose(collapse_shape),
+// all expanded groups are permuted together. We just permute the reassocation
+// map of the collapse and flatten it. For example,
+//
+// reassociation_map = [[0], [1, 2, 3], [4, 5]]
+// permutation = [2, 0, 1]
+//
+// Becomes
+//
+// permutation = [4, 5, 0 , 1, 2, 3]
+//
+// For sinking expand : expand_shape(transpose),
+// the reassociation map is already permuted hence we inverse permute and then
+// flatten it. Then we inverse permute it again to get the final expanded
+// transpose permutation. For example,
+//
+// permutation = [2, 0, 1]
+// reassociation_map = [[0, 1], [2], [3, 4, 5]]
+//
+// inverse permutation = [1, 2, 0]
+// applied to reassocation_map and then flattened becomes
+// flatened permutation = [2, 3, 4, 5, 0, 1]
+// final permuation is the inverse of the flattened permutation.
+//
+// Becomes
+//
+// permutation=[4, 5, 0, 1, 2, 3]
+
static Operation *
createExpandedTransposeOp(PatternRewriter &rewriter, TransposeOp transposeOp,
SmallVector<ReassociationIndices> reassociation,
- Value expandedInput, Value output) {
- applyPermutationToVector(reassociation, transposeOp.getPermutation());
+ Value expandedInput, Value output, bool isExpanding) {
+ ArrayRef<int64_t> permutation =
+ isExpanding ? invertPermutationVector(transposeOp.getPermutation())
+ : transposeOp.getPermutation();
+ applyPermutationToVector(reassociation, permutation);
SmallVector<int64_t> newPerm;
for (auto reassoc : reassociation) {
for (auto dim : reassoc) {
newPerm.push_back(dim);
}
}
+ if (isExpanding) {
+ newPerm = invertPermutationVector(newPerm);
+ }
return rewriter.create<TransposeOp>(transposeOp.getLoc(), expandedInput,
output, newPerm);
}
@@ -866,12 +900,13 @@ static Operation *createExpandedOp(
PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
ArrayRef<Value> expandedOpOperands, ArrayRef<Value> outputs,
ArrayRef<AffineMap> expandedOpIndexingMaps, ExpansionInfo &expansionInfo,
- SmallVector<ReassociationIndices> reassociation) {
+ SmallVector<ReassociationIndices> reassociation, bool isExpanding) {
return TypeSwitch<Operation *, Operation *>(linalgOp.getOperation())
.Case<TransposeOp>([&](TransposeOp transposeOp) {
return createExpandedTransposeOp(rewriter, transposeOp, reassociation,
- expandedOpOperands[0], outputs[0]);
+ expandedOpOperands[0], outputs[0],
+ isExpanding);
})
.Case<FillOp, CopyOp>([&](Operation *op) {
return clone(rewriter, linalgOp, resultTypes,
@@ -994,9 +1029,10 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
SmallVector<ReassociationIndices> reassociationBeforeExpansion =
isExpanding ? expandingReshapeOp.getReassociationIndices()
: collapsingReshapeOp.getReassociationIndices();
- Operation *fusedOp = createExpandedOp(
- rewriter, linalgOp, resultTypes, expandedOpOperands, outputs,
- expandedOpIndexingMaps, expansionInfo, reassociationBeforeExpansion);
+ Operation *fusedOp =
+ createExpandedOp(rewriter, linalgOp, resultTypes, expandedOpOperands,
+ outputs, expandedOpIndexingMaps, expansionInfo,
+ reassociationBeforeExpansion, isExpanding);
// Reshape the result values to their original shape if this is a collapsing
// reshape folded into its consumer.
SmallVector<Value> resultVals;
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 4da9c0851ac70..7c2b55ca745ff 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -195,7 +195,7 @@ func.func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
// CHECK-SAME: : tensor<8x33x4xf32>
// CHECK-DAG: %[[INIT:.+]] = tensor.empty()
// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32>
-// CHECK: %[[T1:.+]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32>
+// CHECK: %[[T1:.+]] = tensor.expand_shape %[[INIT]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32>
// CHECK: %[[T2:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]]
// CHECK-SAME: ["parallel", "parallel", "parallel"]
@@ -203,6 +203,29 @@ func.func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
// CHECK-SAME: outs(%[[T1]] : tensor<8x33x4xf32>)
// CHECK: return %[[T2]] : tensor<8x33x4xf32>
+// -----
+
+func.func @reshape_as_consumer_transpose
+ (%a : tensor<4x210x6xf32>)
+ -> tensor<2x3x4x5x6x7xf32> {
+ %b = tensor.empty() : tensor<6x4x210xf32>
+ %c = linalg.transpose
+ ins(%a : tensor<4x210x6xf32>)
+ outs(%b : tensor<6x4x210xf32>) permutation = [2, 0, 1]
+ %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xf32> into tensor<2x3x4x5x6x7xf32>
+ return %d : tensor<2x3x4x5x6x7xf32>
+}
+// CHECK: func @reshape_as_consumer_transpose
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<4x210x6xf32>
+// CHECK-DAG: %[[INIT:.+]] = tensor.empty()
+// CHECK-DAG: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3], [4, 5]] output_shape [4, 5, 6, 7, 2, 3] : tensor<4x210x6xf32> into tensor<4x5x6x7x2x3xf32>
+// CHECK-DAG: %[[T1:.+]] = tensor.expand_shape %[[INIT]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xf32> into tensor<2x3x4x5x6x7xf32
+// CHECK: %[[T2:.+]] = linalg.transpose ins(%[[T0]] : tensor<4x5x6x7x2x3xf32>)
+// CHECK-SAME: outs(%[[T1]] : tensor<2x3x4x5x6x7xf32>)
+// CHECK-SAME: permutation = [4, 5, 0, 1, 2, 3]
+// CHECK: return %[[T2]] : tensor<2x3x4x5x6x7xf32>
+
+
// -----
#map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
@@ -859,37 +882,29 @@ func.func @linalg_copy_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
// -----
-func.func @linalg_transpose_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
- %arg1 : tensor<?x?xf32>) ->
- tensor<?x?xf32>
-{
- %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] :
- tensor<?x7x?x8xf32> into tensor<?x?xf32>
- %1 = linalg.transpose ins(%0 : tensor<?x?xf32>)
- outs(%arg1 : tensor<?x?xf32>) permutation = [1, 0]
- return %1 : tensor<?x?xf32>
+
+func.func @reshape_as_producer_transpose
+ (%a : tensor<4x5x6x7x2x3xf32>)
+ -> tensor<6x4x210xf32> {
+ %b = tensor.empty() : tensor<6x4x210xf32>
+ %c = tensor.collapse_shape %a [[0], [1, 2, 3], [4, 5]] :
+ tensor<4x5x6x7x2x3xf32> into tensor<4x210x6xf32>
+ %d = linalg.transpose
+ ins(%c : tensor<4x210x6xf32>)
+ outs(%b : tensor<6x4x210xf32>) permutation = [2, 0, 1]
+ return %d : tensor<6x4x210xf32>
}
-// CHECK: func @linalg_transpose_reshape_producer_fusion
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
-// CHECK-DAG: %[[C7:.+]] = arith.constant 7 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
-// CHECK-DAG: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-// CHECK-DAG: %[[VAL_0:.+]] = arith.divsi %[[DIM_0]], %[[C7]] : index
-// CHECK-DAG: %[[VAL_1:.+]] = arith.divsi %[[DIM]], %[[C8]] : index
-// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_1]], 8, %[[VAL_0]], 7] : tensor<?x?xf32> into tensor<?x8x?x7xf32>
-// CHECK: %[[T2:.+]] = linalg.transpose
-// CHECK-SAME: ins(%[[ARG0]] : tensor<?x7x?x8xf32>)
-// CHECK-SAME: outs(%[[T1]] : tensor<?x8x?x7xf32>)
-// CHECK-SAME: permutation = [2, 3, 0, 1]
-// CHECK: %[[T3:.+]] = tensor.collapse_shape %[[T2]]
-// CHECK-SAME: [0, 1], [2, 3]
-// CHECK-SAME: tensor<?x8x?x7xf32> into tensor<?x?xf32>
-// CHECK: return %[[T3]]
+// CHECK: func @reshape_as_producer_transpose
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<4x5x6x7x2x3xf32>
+// CHECK-DAG: %[[INIT:.+]] = tensor.empty()
+// CHECK-DAG: %[[T0:.+]] = tensor.expand_shape %[[INIT]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xf32> into tensor<2x3x4x5x6x7xf32>
+// CHECK: %[[T1:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<4x5x6x7x2x3xf32>)
+// CHECK-SAME: outs(%[[T0]] : tensor<2x3x4x5x6x7xf32>)
+// CHECK-SAME: permutation = [4, 5, 0, 1, 2, 3]
+// CHECK: %[[T2:.+]] = tensor.collapse_shape %[[T1]] {{\[\[}}0, 1], [2], [3, 4, 5]] : tensor<2x3x4x5x6x7xf32> into tensor<6x4x210xf32>
+// CHECK: return %[[T2]] : tensor<6x4x210xf32>
+
// -----
>From 092eb1860ad35dbc8bcf469150a9141daad7be35 Mon Sep 17 00:00:00 2001
From: Nirvedh Meshram <nirvedh at gmail.com>
Date: Mon, 10 Mar 2025 20:02:39 -0500
Subject: [PATCH 2/2] fix sinking expand naming in comment
Signed-off-by: Nirvedh Meshram <nirvedh at gmail.com>
---
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 222fdd40ca12f..1a10aa9a41b06 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -816,7 +816,7 @@ validateDynamicDimExpansion(LinalgOp linalgOp,
}
// Create an expanded transpose op.
-// For bubbling a collapse : transpose(collapse_shape),
+// For sinking a collapse : transpose(collapse_shape),
// all expanded groups are permuted together. We just permute the reassocation
// map of the collapse and flatten it. For example,
//
@@ -827,7 +827,7 @@ validateDynamicDimExpansion(LinalgOp linalgOp,
//
// permutation = [4, 5, 0 , 1, 2, 3]
//
-// For sinking expand : expand_shape(transpose),
+// For bubbling an expand : expand_shape(transpose),
// the reassociation map is already permuted hence we inverse permute and then
// flatten it. Then we inverse permute it again to get the final expanded
// transpose permutation. For example,
More information about the Mlir-commits
mailing list