[Mlir-commits] [mlir] 3f18f6a - [mlir][linalg] Enable fusion by expansion of reduction and named ops (#83473)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Mar 2 22:54:07 PST 2024


Author: Quinn Dawkins
Date: 2024-03-03T01:54:03-05:00
New Revision: 3f18f6a2cfecb080f006477c46d3626102841a17

URL: https://github.com/llvm/llvm-project/commit/3f18f6a2cfecb080f006477c46d3626102841a17
DIFF: https://github.com/llvm/llvm-project/commit/3f18f6a2cfecb080f006477c46d3626102841a17.diff

LOG: [mlir][linalg] Enable fusion by expansion of reduction and named ops (#83473)

This adds support for expansion of named linalg ops and linalg ops with
reduction iterators. This improves the ability to make fusion decisions
WRT reduction operations. To recover the previous behavior, users of the
patterns can add a control function to restrict propagation of reshape
by expansion through linalg ops with reduction iterators.

For named linalg ops, this always converts the named op into a generic.

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
    mlir/test/Dialect/Linalg/reshape_fusion.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 4797bfb2267d7f..9453502a253f16 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -467,10 +467,10 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
 // expanding the dimensionality of the elementwise operations.
 //===---------------------------------------------------------------------===//
 
-/// Conditions for folding a generic operation with a reshape op by expanding
-/// the iteration space dimensionality for tensor operations. These are
-/// preconditions assumed by `foldReshapeByDimExpansion` which implements the
-/// following fusion pattern.
+/// Conditions for folding a structured linalg operation with a reshape op by
+/// expanding the iteration space dimensionality for tensor operations. These
+/// are preconditions assumed by `foldReshapeByDimExpansion` which implements
+/// the following fusion pattern.
 ///
 ///  Consider
 ///
@@ -481,9 +481,9 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
 ///  %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
 ///
-///  The reshape can be folded into the `genericOp` if its loop dimensionality
+///  The reshape can be folded into the `linalgOp` if its loop dimensionality
 ///  is increased to match the result (operand) of the tensor.expand_shape.
-///  The indexing_map of the fused tensor in the `genericOp` and the
+///  The indexing_map of the fused tensor in the `linalgOp` and the
 ///  reassociation map helps compute the indexing maps of the modified op.
 ///  For the above example, based on the reassociation map it
 ///  can be concluded that
@@ -502,7 +502,7 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
 ///   d1 -> e2, e3, e4
 ///   d2 -> e5
 ///
-///  substituting this, the generic op can be rewritten as
+///  substituting this, the structured op can be rewritten as
 ///
 ///  %d = linalg.generic ins(%0, %1 : )
 ///        indexing_maps =
@@ -520,23 +520,28 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
 ///
 ///  The added reshapes are again expanding patterns, so they will get fused
 ///  with its producers if possible.
-static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
+static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
                                                OpOperand *fusableOpOperand) {
   // Is fusable only if:
   // - All the indexing maps for operands and results are projected
   //   permutations.
   // - The fused tensor is not a scalar.
-  // - All the loops are parallel loops.
-  return genericOp.hasPureTensorSemantics() &&
-         llvm::all_of(genericOp.getIndexingMaps().getValue(),
+  // - All the loops for the reshaped operand are parallel loops.
+  SmallVector<utils::IteratorType> iteratorTypes =
+      linalgOp.getIteratorTypesArray();
+  AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
+  return linalgOp.hasPureTensorSemantics() &&
+         llvm::all_of(linalgOp.getIndexingMaps().getValue(),
                       [](Attribute attr) {
                         return cast<AffineMapAttr>(attr)
                             .getValue()
                             .isProjectedPermutation();
                       }) &&
-         genericOp.getMatchingIndexingMap(fusableOpOperand).getNumResults() >
-             0 &&
-         llvm::all_of(genericOp.getIteratorTypesArray(), isParallelIterator);
+         operandMap.getNumResults() > 0 &&
+         llvm::all_of(operandMap.getResults(), [&](AffineExpr expr) {
+           return isParallelIterator(
+               iteratorTypes[cast<AffineDimExpr>(expr).getPosition()]);
+         });
 }
 
 namespace {
@@ -628,10 +633,10 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
 /// Note that this could be extended to handle dynamic case, but the
 /// implementation below uses `affine.apply` which seems to have issues when the
 /// shapes are not static.
-static LogicalResult isGenericOpExpandable(GenericOp genericOp,
-                                           const ExpansionInfo &expansionInfo,
-                                           PatternRewriter &rewriter) {
-  if (!genericOp.hasIndexSemantics())
+static LogicalResult isLinalgOpExpandable(LinalgOp linalgOp,
+                                          const ExpansionInfo &expansionInfo,
+                                          PatternRewriter &rewriter) {
+  if (!linalgOp.hasIndexSemantics())
     return success();
   for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
     ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
@@ -640,7 +645,7 @@ static LogicalResult isGenericOpExpandable(GenericOp genericOp,
     for (int64_t shape : expandedShape.drop_front()) {
       if (ShapedType::isDynamic(shape)) {
         return rewriter.notifyMatchFailure(
-            genericOp, "cannot expand due to index semantics and dynamic dims");
+            linalgOp, "cannot expand due to index semantics and dynamic dims");
       }
     }
   }
@@ -749,10 +754,10 @@ static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
 /// that those conditions have been satisfied.
 static std::optional<SmallVector<Value>>
-fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
+fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
                            OpOperand *fusableOpOperand,
                            PatternRewriter &rewriter) {
-  assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) &&
+  assert(isFusableWithReshapeByDimExpansion(linalgOp, fusableOpOperand) &&
          "preconditions for fuse operation failed");
   // Check if reshape is expanding or collapsing.
   auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
@@ -767,27 +772,27 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
 
   ExpansionInfo expansionInfo;
   if (failed(expansionInfo.compute(
-          genericOp, fusableOpOperand,
+          linalgOp, fusableOpOperand,
           isExpanding ? expandingReshapeOp.getReassociationMaps()
                       : collapsingReshapeOp.getReassociationMaps(),
           expandedType.getShape(), collapsedType.getShape(), rewriter)))
     return std::nullopt;
 
-  if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter)))
+  if (failed(isLinalgOpExpandable(linalgOp, expansionInfo, rewriter)))
     return std::nullopt;
 
   SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
-      llvm::map_range(genericOp.getIndexingMapsArray(), [&](AffineMap m) {
+      llvm::map_range(linalgOp.getIndexingMapsArray(), [&](AffineMap m) {
         return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
       }));
 
   // Set insertion point to the generic op.
   OpBuilder::InsertionGuard g(rewriter);
-  rewriter.setInsertionPoint(genericOp);
+  rewriter.setInsertionPoint(linalgOp);
 
   SmallVector<Value> expandedOpOperands;
-  expandedOpOperands.reserve(genericOp.getNumDpsInputs());
-  for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
+  expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
+  for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
     if (opOperand == fusableOpOperand) {
       expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc()
                                                : collapsingReshapeOp.getSrc());
@@ -795,7 +800,7 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
     }
     if (auto opOperandType =
             dyn_cast<RankedTensorType>(opOperand->get().getType())) {
-      AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
+      AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
       RankedTensorType expandedOperandType =
           getExpandedType(opOperandType, indexingMap, expansionInfo);
       if (expandedOperandType != opOperand->get().getType()) {
@@ -804,14 +809,14 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
             getReassociationForExpansion(indexingMap, expansionInfo);
         if (failed(reshapeLikeShapesAreCompatible(
                 [&](const Twine &msg) {
-                  return rewriter.notifyMatchFailure(genericOp, msg);
+                  return rewriter.notifyMatchFailure(linalgOp, msg);
                 },
                 opOperandType.getShape(), expandedOperandType.getShape(),
                 reassociation,
                 /*isExpandingReshape=*/true)))
           return std::nullopt;
         expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
-            genericOp.getLoc(), expandedOperandType, opOperand->get(),
+            linalgOp.getLoc(), expandedOperandType, opOperand->get(),
             reassociation));
         continue;
       }
@@ -819,10 +824,10 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
     expandedOpOperands.push_back(opOperand->get());
   }
 
-  Location loc = genericOp.getLoc();
+  Location loc = linalgOp.getLoc();
   SmallVector<Value> outputs;
-  for (OpOperand &opOperand : genericOp.getDpsInitsMutable()) {
-    AffineMap indexingMap = genericOp.getMatchingIndexingMap(&opOperand);
+  for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
+    AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
     auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
     RankedTensorType expandedOutputType =
         getExpandedType(opOperandType, indexingMap, expansionInfo);
@@ -831,14 +836,14 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
           getReassociationForExpansion(indexingMap, expansionInfo);
       if (failed(reshapeLikeShapesAreCompatible(
               [&](const Twine &msg) {
-                return rewriter.notifyMatchFailure(genericOp, msg);
+                return rewriter.notifyMatchFailure(linalgOp, msg);
               },
               opOperandType.getShape(), expandedOutputType.getShape(),
               reassociation,
               /*isExpandingReshape=*/true)))
         return std::nullopt;
       outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
-          genericOp.getLoc(), expandedOutputType, opOperand.get(),
+          linalgOp.getLoc(), expandedOutputType, opOperand.get(),
           reassociation));
     } else {
       outputs.push_back(opOperand.get());
@@ -848,14 +853,17 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
   // The iterator types of the expanded op are all parallel.
   SmallVector<utils::IteratorType> iteratorTypes(
       expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
+  for (auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
+    for (auto j : expansionInfo.getExpandedDims(i))
+      iteratorTypes[j] = type;
 
   TypeRange resultTypes = ValueRange(outputs).getTypes();
   auto fusedOp =
-      rewriter.create<GenericOp>(genericOp.getLoc(), resultTypes,
+      rewriter.create<GenericOp>(linalgOp.getLoc(), resultTypes,
                                  /*inputs=*/expandedOpOperands, outputs,
                                  expandedOpIndexingMaps, iteratorTypes);
   Region &fusedRegion = fusedOp->getRegion(0);
-  Region &originalRegion = genericOp->getRegion(0);
+  Region &originalRegion = linalgOp->getRegion(0);
   rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
 
   // Update the index accesses after the expansion.
@@ -864,16 +872,16 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
   // Reshape the result values to their original shape if this is a collapsing
   // reshape folded into its consumer.
   SmallVector<Value> resultVals;
-  for (OpResult opResult : genericOp->getOpResults()) {
+  for (OpResult opResult : linalgOp->getOpResults()) {
     int64_t resultNumber = opResult.getResultNumber();
     if (resultTypes[resultNumber] != opResult.getType()) {
       SmallVector<ReassociationIndices> reassociation =
           getReassociationForExpansion(
-              genericOp.getMatchingIndexingMap(
-                  genericOp.getDpsInitOperand(resultNumber)),
+              linalgOp.getMatchingIndexingMap(
+                  linalgOp.getDpsInitOperand(resultNumber)),
               expansionInfo);
       resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>(
-          genericOp.getLoc(), opResult.getType(),
+          linalgOp.getLoc(), opResult.getType(),
           fusedOp->getResult(resultNumber), reassociation));
     } else {
       resultVals.push_back(fusedOp->getResult(resultNumber));
@@ -885,21 +893,21 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
 
 namespace {
 
-/// Pattern to fuse a tensor.collapse_shape op with its consumer generic op,
+/// Pattern to fuse a tensor.collapse_shape op with its consumer structured op,
 /// when the reshape op is collapsing dimensions. The dimensionality of the loop
 /// in the consumer is expanded.
 class FoldWithProducerReshapeOpByExpansion
-    : public OpRewritePattern<GenericOp> {
+    : public OpInterfaceRewritePattern<LinalgOp> {
 public:
   FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
                                        ControlFusionFn foldReshapes,
                                        PatternBenefit benefit = 1)
-      : OpRewritePattern<GenericOp>(context, benefit),
+      : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
         controlFoldingReshapes(std::move(foldReshapes)) {}
 
-  LogicalResult matchAndRewrite(GenericOp genericOp,
+  LogicalResult matchAndRewrite(LinalgOp linalgOp,
                                 PatternRewriter &rewriter) const override {
-    for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
+    for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
       tensor::CollapseShapeOp reshapeOp =
           opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
       if (!reshapeOp)
@@ -907,15 +915,15 @@ class FoldWithProducerReshapeOpByExpansion
       // Fold only if
       // - The tensor reshape op is folding.
       // - All constraints of fusing with reshape by expansion are met.
-      if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
+      if (!isFusableWithReshapeByDimExpansion(linalgOp, opOperand) ||
           (!controlFoldingReshapes(opOperand)))
         continue;
 
       std::optional<SmallVector<Value>> replacementValues =
-          fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter);
+          fuseWithReshapeByExpansion(linalgOp, reshapeOp, opOperand, rewriter);
       if (!replacementValues)
         return failure();
-      rewriter.replaceOp(genericOp, *replacementValues);
+      rewriter.replaceOp(linalgOp, *replacementValues);
       return success();
     }
     return failure();
@@ -945,7 +953,7 @@ struct FoldReshapeWithGenericOpByExpansion
                                          "source not produced by an operation");
     }
 
-    auto producer = dyn_cast<GenericOp>(producerResult.getOwner());
+    auto producer = dyn_cast<LinalgOp>(producerResult.getOwner());
     if (!producer) {
       return rewriter.notifyMatchFailure(reshapeOp,
                                          "producer not a generic op");

diff  --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 0e40b5fbed97cb..342c067b5c4ba4 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -573,3 +573,164 @@ module {
 // CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]] :
 // CHECK-SAME:       outs(%[[ARG2]], %[[OUTS]] :
 //      CHECK:   return %[[GENERIC]]#1
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @generic_op_reshape_consumer_fusion_reduction(%arg0 : tensor<?x?xf32>,
+                                                        %arg1 : tensor<?x?xf32>,
+                                                        %arg2 : tensor<?x?xf32>) ->
+                                                        tensor<?x?x4x5xf32>
+{
+  %0 = linalg.generic {
+     indexing_maps = [#map0, #map1, #map2],
+     iterator_types = ["parallel", "parallel", "reduction"]}
+       ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+       outs(%arg2 : tensor<?x?xf32>) {
+    ^bb0(%arg3: f32, %arg4: f32, %s: f32):
+      %1 = arith.mulf %arg3, %arg4 : f32
+      linalg.yield %1 : f32
+  } -> tensor<?x?xf32>
+  %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] :
+    tensor<?x?xf32> into tensor<?x?x4x5xf32>
+  return %1 : tensor<?x?x4x5xf32>
+}
+
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d3, d4)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
+//      CHECK: func @generic_op_reshape_consumer_fusion_reduction
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]]
+// CHECK-SAME:     [0, 1, 2], [3]
+// CHECK-SAME:     tensor<?x?xf32> into tensor<?x4x5x?xf32>
+//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG2]]
+// CHECK-SAME:     [0], [1, 2, 3]
+// CHECK-SAME:     tensor<?x?xf32> into tensor<?x?x4x5xf32>
+//      CHECK:   %[[T3:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME:     ins(%[[ARG0]], %[[T1]] : tensor<?x?xf32>, tensor<?x4x5x?xf32>)
+// CHECK-SAME:     outs(%[[T2]] : tensor<?x?x4x5xf32>)
+//      CHECK:   return %[[T3]] : tensor<?x?x4x5xf32>
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d2, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
+func.func @generic_op_reshape_producer_fusion_with_reduction(%arg0 : tensor<?x7x?x8xf32>,
+                                         %arg1 : tensor<?x4x?xf32>,
+                                         %arg2 : tensor<?x?xf32>) ->
+                                         tensor<?x?xf32>
+{
+  %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] :
+    tensor<?x7x?x8xf32> into tensor<?x?xf32>
+  %1 = linalg.generic {
+     indexing_maps = [#map0, #map1, #map2],
+     iterator_types = ["parallel", "reduction", "parallel"]}
+       ins(%0, %arg1 : tensor<?x?xf32>, tensor<?x4x?xf32>)
+       outs(%arg2 : tensor<?x?xf32>) {
+    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+      %1 = arith.mulf %arg3, %arg4 : f32
+      %2 = arith.addf %1, %arg5 : f32
+      linalg.yield %2 : f32
+  } -> tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+//  CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d0, d1)>
+//  CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+//  CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>
+//      CHECK: func @generic_op_reshape_producer_fusion_with_reduction
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x4x?xf32>
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]]
+// CHECK-SAME:     [0, 1], [2], [3, 4]
+//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG2]]
+// CHECK-SAME:     [0, 1], [2, 3]
+//      CHECK:   %[[T3:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
+// CHECK-SAME:     ["parallel", "parallel", "reduction", "parallel", "parallel"]
+// CHECK-SAME:     ins(%[[ARG0]], %[[T1]] : tensor<?x7x?x8xf32>, tensor<?x8x4x?x7xf32>)
+// CHECK-SAME:     outs(%[[T2]] : tensor<?x8x?x7xf32>)
+//      CHECK:   %[[T4:.+]] = tensor.collapse_shape %[[T3]]
+// CHECK-SAME:     [0, 1], [2, 3]
+// CHECK-SAME:     tensor<?x8x?x7xf32> into tensor<?x?xf32>
+//      CHECK:   return %[[T4]]
+
+// -----
+
+func.func @linalg_add_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
+                                              %arg1 : tensor<?x?xf32>,
+                                              %arg2 : tensor<?x?xf32>) ->
+                                              tensor<?x?x4x5xf32>
+{
+  %0 = linalg.add ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+       outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] :
+    tensor<?x?xf32> into tensor<?x?x4x5xf32>
+  return %1 : tensor<?x?x4x5xf32>
+}
+
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+//      CHECK: func @linalg_add_reshape_consumer_fusion
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG0]]
+// CHECK-SAME:     [0], [1, 2, 3]
+// CHECK-SAME:     tensor<?x?xf32> into tensor<?x?x4x5xf32>
+//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG1]]
+// CHECK-SAME:     [0], [1, 2, 3]
+// CHECK-SAME:     tensor<?x?xf32> into tensor<?x?x4x5xf32>
+//      CHECK:   %[[T3:.+]] = tensor.expand_shape %[[ARG2]]
+// CHECK-SAME:     [0], [1, 2, 3]
+// CHECK-SAME:     tensor<?x?xf32> into tensor<?x?x4x5xf32>
+//      CHECK:   %[[T4:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
+// CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME:     ins(%[[T1]], %[[T2]] : tensor<?x?x4x5xf32>, tensor<?x?x4x5xf32>)
+// CHECK-SAME:     outs(%[[T3]] : tensor<?x?x4x5xf32>)
+//      CHECK:   return %[[T4]] : tensor<?x?x4x5xf32>
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d2, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
+func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
+                                              %arg1 : tensor<?x?xf32>,
+                                              %arg2 : tensor<?x?xf32>) ->
+                                              tensor<?x?xf32>
+{
+  %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] :
+    tensor<?x7x?x8xf32> into tensor<?x?xf32>
+  %1 = linalg.add ins(%0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+       outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+//  CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+//      CHECK: func @linalg_add_reshape_producer_fusion
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]]
+// CHECK-SAME:     [0, 1], [2, 3]
+//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG2]]
+// CHECK-SAME:     [0, 1], [2, 3]
+//      CHECK:   %[[T3:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]]]
+// CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME:     ins(%[[ARG0]], %[[T1]] : tensor<?x7x?x8xf32>, tensor<?x7x?x8xf32>)
+// CHECK-SAME:     outs(%[[T2]] : tensor<?x7x?x8xf32>)
+//      CHECK:   %[[T4:.+]] = tensor.collapse_shape %[[T3]]
+// CHECK-SAME:     [0, 1], [2, 3]
+// CHECK-SAME:     tensor<?x7x?x8xf32> into tensor<?x?xf32>
+//      CHECK:   return %[[T4]]


        


More information about the Mlir-commits mailing list