[Mlir-commits] [mlir] [mlir][linalg] Add transpose support for reshape as consumer fusion (PR #130344)

Nirvedh Meshram llvmlistbot at llvm.org
Mon Mar 10 19:27:59 PDT 2025


https://github.com/nirvedhmeshram updated https://github.com/llvm/llvm-project/pull/130344

>From 54821d6e7c6e9e511f70d4c83c65e37104d93593 Mon Sep 17 00:00:00 2001
From: Nirvedh Meshram <nirvedh at gmail.com>
Date: Thu, 6 Mar 2025 21:52:59 -0600
Subject: [PATCH 1/3] [mlir][linalg] Add transpose support for reshape as
 consumer fusion

Signed-off-by: Nirvedh Meshram <nirvedh at gmail.com>
---
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 50 +++++++++++--
 mlir/test/Dialect/Linalg/reshape_fusion.mlir  | 75 +++++++++++--------
 2 files changed, 88 insertions(+), 37 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 3f016fed3519c..6b48d6c587d81 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -811,17 +811,51 @@ validateDynamicDimExpansion(LinalgOp linalgOp,
 }
 
 // Create an expanded transpose op.
+// For bubbling a collapse : transpose(collapse_shape),
+// all expanded groups are permuted together. We just permute the reassocation
+// map of the collapse and flatten it. For example,
+//
+// reassociation_map = [[0], [1, 2, 3], [4, 5]]
+// permutation = [2, 0, 1]
+//
+// Becomes
+//
+// permutation = [4, 5, 0 , 1, 2, 3]
+//
+// For sinking expand : expand_shape(transpose),
+// the reassociation map is already permuted hence we inverse permute and then
+// flatten it. Then we inverse permute it again to get the final expanded
+// transpose permutation. For example,
+//
+// permutation = [2, 0, 1]
+// reassociation_map = [[0, 1], [2], [3, 4, 5]]
+//
+// inverse permutation = [1, 2, 0]
+// applied to reassocation_map and then flattened becomes
+// flatened permutation = [2, 3, 4, 5, 0, 1]
+// final permuation is the inverse of the flattened permutation.
+//
+// Becomes
+//
+// permutation=[4, 5, 0, 1, 2, 3]
+
 static Operation *
 createExpandedTransposeOp(PatternRewriter &rewriter, TransposeOp transposeOp,
                           SmallVector<ReassociationIndices> reassociation,
-                          Value expandedInput, Value output) {
-  applyPermutationToVector(reassociation, transposeOp.getPermutation());
+                          Value expandedInput, Value output, bool isExpanding) {
+  ArrayRef<int64_t> permutation =
+      isExpanding ? invertPermutationVector(transposeOp.getPermutation())
+                  : transposeOp.getPermutation();
+  applyPermutationToVector(reassociation, permutation);
   SmallVector<int64_t> newPerm;
   for (const auto &reassoc : reassociation) {
     for (auto dim : reassoc) {
       newPerm.push_back(dim);
     }
   }
+  if (isExpanding) {
+    newPerm = invertPermutationVector(newPerm);
+  }
   return rewriter.create<TransposeOp>(transposeOp.getLoc(), expandedInput,
                                       output, newPerm);
 }
@@ -861,12 +895,13 @@ static Operation *createExpandedOp(
     PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
     ArrayRef<Value> expandedOpOperands, ArrayRef<Value> outputs,
     ArrayRef<AffineMap> expandedOpIndexingMaps, ExpansionInfo &expansionInfo,
-    SmallVector<ReassociationIndices> reassociation) {
+    SmallVector<ReassociationIndices> reassociation, bool isExpanding) {
 
   return TypeSwitch<Operation *, Operation *>(linalgOp.getOperation())
       .Case<TransposeOp>([&](TransposeOp transposeOp) {
         return createExpandedTransposeOp(rewriter, transposeOp, reassociation,
-                                         expandedOpOperands[0], outputs[0]);
+                                         expandedOpOperands[0], outputs[0],
+                                         isExpanding);
       })
       .Case<FillOp, CopyOp>([&](Operation *op) {
         return clone(rewriter, linalgOp, resultTypes,
@@ -989,9 +1024,10 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
   SmallVector<ReassociationIndices> reassociationBeforeExpansion =
       isExpanding ? expandingReshapeOp.getReassociationIndices()
                   : collapsingReshapeOp.getReassociationIndices();
-  Operation *fusedOp = createExpandedOp(
-      rewriter, linalgOp, resultTypes, expandedOpOperands, outputs,
-      expandedOpIndexingMaps, expansionInfo, reassociationBeforeExpansion);
+  Operation *fusedOp =
+      createExpandedOp(rewriter, linalgOp, resultTypes, expandedOpOperands,
+                       outputs, expandedOpIndexingMaps, expansionInfo,
+                       reassociationBeforeExpansion, isExpanding);
   // Reshape the result values to their original shape if this is a collapsing
   // reshape folded into its consumer.
   SmallVector<Value> resultVals;
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index c8720ebd98c09..3244418d445b7 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -195,7 +195,7 @@ func.func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
 // CHECK-SAME:     : tensor<8x33x4xf32>
 //  CHECK-DAG:   %[[INIT:.+]] = tensor.empty()
 //      CHECK:   %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32>
-//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32>
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[INIT]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32>
 //      CHECK:   %[[T2:.+]] = linalg.generic
 // CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]]
 // CHECK-SAME:     ["parallel", "parallel", "parallel"]
@@ -203,6 +203,29 @@ func.func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
 // CHECK-SAME:     outs(%[[T1]] : tensor<8x33x4xf32>)
 //      CHECK:   return %[[T2]] : tensor<8x33x4xf32>
 
+// -----
+
+func.func @reshape_as_consumer_transpose
+  (%a :  tensor<4x210x6xf32>)
+    -> tensor<2x3x4x5x6x7xf32> {
+  %b = tensor.empty() : tensor<6x4x210xf32>
+  %c = linalg.transpose
+          ins(%a : tensor<4x210x6xf32>)
+         outs(%b : tensor<6x4x210xf32>) permutation = [2, 0, 1]
+  %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xf32> into tensor<2x3x4x5x6x7xf32>
+  return %d : tensor<2x3x4x5x6x7xf32>
+}
+//      CHECK: func @reshape_as_consumer_transpose
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<4x210x6xf32>
+//  CHECK-DAG:   %[[INIT:.+]] = tensor.empty()
+//  CHECK-DAG:   %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3], [4, 5]] output_shape [4, 5, 6, 7, 2, 3] : tensor<4x210x6xf32> into tensor<4x5x6x7x2x3xf32>
+//  CHECK-DAG:   %[[T1:.+]] = tensor.expand_shape %[[INIT]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xf32> into tensor<2x3x4x5x6x7xf32
+//      CHECK:   %[[T2:.+]] = linalg.transpose ins(%[[T0]] : tensor<4x5x6x7x2x3xf32>)
+// CHECK-SAME:     outs(%[[T1]] : tensor<2x3x4x5x6x7xf32>)
+// CHECK-SAME:     permutation = [4, 5, 0, 1, 2, 3]
+//      CHECK:   return %[[T2]] : tensor<2x3x4x5x6x7xf32>
+
+
 // -----
 
 #map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
@@ -884,37 +907,29 @@ func.func @linalg_copy_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
 
 // -----
 
-func.func @linalg_transpose_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
-                                              %arg1 : tensor<?x?xf32>) ->
-                                              tensor<?x?xf32>
-{
-  %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] :
-    tensor<?x7x?x8xf32> into tensor<?x?xf32>
-  %1 = linalg.transpose ins(%0 : tensor<?x?xf32>)
-       outs(%arg1 : tensor<?x?xf32>) permutation = [1, 0]
-  return %1 : tensor<?x?xf32>
+
+func.func @reshape_as_producer_transpose
+  (%a :  tensor<4x5x6x7x2x3xf32>)
+    -> tensor<6x4x210xf32> {
+  %b = tensor.empty() : tensor<6x4x210xf32>
+  %c = tensor.collapse_shape %a [[0], [1, 2, 3], [4, 5]] :
+    tensor<4x5x6x7x2x3xf32> into tensor<4x210x6xf32>
+  %d = linalg.transpose
+          ins(%c : tensor<4x210x6xf32>)
+         outs(%b : tensor<6x4x210xf32>) permutation = [2, 0, 1]
+  return %d : tensor<6x4x210xf32>
 }
 
-//      CHECK: func @linalg_transpose_reshape_producer_fusion
-// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
-// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-//  CHECK-DAG:   %[[C8:.+]] = arith.constant 8 : index
-//  CHECK-DAG:   %[[C7:.+]] = arith.constant 7 : index
-//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
-//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
-//  CHECK-DAG:   %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
-//  CHECK-DAG:   %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-//  CHECK-DAG:   %[[VAL_0:.+]] = arith.divsi %[[DIM_0]], %[[C7]] : index
-//  CHECK-DAG:   %[[VAL_1:.+]] = arith.divsi %[[DIM]], %[[C8]] : index
-//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_1]], 8, %[[VAL_0]], 7] : tensor<?x?xf32> into tensor<?x8x?x7xf32>
-//      CHECK:   %[[T2:.+]] = linalg.transpose
-// CHECK-SAME:     ins(%[[ARG0]] : tensor<?x7x?x8xf32>)
-// CHECK-SAME:     outs(%[[T1]] : tensor<?x8x?x7xf32>)
-// CHECK-SAME:   permutation = [2, 3, 0, 1]
-//      CHECK:   %[[T3:.+]] = tensor.collapse_shape %[[T2]]
-// CHECK-SAME:     [0, 1], [2, 3]
-// CHECK-SAME:     tensor<?x8x?x7xf32> into tensor<?x?xf32>
-//      CHECK:   return %[[T3]]
+//      CHECK: func @reshape_as_producer_transpose
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<4x5x6x7x2x3xf32>
+//  CHECK-DAG:   %[[INIT:.+]] = tensor.empty()
+//  CHECK-DAG:   %[[T0:.+]] = tensor.expand_shape %[[INIT]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xf32> into tensor<2x3x4x5x6x7xf32>
+//      CHECK:   %[[T1:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<4x5x6x7x2x3xf32>)
+// CHECK-SAME:     outs(%[[T0]] : tensor<2x3x4x5x6x7xf32>)
+// CHECK-SAME:     permutation = [4, 5, 0, 1, 2, 3]
+//      CHECK:   %[[T2:.+]] = tensor.collapse_shape %[[T1]] {{\[\[}}0, 1], [2], [3, 4, 5]] : tensor<2x3x4x5x6x7xf32> into tensor<6x4x210xf32>
+//      CHECK:   return %[[T2]] : tensor<6x4x210xf32>
+
 
 // -----
 

>From d8c467e282f4fa300cafc10c7af6b06fde6f857c Mon Sep 17 00:00:00 2001
From: Nirvedh Meshram <nirvedh at gmail.com>
Date: Mon, 10 Mar 2025 20:02:39 -0500
Subject: [PATCH 2/3] fix sinking expand naming in comment

Signed-off-by: Nirvedh Meshram <nirvedh at gmail.com>
---
 mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 6b48d6c587d81..03ec87400185a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -811,7 +811,7 @@ validateDynamicDimExpansion(LinalgOp linalgOp,
 }
 
 // Create an expanded transpose op.
-// For bubbling a collapse : transpose(collapse_shape),
+// For sinking a collapse : transpose(collapse_shape),
 // all expanded groups are permuted together. We just permute the reassocation
 // map of the collapse and flatten it. For example,
 //
@@ -822,7 +822,7 @@ validateDynamicDimExpansion(LinalgOp linalgOp,
 //
 // permutation = [4, 5, 0 , 1, 2, 3]
 //
-// For sinking expand : expand_shape(transpose),
+// For bubbling an expand : expand_shape(transpose),
 // the reassociation map is already permuted hence we inverse permute and then
 // flatten it. Then we inverse permute it again to get the final expanded
 // transpose permutation. For example,

>From 75d238a5b1fc7f606db1423d766a5985801411fc Mon Sep 17 00:00:00 2001
From: Nirvedh Meshram <nirvedh at gmail.com>
Date: Mon, 10 Mar 2025 21:16:19 -0500
Subject: [PATCH 3/3] Use expansioninfo to get output reassociation

Signed-off-by: Nirvedh Meshram <nirvedh at gmail.com>
---
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 57 ++++++-------------
 1 file changed, 18 insertions(+), 39 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 03ec87400185a..33667e7ab0c5c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -811,24 +811,12 @@ validateDynamicDimExpansion(LinalgOp linalgOp,
 }
 
 // Create an expanded transpose op.
-// For sinking a collapse : transpose(collapse_shape),
-// all expanded groups are permuted together. We just permute the reassocation
-// map of the collapse and flatten it. For example,
-//
-// reassociation_map = [[0], [1, 2, 3], [4, 5]]
-// permutation = [2, 0, 1]
-//
-// Becomes
-//
-// permutation = [4, 5, 0 , 1, 2, 3]
-//
-// For bubbling an expand : expand_shape(transpose),
 // the reassociation map is already permuted hence we inverse permute and then
 // flatten it. Then we inverse permute it again to get the final expanded
 // transpose permutation. For example,
 //
 // permutation = [2, 0, 1]
-// reassociation_map = [[0, 1], [2], [3, 4, 5]]
+// reassociation_map for expansion = [[0, 1], [2], [3, 4, 5]]
 //
 // inverse permutation = [1, 2, 0]
 // applied to reassocation_map and then flattened becomes
@@ -839,25 +827,19 @@ validateDynamicDimExpansion(LinalgOp linalgOp,
 //
 // permutation=[4, 5, 0, 1, 2, 3]
 
-static Operation *
-createExpandedTransposeOp(PatternRewriter &rewriter, TransposeOp transposeOp,
-                          SmallVector<ReassociationIndices> reassociation,
-                          Value expandedInput, Value output, bool isExpanding) {
-  ArrayRef<int64_t> permutation =
-      isExpanding ? invertPermutationVector(transposeOp.getPermutation())
-                  : transposeOp.getPermutation();
-  applyPermutationToVector(reassociation, permutation);
+static Operation *createExpandedTransposeOp(PatternRewriter &rewriter,
+                                            TransposeOp transposeOp,
+                                            Value expandedInput, Value output,
+                                            ExpansionInfo &expansionInfo) {
   SmallVector<int64_t> newPerm;
-  for (const auto &reassoc : reassociation) {
-    for (auto dim : reassoc) {
+  for (int64_t perm : invertPermutationVector(transposeOp.getPermutation())) {
+    auto reassoc = expansionInfo.getExpandedDims(perm);
+    for (int64_t dim : reassoc) {
       newPerm.push_back(dim);
     }
   }
-  if (isExpanding) {
-    newPerm = invertPermutationVector(newPerm);
-  }
   return rewriter.create<TransposeOp>(transposeOp.getLoc(), expandedInput,
-                                      output, newPerm);
+                                      output, invertPermutationVector(newPerm));
 }
 
 // Create an expanded generic op.
@@ -891,17 +873,18 @@ static Operation *createExpandedGenericOp(
 // Create an expanded fused op that retains the name for certain ops
 // such as fill, copy and transpose and produce a generic op for
 // rest of linalg ops.
-static Operation *createExpandedOp(
-    PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
-    ArrayRef<Value> expandedOpOperands, ArrayRef<Value> outputs,
-    ArrayRef<AffineMap> expandedOpIndexingMaps, ExpansionInfo &expansionInfo,
-    SmallVector<ReassociationIndices> reassociation, bool isExpanding) {
+static Operation *createExpandedOp(PatternRewriter &rewriter, LinalgOp linalgOp,
+                                   TypeRange resultTypes,
+                                   ArrayRef<Value> expandedOpOperands,
+                                   ArrayRef<Value> outputs,
+                                   ArrayRef<AffineMap> expandedOpIndexingMaps,
+                                   ExpansionInfo &expansionInfo) {
 
   return TypeSwitch<Operation *, Operation *>(linalgOp.getOperation())
       .Case<TransposeOp>([&](TransposeOp transposeOp) {
-        return createExpandedTransposeOp(rewriter, transposeOp, reassociation,
+        return createExpandedTransposeOp(rewriter, transposeOp,
                                          expandedOpOperands[0], outputs[0],
-                                         isExpanding);
+                                         expansionInfo);
       })
       .Case<FillOp, CopyOp>([&](Operation *op) {
         return clone(rewriter, linalgOp, resultTypes,
@@ -1021,13 +1004,9 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
   }
 
   TypeRange resultTypes = ValueRange(outputs).getTypes();
-  SmallVector<ReassociationIndices> reassociationBeforeExpansion =
-      isExpanding ? expandingReshapeOp.getReassociationIndices()
-                  : collapsingReshapeOp.getReassociationIndices();
   Operation *fusedOp =
       createExpandedOp(rewriter, linalgOp, resultTypes, expandedOpOperands,
-                       outputs, expandedOpIndexingMaps, expansionInfo,
-                       reassociationBeforeExpansion, isExpanding);
+                       outputs, expandedOpIndexingMaps, expansionInfo);
   // Reshape the result values to their original shape if this is a collapsing
   // reshape folded into its consumer.
   SmallVector<Value> resultVals;



More information about the Mlir-commits mailing list