[Mlir-commits] [mlir] [mlir][linalg] Add transpose support for reshape as consumer fusion (PR #130344)

Nirvedh Meshram llvmlistbot at llvm.org
Mon Mar 10 18:04:46 PDT 2025


https://github.com/nirvedhmeshram updated https://github.com/llvm/llvm-project/pull/130344

>From 14cb2327a55916d035de9cf198338237b0cf80bb Mon Sep 17 00:00:00 2001
From: Nirvedh Meshram <nirvedh at gmail.com>
Date: Thu, 6 Mar 2025 21:52:59 -0600
Subject: [PATCH 1/2] [mlir][linalg] Add transpose support for reshape as
 consumer fusion

Signed-off-by: Nirvedh Meshram <nirvedh at gmail.com>
---
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 50 +++++++++++--
 mlir/test/Dialect/Linalg/reshape_fusion.mlir  | 75 +++++++++++--------
 2 files changed, 88 insertions(+), 37 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index a45b5c43f5d33..222fdd40ca12f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -816,17 +816,51 @@ validateDynamicDimExpansion(LinalgOp linalgOp,
 }
 
 // Create an expanded transpose op.
+// For bubbling a collapse : transpose(collapse_shape),
+// all expanded groups are permuted together. We just permute the reassocation
+// map of the collapse and flatten it. For example,
+//
+// reassociation_map = [[0], [1, 2, 3], [4, 5]]
+// permutation = [2, 0, 1]
+//
+// Becomes
+//
+// permutation = [4, 5, 0 , 1, 2, 3]
+//
+// For sinking expand : expand_shape(transpose),
+// the reassociation map is already permuted hence we inverse permute and then
+// flatten it. Then we inverse permute it again to get the final expanded
+// transpose permutation. For example,
+//
+// permutation = [2, 0, 1]
+// reassociation_map = [[0, 1], [2], [3, 4, 5]]
+//
+// inverse permutation = [1, 2, 0]
+// applied to reassocation_map and then flattened becomes
+// flatened permutation = [2, 3, 4, 5, 0, 1]
+// final permuation is the inverse of the flattened permutation.
+//
+// Becomes
+//
+// permutation=[4, 5, 0, 1, 2, 3]
+
 static Operation *
 createExpandedTransposeOp(PatternRewriter &rewriter, TransposeOp transposeOp,
                           SmallVector<ReassociationIndices> reassociation,
-                          Value expandedInput, Value output) {
-  applyPermutationToVector(reassociation, transposeOp.getPermutation());
+                          Value expandedInput, Value output, bool isExpanding) {
+  ArrayRef<int64_t> permutation =
+      isExpanding ? invertPermutationVector(transposeOp.getPermutation())
+                  : transposeOp.getPermutation();
+  applyPermutationToVector(reassociation, permutation);
   SmallVector<int64_t> newPerm;
   for (auto reassoc : reassociation) {
     for (auto dim : reassoc) {
       newPerm.push_back(dim);
     }
   }
+  if (isExpanding) {
+    newPerm = invertPermutationVector(newPerm);
+  }
   return rewriter.create<TransposeOp>(transposeOp.getLoc(), expandedInput,
                                       output, newPerm);
 }
@@ -866,12 +900,13 @@ static Operation *createExpandedOp(
     PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
     ArrayRef<Value> expandedOpOperands, ArrayRef<Value> outputs,
     ArrayRef<AffineMap> expandedOpIndexingMaps, ExpansionInfo &expansionInfo,
-    SmallVector<ReassociationIndices> reassociation) {
+    SmallVector<ReassociationIndices> reassociation, bool isExpanding) {
 
   return TypeSwitch<Operation *, Operation *>(linalgOp.getOperation())
       .Case<TransposeOp>([&](TransposeOp transposeOp) {
         return createExpandedTransposeOp(rewriter, transposeOp, reassociation,
-                                         expandedOpOperands[0], outputs[0]);
+                                         expandedOpOperands[0], outputs[0],
+                                         isExpanding);
       })
       .Case<FillOp, CopyOp>([&](Operation *op) {
         return clone(rewriter, linalgOp, resultTypes,
@@ -994,9 +1029,10 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
   SmallVector<ReassociationIndices> reassociationBeforeExpansion =
       isExpanding ? expandingReshapeOp.getReassociationIndices()
                   : collapsingReshapeOp.getReassociationIndices();
-  Operation *fusedOp = createExpandedOp(
-      rewriter, linalgOp, resultTypes, expandedOpOperands, outputs,
-      expandedOpIndexingMaps, expansionInfo, reassociationBeforeExpansion);
+  Operation *fusedOp =
+      createExpandedOp(rewriter, linalgOp, resultTypes, expandedOpOperands,
+                       outputs, expandedOpIndexingMaps, expansionInfo,
+                       reassociationBeforeExpansion, isExpanding);
   // Reshape the result values to their original shape if this is a collapsing
   // reshape folded into its consumer.
   SmallVector<Value> resultVals;
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 4da9c0851ac70..7c2b55ca745ff 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -195,7 +195,7 @@ func.func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
 // CHECK-SAME:     : tensor<8x33x4xf32>
 //  CHECK-DAG:   %[[INIT:.+]] = tensor.empty()
 //      CHECK:   %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32>
-//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32>
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[INIT]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32>
 //      CHECK:   %[[T2:.+]] = linalg.generic
 // CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]]
 // CHECK-SAME:     ["parallel", "parallel", "parallel"]
@@ -203,6 +203,29 @@ func.func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
 // CHECK-SAME:     outs(%[[T1]] : tensor<8x33x4xf32>)
 //      CHECK:   return %[[T2]] : tensor<8x33x4xf32>
 
+// -----
+
+func.func @reshape_as_consumer_transpose
+  (%a :  tensor<4x210x6xf32>)
+    -> tensor<2x3x4x5x6x7xf32> {
+  %b = tensor.empty() : tensor<6x4x210xf32>
+  %c = linalg.transpose
+          ins(%a : tensor<4x210x6xf32>)
+         outs(%b : tensor<6x4x210xf32>) permutation = [2, 0, 1]
+  %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xf32> into tensor<2x3x4x5x6x7xf32>
+  return %d : tensor<2x3x4x5x6x7xf32>
+}
+//      CHECK: func @reshape_as_consumer_transpose
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<4x210x6xf32>
+//  CHECK-DAG:   %[[INIT:.+]] = tensor.empty()
+//  CHECK-DAG:   %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3], [4, 5]] output_shape [4, 5, 6, 7, 2, 3] : tensor<4x210x6xf32> into tensor<4x5x6x7x2x3xf32>
+//  CHECK-DAG:   %[[T1:.+]] = tensor.expand_shape %[[INIT]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xf32> into tensor<2x3x4x5x6x7xf32
+//      CHECK:   %[[T2:.+]] = linalg.transpose ins(%[[T0]] : tensor<4x5x6x7x2x3xf32>)
+// CHECK-SAME:     outs(%[[T1]] : tensor<2x3x4x5x6x7xf32>)
+// CHECK-SAME:     permutation = [4, 5, 0, 1, 2, 3]
+//      CHECK:   return %[[T2]] : tensor<2x3x4x5x6x7xf32>
+
+
 // -----
 
 #map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
@@ -859,37 +882,29 @@ func.func @linalg_copy_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
 
 // -----
 
-func.func @linalg_transpose_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
-                                              %arg1 : tensor<?x?xf32>) ->
-                                              tensor<?x?xf32>
-{
-  %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] :
-    tensor<?x7x?x8xf32> into tensor<?x?xf32>
-  %1 = linalg.transpose ins(%0 : tensor<?x?xf32>)
-       outs(%arg1 : tensor<?x?xf32>) permutation = [1, 0]
-  return %1 : tensor<?x?xf32>
+
+func.func @reshape_as_producer_transpose
+  (%a :  tensor<4x5x6x7x2x3xf32>)
+    -> tensor<6x4x210xf32> {
+  %b = tensor.empty() : tensor<6x4x210xf32>
+  %c = tensor.collapse_shape %a [[0], [1, 2, 3], [4, 5]] :
+    tensor<4x5x6x7x2x3xf32> into tensor<4x210x6xf32>
+  %d = linalg.transpose
+          ins(%c : tensor<4x210x6xf32>)
+         outs(%b : tensor<6x4x210xf32>) permutation = [2, 0, 1]
+  return %d : tensor<6x4x210xf32>
 }
 
-//      CHECK: func @linalg_transpose_reshape_producer_fusion
-// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
-// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-//  CHECK-DAG:   %[[C8:.+]] = arith.constant 8 : index
-//  CHECK-DAG:   %[[C7:.+]] = arith.constant 7 : index
-//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
-//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
-//  CHECK-DAG:   %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
-//  CHECK-DAG:   %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-//  CHECK-DAG:   %[[VAL_0:.+]] = arith.divsi %[[DIM_0]], %[[C7]] : index
-//  CHECK-DAG:   %[[VAL_1:.+]] = arith.divsi %[[DIM]], %[[C8]] : index
-//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_1]], 8, %[[VAL_0]], 7] : tensor<?x?xf32> into tensor<?x8x?x7xf32>
-//      CHECK:   %[[T2:.+]] = linalg.transpose
-// CHECK-SAME:     ins(%[[ARG0]] : tensor<?x7x?x8xf32>)
-// CHECK-SAME:     outs(%[[T1]] : tensor<?x8x?x7xf32>)
-// CHECK-SAME:   permutation = [2, 3, 0, 1]
-//      CHECK:   %[[T3:.+]] = tensor.collapse_shape %[[T2]]
-// CHECK-SAME:     [0, 1], [2, 3]
-// CHECK-SAME:     tensor<?x8x?x7xf32> into tensor<?x?xf32>
-//      CHECK:   return %[[T3]]
+//      CHECK: func @reshape_as_producer_transpose
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<4x5x6x7x2x3xf32>
+//  CHECK-DAG:   %[[INIT:.+]] = tensor.empty()
+//  CHECK-DAG:   %[[T0:.+]] = tensor.expand_shape %[[INIT]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xf32> into tensor<2x3x4x5x6x7xf32>
+//      CHECK:   %[[T1:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<4x5x6x7x2x3xf32>)
+// CHECK-SAME:     outs(%[[T0]] : tensor<2x3x4x5x6x7xf32>)
+// CHECK-SAME:     permutation = [4, 5, 0, 1, 2, 3]
+//      CHECK:   %[[T2:.+]] = tensor.collapse_shape %[[T1]] {{\[\[}}0, 1], [2], [3, 4, 5]] : tensor<2x3x4x5x6x7xf32> into tensor<6x4x210xf32>
+//      CHECK:   return %[[T2]] : tensor<6x4x210xf32>
+
 
 // -----
 

>From 092eb1860ad35dbc8bcf469150a9141daad7be35 Mon Sep 17 00:00:00 2001
From: Nirvedh Meshram <nirvedh at gmail.com>
Date: Mon, 10 Mar 2025 20:02:39 -0500
Subject: [PATCH 2/2] fix sinking expand naming in comment

Signed-off-by: Nirvedh Meshram <nirvedh at gmail.com>
---
 mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 222fdd40ca12f..1a10aa9a41b06 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -816,7 +816,7 @@ validateDynamicDimExpansion(LinalgOp linalgOp,
 }
 
 // Create an expanded transpose op.
-// For bubbling a collapse : transpose(collapse_shape),
+// For sinking a collapse : transpose(collapse_shape),
 // all expanded groups are permuted together. We just permute the reassocation
 // map of the collapse and flatten it. For example,
 //
@@ -827,7 +827,7 @@ validateDynamicDimExpansion(LinalgOp linalgOp,
 //
 // permutation = [4, 5, 0 , 1, 2, 3]
 //
-// For sinking expand : expand_shape(transpose),
+// For bubbling an expand : expand_shape(transpose),
 // the reassociation map is already permuted hence we inverse permute and then
 // flatten it. Then we inverse permute it again to get the final expanded
 // transpose permutation. For example,



More information about the Mlir-commits mailing list