[Mlir-commits] [mlir] 849abd8 - [mlir][linalg] Add transpose support for reshape as consumer fusion (#130344)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 11 05:48:42 PDT 2025
Author: Nirvedh Meshram
Date: 2025-03-11T07:48:38-05:00
New Revision: 849abd8c05cf9899cd943a7b56bae57f93ea80cb
URL: https://github.com/llvm/llvm-project/commit/849abd8c05cf9899cd943a7b56bae57f93ea80cb
DIFF: https://github.com/llvm/llvm-project/commit/849abd8c05cf9899cd943a7b56bae57f93ea80cb.diff
LOG: [mlir][linalg] Add transpose support for reshape as consumer fusion (#130344)
During https://github.com/llvm/llvm-project/pull/129128 adding reshape
as consumer fusion handling of linalg.transpose was missed. This PR adds
that.
Also transpose reshape as producer fusion test is updated to static
sizes as that is more likely to catch any issues with the permutation
vector in the verifier if the shapes dont match up.
---------
Signed-off-by: Nirvedh Meshram <nirvedh at gmail.com>
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/test/Dialect/Linalg/reshape_fusion.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 3f016fed3519c..33667e7ab0c5c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -811,19 +811,35 @@ validateDynamicDimExpansion(LinalgOp linalgOp,
}
// Create an expanded transpose op.
-static Operation *
-createExpandedTransposeOp(PatternRewriter &rewriter, TransposeOp transposeOp,
- SmallVector<ReassociationIndices> reassociation,
- Value expandedInput, Value output) {
- applyPermutationToVector(reassociation, transposeOp.getPermutation());
+// the reassociation map is already permuted hence we inverse permute and then
+// flatten it. Then we inverse permute it again to get the final expanded
+// transpose permutation. For example,
+//
+// permutation = [2, 0, 1]
+// reassociation_map for expansion = [[0, 1], [2], [3, 4, 5]]
+//
+// inverse permutation = [1, 2, 0]
+// applied to reassocation_map and then flattened becomes
+// flatened permutation = [2, 3, 4, 5, 0, 1]
+// final permuation is the inverse of the flattened permutation.
+//
+// Becomes
+//
+// permutation=[4, 5, 0, 1, 2, 3]
+
+static Operation *createExpandedTransposeOp(PatternRewriter &rewriter,
+ TransposeOp transposeOp,
+ Value expandedInput, Value output,
+ ExpansionInfo &expansionInfo) {
SmallVector<int64_t> newPerm;
- for (const auto &reassoc : reassociation) {
- for (auto dim : reassoc) {
+ for (int64_t perm : invertPermutationVector(transposeOp.getPermutation())) {
+ auto reassoc = expansionInfo.getExpandedDims(perm);
+ for (int64_t dim : reassoc) {
newPerm.push_back(dim);
}
}
return rewriter.create<TransposeOp>(transposeOp.getLoc(), expandedInput,
- output, newPerm);
+ output, invertPermutationVector(newPerm));
}
// Create an expanded generic op.
@@ -857,16 +873,18 @@ static Operation *createExpandedGenericOp(
// Create an expanded fused op that retains the name for certain ops
// such as fill, copy and transpose and produce a generic op for
// rest of linalg ops.
-static Operation *createExpandedOp(
- PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
- ArrayRef<Value> expandedOpOperands, ArrayRef<Value> outputs,
- ArrayRef<AffineMap> expandedOpIndexingMaps, ExpansionInfo &expansionInfo,
- SmallVector<ReassociationIndices> reassociation) {
+static Operation *createExpandedOp(PatternRewriter &rewriter, LinalgOp linalgOp,
+ TypeRange resultTypes,
+ ArrayRef<Value> expandedOpOperands,
+ ArrayRef<Value> outputs,
+ ArrayRef<AffineMap> expandedOpIndexingMaps,
+ ExpansionInfo &expansionInfo) {
return TypeSwitch<Operation *, Operation *>(linalgOp.getOperation())
.Case<TransposeOp>([&](TransposeOp transposeOp) {
- return createExpandedTransposeOp(rewriter, transposeOp, reassociation,
- expandedOpOperands[0], outputs[0]);
+ return createExpandedTransposeOp(rewriter, transposeOp,
+ expandedOpOperands[0], outputs[0],
+ expansionInfo);
})
.Case<FillOp, CopyOp>([&](Operation *op) {
return clone(rewriter, linalgOp, resultTypes,
@@ -986,12 +1004,9 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
}
TypeRange resultTypes = ValueRange(outputs).getTypes();
- SmallVector<ReassociationIndices> reassociationBeforeExpansion =
- isExpanding ? expandingReshapeOp.getReassociationIndices()
- : collapsingReshapeOp.getReassociationIndices();
- Operation *fusedOp = createExpandedOp(
- rewriter, linalgOp, resultTypes, expandedOpOperands, outputs,
- expandedOpIndexingMaps, expansionInfo, reassociationBeforeExpansion);
+ Operation *fusedOp =
+ createExpandedOp(rewriter, linalgOp, resultTypes, expandedOpOperands,
+ outputs, expandedOpIndexingMaps, expansionInfo);
// Reshape the result values to their original shape if this is a collapsing
// reshape folded into its consumer.
SmallVector<Value> resultVals;
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index c8720ebd98c09..3244418d445b7 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -195,7 +195,7 @@ func.func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
// CHECK-SAME: : tensor<8x33x4xf32>
// CHECK-DAG: %[[INIT:.+]] = tensor.empty()
// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32>
-// CHECK: %[[T1:.+]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32>
+// CHECK: %[[T1:.+]] = tensor.expand_shape %[[INIT]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32>
// CHECK: %[[T2:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]]
// CHECK-SAME: ["parallel", "parallel", "parallel"]
@@ -203,6 +203,29 @@ func.func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
// CHECK-SAME: outs(%[[T1]] : tensor<8x33x4xf32>)
// CHECK: return %[[T2]] : tensor<8x33x4xf32>
+// -----
+
+func.func @reshape_as_consumer_transpose
+ (%a : tensor<4x210x6xf32>)
+ -> tensor<2x3x4x5x6x7xf32> {
+ %b = tensor.empty() : tensor<6x4x210xf32>
+ %c = linalg.transpose
+ ins(%a : tensor<4x210x6xf32>)
+ outs(%b : tensor<6x4x210xf32>) permutation = [2, 0, 1]
+ %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xf32> into tensor<2x3x4x5x6x7xf32>
+ return %d : tensor<2x3x4x5x6x7xf32>
+}
+// CHECK: func @reshape_as_consumer_transpose
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<4x210x6xf32>
+// CHECK-DAG: %[[INIT:.+]] = tensor.empty()
+// CHECK-DAG: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3], [4, 5]] output_shape [4, 5, 6, 7, 2, 3] : tensor<4x210x6xf32> into tensor<4x5x6x7x2x3xf32>
+// CHECK-DAG: %[[T1:.+]] = tensor.expand_shape %[[INIT]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xf32> into tensor<2x3x4x5x6x7xf32
+// CHECK: %[[T2:.+]] = linalg.transpose ins(%[[T0]] : tensor<4x5x6x7x2x3xf32>)
+// CHECK-SAME: outs(%[[T1]] : tensor<2x3x4x5x6x7xf32>)
+// CHECK-SAME: permutation = [4, 5, 0, 1, 2, 3]
+// CHECK: return %[[T2]] : tensor<2x3x4x5x6x7xf32>
+
+
// -----
#map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
@@ -884,37 +907,29 @@ func.func @linalg_copy_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
// -----
-func.func @linalg_transpose_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
- %arg1 : tensor<?x?xf32>) ->
- tensor<?x?xf32>
-{
- %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] :
- tensor<?x7x?x8xf32> into tensor<?x?xf32>
- %1 = linalg.transpose ins(%0 : tensor<?x?xf32>)
- outs(%arg1 : tensor<?x?xf32>) permutation = [1, 0]
- return %1 : tensor<?x?xf32>
+
+func.func @reshape_as_producer_transpose
+ (%a : tensor<4x5x6x7x2x3xf32>)
+ -> tensor<6x4x210xf32> {
+ %b = tensor.empty() : tensor<6x4x210xf32>
+ %c = tensor.collapse_shape %a [[0], [1, 2, 3], [4, 5]] :
+ tensor<4x5x6x7x2x3xf32> into tensor<4x210x6xf32>
+ %d = linalg.transpose
+ ins(%c : tensor<4x210x6xf32>)
+ outs(%b : tensor<6x4x210xf32>) permutation = [2, 0, 1]
+ return %d : tensor<6x4x210xf32>
}
-// CHECK: func @linalg_transpose_reshape_producer_fusion
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
-// CHECK-DAG: %[[C7:.+]] = arith.constant 7 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
-// CHECK-DAG: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-// CHECK-DAG: %[[VAL_0:.+]] = arith.divsi %[[DIM_0]], %[[C7]] : index
-// CHECK-DAG: %[[VAL_1:.+]] = arith.divsi %[[DIM]], %[[C8]] : index
-// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_1]], 8, %[[VAL_0]], 7] : tensor<?x?xf32> into tensor<?x8x?x7xf32>
-// CHECK: %[[T2:.+]] = linalg.transpose
-// CHECK-SAME: ins(%[[ARG0]] : tensor<?x7x?x8xf32>)
-// CHECK-SAME: outs(%[[T1]] : tensor<?x8x?x7xf32>)
-// CHECK-SAME: permutation = [2, 3, 0, 1]
-// CHECK: %[[T3:.+]] = tensor.collapse_shape %[[T2]]
-// CHECK-SAME: [0, 1], [2, 3]
-// CHECK-SAME: tensor<?x8x?x7xf32> into tensor<?x?xf32>
-// CHECK: return %[[T3]]
+// CHECK: func @reshape_as_producer_transpose
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<4x5x6x7x2x3xf32>
+// CHECK-DAG: %[[INIT:.+]] = tensor.empty()
+// CHECK-DAG: %[[T0:.+]] = tensor.expand_shape %[[INIT]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xf32> into tensor<2x3x4x5x6x7xf32>
+// CHECK: %[[T1:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<4x5x6x7x2x3xf32>)
+// CHECK-SAME: outs(%[[T0]] : tensor<2x3x4x5x6x7xf32>)
+// CHECK-SAME: permutation = [4, 5, 0, 1, 2, 3]
+// CHECK: %[[T2:.+]] = tensor.collapse_shape %[[T1]] {{\[\[}}0, 1], [2], [3, 4, 5]] : tensor<2x3x4x5x6x7xf32> into tensor<6x4x210xf32>
+// CHECK: return %[[T2]] : tensor<6x4x210xf32>
+
// -----
More information about the Mlir-commits
mailing list