[Mlir-commits] [mlir] [mlir][linalg] Enable expansion of parallel dims of reduction ops (PR #83473)

Quinn Dawkins llvmlistbot at llvm.org
Sat Mar 2 16:08:50 PST 2024


https://github.com/qedawkins updated https://github.com/llvm/llvm-project/pull/83473

>From 542787dcd06b8a091ce963e9d2f0deb7e4855035 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Thu, 29 Feb 2024 14:47:06 -0500
Subject: [PATCH 1/3] [mlir][linalg] Enable expansion of parallel dims of
 reduction ops

This adds support for expansion of linalg ops with reduction iterators.
This improves the ability to make fusion decisions WRT reduction
operations. To recover the previous behavior, users of the patterns can
add a control function to restrict propagation of reshape by expansion
through linalg ops with reduction iterators.
---
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 18 +++-
 mlir/test/Dialect/Linalg/reshape_fusion.mlir  | 90 +++++++++++++++++++
 2 files changed, 104 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 4797bfb2267d7f..6310f9105960be 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -526,7 +526,10 @@ static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
   // - All the indexing maps for operands and results are projected
   //   permutations.
   // - The fused tensor is not a scalar.
-  // - All the loops are parallel loops.
+  // - All the loops for the reshaped operand are parallel loops.
+  SmallVector<utils::IteratorType> iteratorTypes =
+      genericOp.getIteratorTypesArray();
+  AffineMap operandMap = genericOp.getMatchingIndexingMap(fusableOpOperand);
   return genericOp.hasPureTensorSemantics() &&
          llvm::all_of(genericOp.getIndexingMaps().getValue(),
                       [](Attribute attr) {
@@ -534,9 +537,11 @@ static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
                             .getValue()
                             .isProjectedPermutation();
                       }) &&
-         genericOp.getMatchingIndexingMap(fusableOpOperand).getNumResults() >
-             0 &&
-         llvm::all_of(genericOp.getIteratorTypesArray(), isParallelIterator);
+         operandMap.getNumResults() > 0 &&
+         llvm::all_of(operandMap.getResults(), [&](AffineExpr expr) {
+           return isParallelIterator(
+               iteratorTypes[cast<AffineDimExpr>(expr).getPosition()]);
+         });
 }
 
 namespace {
@@ -848,6 +853,11 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
   // The iterator types of the expanded op are all parallel.
   SmallVector<utils::IteratorType> iteratorTypes(
       expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
+  for (auto [i, type] : llvm::enumerate(genericOp.getIteratorTypesArray())) {
+    ReassociationIndicesRef group = expansionInfo.getExpandedDims(i);
+    for (auto i : group)
+      iteratorTypes[i] = type;
+  }
 
   TypeRange resultTypes = ValueRange(outputs).getTypes();
   auto fusedOp =
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 0e40b5fbed97cb..5c0a83258b4b95 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -573,3 +573,93 @@ module {
 // CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]] :
 // CHECK-SAME:       outs(%[[ARG2]], %[[OUTS]] :
 //      CHECK:   return %[[GENERIC]]#1
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @generic_op_reshape_consumer_fusion_reduction(%arg0 : tensor<?x?xf32>,
+                                                        %arg1 : tensor<?x?xf32>,
+                                                        %arg2 : tensor<?x?xf32>) ->
+                                                        tensor<?x?x4x5xf32>
+{
+  %0 = linalg.generic {
+     indexing_maps = [#map0, #map1, #map2],
+     iterator_types = ["parallel", "parallel", "reduction"]}
+       ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+       outs(%arg2 : tensor<?x?xf32>) {
+    ^bb0(%arg3: f32, %arg4: f32, %s: f32):
+      %1 = arith.mulf %arg3, %arg4 : f32
+      linalg.yield %1 : f32
+  } -> tensor<?x?xf32>
+  %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] :
+    tensor<?x?xf32> into tensor<?x?x4x5xf32>
+  return %1 : tensor<?x?x4x5xf32>
+}
+
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d3, d4)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
+//      CHECK: func @generic_op_reshape_consumer_fusion_reduction
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]]
+// CHECK-SAME:     [0, 1, 2], [3]
+// CHECK-SAME:     tensor<?x?xf32> into tensor<?x4x5x?xf32>
+//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG2]]
+// CHECK-SAME:     [0], [1, 2, 3]
+// CHECK-SAME:     tensor<?x?xf32> into tensor<?x?x4x5xf32>
+//      CHECK:   %[[T3:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME:     ins(%[[ARG0]], %[[T1]] : tensor<?x?xf32>, tensor<?x4x5x?xf32>)
+// CHECK-SAME:     outs(%[[T2]] : tensor<?x?x4x5xf32>)
+//      CHECK:   return %[[T3]] : tensor<?x?x4x5xf32>
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d2, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
+func.func @generic_op_reshape_producer_fusion_with_reduction(%arg0 : tensor<?x7x?x8xf32>,
+                                         %arg1 : tensor<?x4x?xf32>,
+                                         %arg2 : tensor<?x?xf32>) ->
+                                         tensor<?x?xf32>
+{
+  %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] :
+    tensor<?x7x?x8xf32> into tensor<?x?xf32>
+  %1 = linalg.generic {
+     indexing_maps = [#map0, #map1, #map2],
+     iterator_types = ["parallel", "reduction", "parallel"]}
+       ins(%0, %arg1 : tensor<?x?xf32>, tensor<?x4x?xf32>)
+       outs(%arg2 : tensor<?x?xf32>) {
+    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+      %1 = arith.mulf %arg3, %arg4 : f32
+      %2 = arith.addf %1, %arg5 : f32
+      linalg.yield %2 : f32
+  } -> tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+//  CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d0, d1)>
+//  CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+//  CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>
+//      CHECK: func @generic_op_reshape_producer_fusion_with_reduction
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x4x?xf32>
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]]
+// CHECK-SAME:     [0, 1], [2], [3, 4]
+//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG2]]
+// CHECK-SAME:     [0, 1], [2, 3]
+//      CHECK:   %[[T3:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
+// CHECK-SAME:     ["parallel", "parallel", "reduction", "parallel", "parallel"]
+// CHECK-SAME:     ins(%[[ARG0]], %[[T1]] : tensor<?x7x?x8xf32>, tensor<?x8x4x?x7xf32>)
+// CHECK-SAME:     outs(%[[T2]] : tensor<?x8x?x7xf32>)
+//      CHECK:   %[[T4:.+]] = tensor.collapse_shape %[[T3]]
+// CHECK-SAME:     [0, 1], [2, 3]
+// CHECK-SAME:     tensor<?x8x?x7xf32> into tensor<?x?xf32>
+//      CHECK:   return %[[T4]]

>From c0b6a7ac1d6a57dda17b737312f16fc660bc4e63 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Thu, 29 Feb 2024 16:21:18 -0500
Subject: [PATCH 2/3] Address comments

---
 .../lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp | 8 +++-----
 1 file changed, 3 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 6310f9105960be..402a7d58333a5c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -853,11 +853,9 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
   // The iterator types of the expanded op are all parallel.
   SmallVector<utils::IteratorType> iteratorTypes(
       expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
-  for (auto [i, type] : llvm::enumerate(genericOp.getIteratorTypesArray())) {
-    ReassociationIndicesRef group = expansionInfo.getExpandedDims(i);
-    for (auto i : group)
-      iteratorTypes[i] = type;
-  }
+  for (auto [i, type] : llvm::enumerate(genericOp.getIteratorTypesArray()))
+    for (auto j : expansionInfo.getExpandedDims(i))
+      iteratorTypes[j] = type;
 
   TypeRange resultTypes = ValueRange(outputs).getTypes();
   auto fusedOp =

>From bcdc0b810dd13979246810554e0d8b3ed8ddc06e Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Sat, 2 Mar 2024 18:58:01 -0500
Subject: [PATCH 3/3] Make fusion by expansion work for linalg ops

---
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 98 +++++++++----------
 mlir/test/Dialect/Linalg/reshape_fusion.mlir  | 71 ++++++++++++++
 2 files changed, 120 insertions(+), 49 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 402a7d58333a5c..9453502a253f16 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -467,10 +467,10 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
 // expanding the dimensionality of the elementwise operations.
 //===---------------------------------------------------------------------===//
 
-/// Conditions for folding a generic operation with a reshape op by expanding
-/// the iteration space dimensionality for tensor operations. These are
-/// preconditions assumed by `foldReshapeByDimExpansion` which implements the
-/// following fusion pattern.
+/// Conditions for folding a structured linalg operation with a reshape op by
+/// expanding the iteration space dimensionality for tensor operations. These
+/// are preconditions assumed by `foldReshapeByDimExpansion` which implements
+/// the following fusion pattern.
 ///
 ///  Consider
 ///
@@ -481,9 +481,9 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
 ///  %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
 ///
-///  The reshape can be folded into the `genericOp` if its loop dimensionality
+///  The reshape can be folded into the `linalgOp` if its loop dimensionality
 ///  is increased to match the result (operand) of the tensor.expand_shape.
-///  The indexing_map of the fused tensor in the `genericOp` and the
+///  The indexing_map of the fused tensor in the `linalgOp` and the
 ///  reassociation map helps compute the indexing maps of the modified op.
 ///  For the above example, based on the reassociation map it
 ///  can be concluded that
@@ -502,7 +502,7 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
 ///   d1 -> e2, e3, e4
 ///   d2 -> e5
 ///
-///  substituting this, the generic op can be rewritten as
+///  substituting this, the structured op can be rewritten as
 ///
 ///  %d = linalg.generic ins(%0, %1 : )
 ///        indexing_maps =
@@ -520,7 +520,7 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
 ///
 ///  The added reshapes are again expanding patterns, so they will get fused
 ///  with its producers if possible.
-static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
+static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
                                                OpOperand *fusableOpOperand) {
   // Is fusable only if:
   // - All the indexing maps for operands and results are projected
@@ -528,10 +528,10 @@ static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
   // - The fused tensor is not a scalar.
   // - All the loops for the reshaped operand are parallel loops.
   SmallVector<utils::IteratorType> iteratorTypes =
-      genericOp.getIteratorTypesArray();
-  AffineMap operandMap = genericOp.getMatchingIndexingMap(fusableOpOperand);
-  return genericOp.hasPureTensorSemantics() &&
-         llvm::all_of(genericOp.getIndexingMaps().getValue(),
+      linalgOp.getIteratorTypesArray();
+  AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
+  return linalgOp.hasPureTensorSemantics() &&
+         llvm::all_of(linalgOp.getIndexingMaps().getValue(),
                       [](Attribute attr) {
                         return cast<AffineMapAttr>(attr)
                             .getValue()
@@ -633,10 +633,10 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
 /// Note that this could be extended to handle dynamic case, but the
 /// implementation below uses `affine.apply` which seems to have issues when the
 /// shapes are not static.
-static LogicalResult isGenericOpExpandable(GenericOp genericOp,
-                                           const ExpansionInfo &expansionInfo,
-                                           PatternRewriter &rewriter) {
-  if (!genericOp.hasIndexSemantics())
+static LogicalResult isLinalgOpExpandable(LinalgOp linalgOp,
+                                          const ExpansionInfo &expansionInfo,
+                                          PatternRewriter &rewriter) {
+  if (!linalgOp.hasIndexSemantics())
     return success();
   for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
     ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
@@ -645,7 +645,7 @@ static LogicalResult isGenericOpExpandable(GenericOp genericOp,
     for (int64_t shape : expandedShape.drop_front()) {
       if (ShapedType::isDynamic(shape)) {
         return rewriter.notifyMatchFailure(
-            genericOp, "cannot expand due to index semantics and dynamic dims");
+            linalgOp, "cannot expand due to index semantics and dynamic dims");
       }
     }
   }
@@ -754,10 +754,10 @@ static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
 /// that those conditions have been satisfied.
 static std::optional<SmallVector<Value>>
-fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
+fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
                            OpOperand *fusableOpOperand,
                            PatternRewriter &rewriter) {
-  assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) &&
+  assert(isFusableWithReshapeByDimExpansion(linalgOp, fusableOpOperand) &&
          "preconditions for fuse operation failed");
   // Check if reshape is expanding or collapsing.
   auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
@@ -772,27 +772,27 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
 
   ExpansionInfo expansionInfo;
   if (failed(expansionInfo.compute(
-          genericOp, fusableOpOperand,
+          linalgOp, fusableOpOperand,
           isExpanding ? expandingReshapeOp.getReassociationMaps()
                       : collapsingReshapeOp.getReassociationMaps(),
           expandedType.getShape(), collapsedType.getShape(), rewriter)))
     return std::nullopt;
 
-  if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter)))
+  if (failed(isLinalgOpExpandable(linalgOp, expansionInfo, rewriter)))
     return std::nullopt;
 
   SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
-      llvm::map_range(genericOp.getIndexingMapsArray(), [&](AffineMap m) {
+      llvm::map_range(linalgOp.getIndexingMapsArray(), [&](AffineMap m) {
         return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
       }));
 
   // Set insertion point to the generic op.
   OpBuilder::InsertionGuard g(rewriter);
-  rewriter.setInsertionPoint(genericOp);
+  rewriter.setInsertionPoint(linalgOp);
 
   SmallVector<Value> expandedOpOperands;
-  expandedOpOperands.reserve(genericOp.getNumDpsInputs());
-  for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
+  expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
+  for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
     if (opOperand == fusableOpOperand) {
       expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc()
                                                : collapsingReshapeOp.getSrc());
@@ -800,7 +800,7 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
     }
     if (auto opOperandType =
             dyn_cast<RankedTensorType>(opOperand->get().getType())) {
-      AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
+      AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
       RankedTensorType expandedOperandType =
           getExpandedType(opOperandType, indexingMap, expansionInfo);
       if (expandedOperandType != opOperand->get().getType()) {
@@ -809,14 +809,14 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
             getReassociationForExpansion(indexingMap, expansionInfo);
         if (failed(reshapeLikeShapesAreCompatible(
                 [&](const Twine &msg) {
-                  return rewriter.notifyMatchFailure(genericOp, msg);
+                  return rewriter.notifyMatchFailure(linalgOp, msg);
                 },
                 opOperandType.getShape(), expandedOperandType.getShape(),
                 reassociation,
                 /*isExpandingReshape=*/true)))
           return std::nullopt;
         expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
-            genericOp.getLoc(), expandedOperandType, opOperand->get(),
+            linalgOp.getLoc(), expandedOperandType, opOperand->get(),
             reassociation));
         continue;
       }
@@ -824,10 +824,10 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
     expandedOpOperands.push_back(opOperand->get());
   }
 
-  Location loc = genericOp.getLoc();
+  Location loc = linalgOp.getLoc();
   SmallVector<Value> outputs;
-  for (OpOperand &opOperand : genericOp.getDpsInitsMutable()) {
-    AffineMap indexingMap = genericOp.getMatchingIndexingMap(&opOperand);
+  for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
+    AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
     auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
     RankedTensorType expandedOutputType =
         getExpandedType(opOperandType, indexingMap, expansionInfo);
@@ -836,14 +836,14 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
           getReassociationForExpansion(indexingMap, expansionInfo);
       if (failed(reshapeLikeShapesAreCompatible(
               [&](const Twine &msg) {
-                return rewriter.notifyMatchFailure(genericOp, msg);
+                return rewriter.notifyMatchFailure(linalgOp, msg);
               },
               opOperandType.getShape(), expandedOutputType.getShape(),
               reassociation,
               /*isExpandingReshape=*/true)))
         return std::nullopt;
       outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
-          genericOp.getLoc(), expandedOutputType, opOperand.get(),
+          linalgOp.getLoc(), expandedOutputType, opOperand.get(),
           reassociation));
     } else {
       outputs.push_back(opOperand.get());
@@ -853,17 +853,17 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
   // The iterator types of the expanded op are all parallel.
   SmallVector<utils::IteratorType> iteratorTypes(
       expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
-  for (auto [i, type] : llvm::enumerate(genericOp.getIteratorTypesArray()))
+  for (auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
     for (auto j : expansionInfo.getExpandedDims(i))
       iteratorTypes[j] = type;
 
   TypeRange resultTypes = ValueRange(outputs).getTypes();
   auto fusedOp =
-      rewriter.create<GenericOp>(genericOp.getLoc(), resultTypes,
+      rewriter.create<GenericOp>(linalgOp.getLoc(), resultTypes,
                                  /*inputs=*/expandedOpOperands, outputs,
                                  expandedOpIndexingMaps, iteratorTypes);
   Region &fusedRegion = fusedOp->getRegion(0);
-  Region &originalRegion = genericOp->getRegion(0);
+  Region &originalRegion = linalgOp->getRegion(0);
   rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
 
   // Update the index accesses after the expansion.
@@ -872,16 +872,16 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
   // Reshape the result values to their original shape if this is a collapsing
   // reshape folded into its consumer.
   SmallVector<Value> resultVals;
-  for (OpResult opResult : genericOp->getOpResults()) {
+  for (OpResult opResult : linalgOp->getOpResults()) {
     int64_t resultNumber = opResult.getResultNumber();
     if (resultTypes[resultNumber] != opResult.getType()) {
       SmallVector<ReassociationIndices> reassociation =
           getReassociationForExpansion(
-              genericOp.getMatchingIndexingMap(
-                  genericOp.getDpsInitOperand(resultNumber)),
+              linalgOp.getMatchingIndexingMap(
+                  linalgOp.getDpsInitOperand(resultNumber)),
               expansionInfo);
       resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>(
-          genericOp.getLoc(), opResult.getType(),
+          linalgOp.getLoc(), opResult.getType(),
           fusedOp->getResult(resultNumber), reassociation));
     } else {
       resultVals.push_back(fusedOp->getResult(resultNumber));
@@ -893,21 +893,21 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
 
 namespace {
 
-/// Pattern to fuse a tensor.collapse_shape op with its consumer generic op,
+/// Pattern to fuse a tensor.collapse_shape op with its consumer structured op,
 /// when the reshape op is collapsing dimensions. The dimensionality of the loop
 /// in the consumer is expanded.
 class FoldWithProducerReshapeOpByExpansion
-    : public OpRewritePattern<GenericOp> {
+    : public OpInterfaceRewritePattern<LinalgOp> {
 public:
   FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
                                        ControlFusionFn foldReshapes,
                                        PatternBenefit benefit = 1)
-      : OpRewritePattern<GenericOp>(context, benefit),
+      : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
         controlFoldingReshapes(std::move(foldReshapes)) {}
 
-  LogicalResult matchAndRewrite(GenericOp genericOp,
+  LogicalResult matchAndRewrite(LinalgOp linalgOp,
                                 PatternRewriter &rewriter) const override {
-    for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
+    for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
       tensor::CollapseShapeOp reshapeOp =
           opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
       if (!reshapeOp)
@@ -915,15 +915,15 @@ class FoldWithProducerReshapeOpByExpansion
       // Fold only if
       // - The tensor reshape op is folding.
       // - All constraints of fusing with reshape by expansion are met.
-      if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
+      if (!isFusableWithReshapeByDimExpansion(linalgOp, opOperand) ||
           (!controlFoldingReshapes(opOperand)))
         continue;
 
       std::optional<SmallVector<Value>> replacementValues =
-          fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter);
+          fuseWithReshapeByExpansion(linalgOp, reshapeOp, opOperand, rewriter);
       if (!replacementValues)
         return failure();
-      rewriter.replaceOp(genericOp, *replacementValues);
+      rewriter.replaceOp(linalgOp, *replacementValues);
       return success();
     }
     return failure();
@@ -953,7 +953,7 @@ struct FoldReshapeWithGenericOpByExpansion
                                          "source not produced by an operation");
     }
 
-    auto producer = dyn_cast<GenericOp>(producerResult.getOwner());
+    auto producer = dyn_cast<LinalgOp>(producerResult.getOwner());
     if (!producer) {
       return rewriter.notifyMatchFailure(reshapeOp,
                                          "producer not a generic op");
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 5c0a83258b4b95..342c067b5c4ba4 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -663,3 +663,74 @@ func.func @generic_op_reshape_producer_fusion_with_reduction(%arg0 : tensor<?x7x
 // CHECK-SAME:     [0, 1], [2, 3]
 // CHECK-SAME:     tensor<?x8x?x7xf32> into tensor<?x?xf32>
 //      CHECK:   return %[[T4]]
+
+// -----
+
+func.func @linalg_add_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
+                                              %arg1 : tensor<?x?xf32>,
+                                              %arg2 : tensor<?x?xf32>) ->
+                                              tensor<?x?x4x5xf32>
+{
+  %0 = linalg.add ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+       outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] :
+    tensor<?x?xf32> into tensor<?x?x4x5xf32>
+  return %1 : tensor<?x?x4x5xf32>
+}
+
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+//      CHECK: func @linalg_add_reshape_consumer_fusion
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG0]]
+// CHECK-SAME:     [0], [1, 2, 3]
+// CHECK-SAME:     tensor<?x?xf32> into tensor<?x?x4x5xf32>
+//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG1]]
+// CHECK-SAME:     [0], [1, 2, 3]
+// CHECK-SAME:     tensor<?x?xf32> into tensor<?x?x4x5xf32>
+//      CHECK:   %[[T3:.+]] = tensor.expand_shape %[[ARG2]]
+// CHECK-SAME:     [0], [1, 2, 3]
+// CHECK-SAME:     tensor<?x?xf32> into tensor<?x?x4x5xf32>
+//      CHECK:   %[[T4:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
+// CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME:     ins(%[[T1]], %[[T2]] : tensor<?x?x4x5xf32>, tensor<?x?x4x5xf32>)
+// CHECK-SAME:     outs(%[[T3]] : tensor<?x?x4x5xf32>)
+//      CHECK:   return %[[T4]] : tensor<?x?x4x5xf32>
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d2, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
+func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
+                                              %arg1 : tensor<?x?xf32>,
+                                              %arg2 : tensor<?x?xf32>) ->
+                                              tensor<?x?xf32>
+{
+  %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] :
+    tensor<?x7x?x8xf32> into tensor<?x?xf32>
+  %1 = linalg.add ins(%0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+       outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+//  CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+//      CHECK: func @linalg_add_reshape_producer_fusion
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]]
+// CHECK-SAME:     [0, 1], [2, 3]
+//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG2]]
+// CHECK-SAME:     [0, 1], [2, 3]
+//      CHECK:   %[[T3:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]]]
+// CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME:     ins(%[[ARG0]], %[[T1]] : tensor<?x7x?x8xf32>, tensor<?x7x?x8xf32>)
+// CHECK-SAME:     outs(%[[T2]] : tensor<?x7x?x8xf32>)
+//      CHECK:   %[[T4:.+]] = tensor.collapse_shape %[[T3]]
+// CHECK-SAME:     [0, 1], [2, 3]
+// CHECK-SAME:     tensor<?x7x?x8xf32> into tensor<?x?xf32>
+//      CHECK:   return %[[T4]]



More information about the Mlir-commits mailing list