[Mlir-commits] [mlir] [mlir][linalg] Retain Op Type of linalg ops in fuseWithReshapeByExpansion pattern (PR #129128)

Nirvedh Meshram llvmlistbot at llvm.org
Fri Feb 28 09:48:44 PST 2025


https://github.com/nirvedhmeshram updated https://github.com/llvm/llvm-project/pull/129128

>From 620f2a80aecf8013dd242a2f7868510e3b8ed3da Mon Sep 17 00:00:00 2001
From: Nirvedh Meshram <nirvedh at nod-labs.com>
Date: Thu, 27 Feb 2025 12:55:44 -0600
Subject: [PATCH 1/2] [mlir][linalg] Retain named ops in
 fuseWithReshapeByExpansion pattern

Signed-off-by: Nirvedh Meshram <nirvedh at gmail.com>
---
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 48 +++++++++++++-----
 mlir/test/Dialect/Linalg/reshape_fusion.mlir  | 49 ++++++++++++++-----
 2 files changed, 75 insertions(+), 22 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index f4b6955823085..f64151db8e5a0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -927,17 +927,43 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
       iteratorTypes[j] = type;
 
   TypeRange resultTypes = ValueRange(outputs).getTypes();
-  auto fusedOp =
-      rewriter.create<GenericOp>(linalgOp.getLoc(), resultTypes,
-                                 /*inputs=*/expandedOpOperands, outputs,
-                                 expandedOpIndexingMaps, iteratorTypes);
-  Region &fusedRegion = fusedOp->getRegion(0);
-  Region &originalRegion = linalgOp->getRegion(0);
-  rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
-
-  // Update the index accesses after the expansion.
-  updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo);
-
+  Operation *fusedOp;
+
+  TypeSwitch<Operation *>(linalgOp.getOperation())
+      .Case<GenericOp>([&](GenericOp op) {
+        fusedOp = rewriter.create<GenericOp>(
+            linalgOp.getLoc(), resultTypes, expandedOpOperands, outputs,
+            expandedOpIndexingMaps, iteratorTypes);
+        Region &fusedRegion = fusedOp->getRegion(0);
+        Region &originalRegion = linalgOp->getRegion(0);
+        rewriter.cloneRegionBefore(originalRegion, fusedRegion,
+                                   fusedRegion.begin());
+
+        // Update the index accesses after the expansion.
+        updateExpandedGenericOpRegion(rewriter, loc, fusedRegion,
+                                      expansionInfo);
+      })
+      .Case<TransposeOp>([&](TransposeOp op) {
+        SmallVector<ReassociationIndices> reassociation =
+            isExpanding ? expandingReshapeOp.getReassociationIndices()
+                        : collapsingReshapeOp.getReassociationIndices();
+        applyPermutationToVector(reassociation, op.getPermutation());
+        SmallVector<int64_t> newPerm;
+        for (auto reassoc : reassociation) {
+          for (auto dim : reassoc) {
+            newPerm.push_back(dim);
+          }
+        }
+        fusedOp = rewriter.create<TransposeOp>(
+            linalgOp.getLoc(), expandedOpOperands[0], outputs[0], newPerm);
+      })
+      // All other expandable linalg ops that are not generic or transpose can
+      // be cloned with the expanded input and output operands.
+      .Default([&](Operation *op) {
+        fusedOp = clone(
+            rewriter, linalgOp, resultTypes,
+            llvm::to_vector(llvm::concat<Value>(expandedOpOperands, outputs)));
+      });
   // Reshape the result values to their original shape if this is a collapsing
   // reshape folded into its consumer.
   SmallVector<Value> resultVals;
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index ef853e4d662a7..80cebab590f6f 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -753,7 +753,6 @@ func.func @linalg_add_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
   return %1 : tensor<?x?x4x5xf32>
 }
 
-//  CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 //      CHECK: func @linalg_add_reshape_consumer_fusion
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
@@ -774,18 +773,13 @@ func.func @linalg_add_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
 //      CHECK:   %[[DIM_5:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32>
 //      CHECK:   %[[VAL_2:.+]] = arith.divsi %[[DIM_5]], %[[C20]] : index
 //      CHECK:   %[[T3:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_4]], %[[VAL_2]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
-//      CHECK:   %[[T4:.+]] = linalg.generic
-// CHECK-SAME:     indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
-// CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
+//      CHECK:   %[[T4:.+]] = linalg.add
 // CHECK-SAME:     ins(%[[T1]], %[[T2]] : tensor<?x?x4x5xf32>, tensor<?x?x4x5xf32>)
 // CHECK-SAME:     outs(%[[T3]] : tensor<?x?x4x5xf32>)
 //      CHECK:   return %[[T4]] : tensor<?x?x4x5xf32>
 
 // -----
 
-#map0 = affine_map<(d0, d1, d2) -> (d2, d0)>
-#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
 func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
                                               %arg1 : tensor<?x?xf32>,
                                               %arg2 : tensor<?x?xf32>) ->
@@ -798,7 +792,6 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
   return %1 : tensor<?x?xf32>
 }
 
-//  CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 //      CHECK: func @linalg_add_reshape_producer_fusion
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
@@ -817,9 +810,7 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
 //      CHECK:   %[[VAL_2:.+]] = arith.divsi %[[DIM_1]], %[[C7]] : index
 //      CHECK:   %[[VAL_3:.+]] = arith.divsi %[[DIM_2]], %[[C8]] : index
 //      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_2]], 7, %[[VAL_3]], 8] : tensor<?x?xf32> into tensor<?x7x?x8xf32>
-//      CHECK:   %[[T3:.+]] = linalg.generic
-// CHECK-SAME:     indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]]]
-// CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
+//      CHECK:   %[[T3:.+]] = linalg.add
 // CHECK-SAME:     ins(%[[ARG0]], %[[T1]] : tensor<?x7x?x8xf32>, tensor<?x7x?x8xf32>)
 // CHECK-SAME:     outs(%[[T2]] : tensor<?x7x?x8xf32>)
 //      CHECK:   %[[T4:.+]] = tensor.collapse_shape %[[T3]]
@@ -827,6 +818,42 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
 // CHECK-SAME:     tensor<?x7x?x8xf32> into tensor<?x?xf32>
 //      CHECK:   return %[[T4]]
 
+// -----
+
+func.func @linalg_transpose_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
+                                              %arg1 : tensor<?x?xf32>) ->
+                                              tensor<?x?xf32>
+{
+  %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] :
+    tensor<?x7x?x8xf32> into tensor<?x?xf32>
+  %1 = linalg.transpose ins(%0 : tensor<?x?xf32>)
+       outs(%arg1 : tensor<?x?xf32>) permutation = [1, 0]
+  return %1 : tensor<?x?xf32>
+}
+
+//      CHECK: func @linalg_transpose_reshape_producer_fusion
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//  CHECK-DAG:   %[[C8:.+]] = arith.constant 8 : index
+//  CHECK-DAG:   %[[C7:.+]] = arith.constant 7 : index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
+//  CHECK-DAG:   %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
+//  CHECK-DAG:   %[[VAL_0:.+]] = arith.divsi %[[DIM_0]], %[[C7]] : index
+//  CHECK-DAG:   %[[VAL_1:.+]] = arith.divsi %[[DIM]], %[[C8]] : index
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_1]], 8, %[[VAL_0]], 7] : tensor<?x?xf32> into tensor<?x8x?x7xf32>
+//      CHECK:   %[[T2:.+]] = linalg.transpose
+// CHECK-SAME:     ins(%[[ARG0]] : tensor<?x7x?x8xf32>)
+// CHECK-SAME:     outs(%[[T1]] : tensor<?x8x?x7xf32>)
+// CHECK-SAME:   permutation = [2, 3, 0, 1]
+//      CHECK:   %[[T3:.+]] = tensor.collapse_shape %[[T2]]
+// CHECK-SAME:     [0, 1], [2, 3]
+// CHECK-SAME:     tensor<?x8x?x7xf32> into tensor<?x?xf32>
+//      CHECK:   return %[[T3]]
+
+
+
 // -----
 
 func.func @fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x336x14xi32> {

>From 8e07bc5d2f75afc0fb80b65464e597da7b107ffc Mon Sep 17 00:00:00 2001
From: Nirvedh Meshram <nirvedh at gmail.com>
Date: Fri, 28 Feb 2025 11:37:17 -0600
Subject: [PATCH 2/2] Address reviewer comments

Signed-off-by: Nirvedh Meshram <nirvedh at gmail.com>
---
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 99 ++++++++++---------
 mlir/test/Dialect/Linalg/reshape_fusion.mlir  | 45 ++++++++-
 2 files changed, 96 insertions(+), 48 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index f64151db8e5a0..384f923f010c4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -814,6 +814,55 @@ validateDynamicDimExpansion(LinalgOp linalgOp,
   }
   return success();
 }
+// Create an expanded fused op that retains the name for certain ops
+// such as fill, copy and transpose and produce a generic op for
+// rest of linalg ops.
+Operation *createFusedOpForReshapeByExpansion(
+    PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
+    ArrayRef<Value> expandedOpOperands, ArrayRef<Value> outputs,
+    ArrayRef<AffineMap> expandedOpIndexingMaps, ExpansionInfo &expansionInfo,
+    SmallVector<ReassociationIndices> reassociation) {
+
+  return TypeSwitch<Operation *, Operation *>(linalgOp.getOperation())
+      .Case<TransposeOp>([&](TransposeOp op) {
+        applyPermutationToVector(reassociation, op.getPermutation());
+        SmallVector<int64_t> newPerm;
+        for (auto reassoc : reassociation) {
+          for (auto dim : reassoc) {
+            newPerm.push_back(dim);
+          }
+        }
+        return rewriter.create<TransposeOp>(
+            linalgOp.getLoc(), expandedOpOperands[0], outputs[0], newPerm);
+      })
+      .Case<FillOp, CopyOp>([&](Operation *op) {
+        return clone(rewriter, linalgOp, resultTypes,
+                     llvm::to_vector(llvm::concat<Value>(
+                         llvm::to_vector(expandedOpOperands),
+                         llvm::to_vector(outputs))));
+      })
+      .Default([&](Operation *op) {
+        // The iterator types of the expanded op are all parallel.
+        SmallVector<utils::IteratorType> iteratorTypes(
+            expansionInfo.getExpandedOpNumDims(),
+            utils::IteratorType::parallel);
+        for (auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
+          for (auto j : expansionInfo.getExpandedDims(i))
+            iteratorTypes[j] = type;
+        Operation *fused = rewriter.create<GenericOp>(
+            linalgOp.getLoc(), resultTypes, expandedOpOperands, outputs,
+            expandedOpIndexingMaps, iteratorTypes);
+        Region &fusedRegion = fused->getRegion(0);
+        Region &originalRegion = linalgOp->getRegion(0);
+        rewriter.cloneRegionBefore(originalRegion, fusedRegion,
+                                   fusedRegion.begin());
+
+        // Update the index accesses after the expansion.
+        updateExpandedGenericOpRegion(rewriter, linalgOp.getLoc(), fusedRegion,
+                                      expansionInfo);
+        return fused;
+      });
+}
 
 /// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
@@ -919,51 +968,13 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
     }
   }
 
-  // The iterator types of the expanded op are all parallel.
-  SmallVector<utils::IteratorType> iteratorTypes(
-      expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
-  for (auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
-    for (auto j : expansionInfo.getExpandedDims(i))
-      iteratorTypes[j] = type;
-
   TypeRange resultTypes = ValueRange(outputs).getTypes();
-  Operation *fusedOp;
-
-  TypeSwitch<Operation *>(linalgOp.getOperation())
-      .Case<GenericOp>([&](GenericOp op) {
-        fusedOp = rewriter.create<GenericOp>(
-            linalgOp.getLoc(), resultTypes, expandedOpOperands, outputs,
-            expandedOpIndexingMaps, iteratorTypes);
-        Region &fusedRegion = fusedOp->getRegion(0);
-        Region &originalRegion = linalgOp->getRegion(0);
-        rewriter.cloneRegionBefore(originalRegion, fusedRegion,
-                                   fusedRegion.begin());
-
-        // Update the index accesses after the expansion.
-        updateExpandedGenericOpRegion(rewriter, loc, fusedRegion,
-                                      expansionInfo);
-      })
-      .Case<TransposeOp>([&](TransposeOp op) {
-        SmallVector<ReassociationIndices> reassociation =
-            isExpanding ? expandingReshapeOp.getReassociationIndices()
-                        : collapsingReshapeOp.getReassociationIndices();
-        applyPermutationToVector(reassociation, op.getPermutation());
-        SmallVector<int64_t> newPerm;
-        for (auto reassoc : reassociation) {
-          for (auto dim : reassoc) {
-            newPerm.push_back(dim);
-          }
-        }
-        fusedOp = rewriter.create<TransposeOp>(
-            linalgOp.getLoc(), expandedOpOperands[0], outputs[0], newPerm);
-      })
-      // All other expandable linalg ops that are not generic or transpose can
-      // be cloned with the expanded input and output operands.
-      .Default([&](Operation *op) {
-        fusedOp = clone(
-            rewriter, linalgOp, resultTypes,
-            llvm::to_vector(llvm::concat<Value>(expandedOpOperands, outputs)));
-      });
+  SmallVector<ReassociationIndices> reassociationBeforeExpansion =
+      isExpanding ? expandingReshapeOp.getReassociationIndices()
+                  : collapsingReshapeOp.getReassociationIndices();
+  Operation *fusedOp = createFusedOpForReshapeByExpansion(
+      rewriter, linalgOp, resultTypes, expandedOpOperands, outputs,
+      expandedOpIndexingMaps, expansionInfo, reassociationBeforeExpansion);
   // Reshape the result values to their original shape if this is a collapsing
   // reshape folded into its consumer.
   SmallVector<Value> resultVals;
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 80cebab590f6f..dab15615459f8 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -753,6 +753,7 @@ func.func @linalg_add_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
   return %1 : tensor<?x?x4x5xf32>
 }
 
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 //      CHECK: func @linalg_add_reshape_consumer_fusion
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
@@ -773,7 +774,9 @@ func.func @linalg_add_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
 //      CHECK:   %[[DIM_5:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32>
 //      CHECK:   %[[VAL_2:.+]] = arith.divsi %[[DIM_5]], %[[C20]] : index
 //      CHECK:   %[[T3:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_4]], %[[VAL_2]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
-//      CHECK:   %[[T4:.+]] = linalg.add
+//      CHECK:   %[[T4:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
+// CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
 // CHECK-SAME:     ins(%[[T1]], %[[T2]] : tensor<?x?x4x5xf32>, tensor<?x?x4x5xf32>)
 // CHECK-SAME:     outs(%[[T3]] : tensor<?x?x4x5xf32>)
 //      CHECK:   return %[[T4]] : tensor<?x?x4x5xf32>
@@ -792,6 +795,7 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
   return %1 : tensor<?x?xf32>
 }
 
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 //      CHECK: func @linalg_add_reshape_producer_fusion
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
@@ -810,7 +814,9 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
 //      CHECK:   %[[VAL_2:.+]] = arith.divsi %[[DIM_1]], %[[C7]] : index
 //      CHECK:   %[[VAL_3:.+]] = arith.divsi %[[DIM_2]], %[[C8]] : index
 //      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_2]], 7, %[[VAL_3]], 8] : tensor<?x?xf32> into tensor<?x7x?x8xf32>
-//      CHECK:   %[[T3:.+]] = linalg.add
+//      CHECK:   %[[T3:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]]]
+// CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
 // CHECK-SAME:     ins(%[[ARG0]], %[[T1]] : tensor<?x7x?x8xf32>, tensor<?x7x?x8xf32>)
 // CHECK-SAME:     outs(%[[T2]] : tensor<?x7x?x8xf32>)
 //      CHECK:   %[[T4:.+]] = tensor.collapse_shape %[[T3]]
@@ -820,6 +826,39 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
 
 // -----
 
+func.func @linalg_copy_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
+                                              %arg1 : tensor<?x?xf32>) ->
+                                              tensor<?x?xf32>
+{
+  %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] :
+    tensor<?x7x?x8xf32> into tensor<?x?xf32>
+  %1 = linalg.copy ins(%0 : tensor<?x?xf32>)
+       outs(%arg1 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+//      CHECK: func @linalg_copy_reshape_producer_fusion
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//      CHECK:   %[[C8:.+]] = arith.constant 8 : index
+//      CHECK:   %[[C7:.+]] = arith.constant 7 : index
+//      CHECK:   %[[C1:.+]] = arith.constant 1 : index
+//      CHECK:   %[[C0:.+]] = arith.constant 0 : index
+//      CHECK:   %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
+//      CHECK:   %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
+//      CHECK:   %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C7]] : index
+//      CHECK:   %[[VAL_1:.+]] = arith.divsi %[[DIM_0]], %[[C8]] : index
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_0]], 7, %[[VAL_1]], 8] : tensor<?x?xf32> into tensor<?x7x?x8xf32>
+//      CHECK:   %[[T2:.+]] = linalg.copy
+// CHECK-SAME:     ins(%[[ARG0]] : tensor<?x7x?x8xf32>)
+// CHECK-SAME:     outs(%[[T1]] : tensor<?x7x?x8xf32>)
+//      CHECK:   %[[T3:.+]] = tensor.collapse_shape %[[T2]]
+// CHECK-SAME:     [0, 1], [2, 3]
+// CHECK-SAME:     tensor<?x7x?x8xf32> into tensor<?x?xf32>
+//      CHECK:   return %[[T3]]
+
+// -----
+
 func.func @linalg_transpose_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
                                               %arg1 : tensor<?x?xf32>) ->
                                               tensor<?x?xf32>
@@ -852,8 +891,6 @@ func.func @linalg_transpose_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
 // CHECK-SAME:     tensor<?x8x?x7xf32> into tensor<?x?xf32>
 //      CHECK:   return %[[T3]]
 
-
-
 // -----
 
 func.func @fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x336x14xi32> {



More information about the Mlir-commits mailing list