[Mlir-commits] [mlir] 2c58cde - [mlir][Linalg] Add pattern for folding reshape by collapsing.

Mahesh Ravishankar llvmlistbot at llvm.org
Tue Feb 15 19:15:48 PST 2022


Author: Mahesh Ravishankar
Date: 2022-02-16T03:15:20Z
New Revision: 2c58cde003eb7b5aaae3eb5ac94f25a52f151df2

URL: https://github.com/llvm/llvm-project/commit/2c58cde003eb7b5aaae3eb5ac94f25a52f151df2
DIFF: https://github.com/llvm/llvm-project/commit/2c58cde003eb7b5aaae3eb5ac94f25a52f151df2.diff

LOG: [mlir][Linalg] Add pattern for folding reshape by collapsing.

Fusion of `linalg.generic` with
`tensor.expand_shape/tensor.collapse_shape` currently handles fusion
with reshape by expanding the dimensionality of the `linalg.generic`
operation. This helps fuse elementwise operations better since they
are fused at the highest dimensionality while keeping all indexing
maps involved projected permutations. The intent of these is to push
the reshape to the boundaries of functions.

The presence of named ops (or other ops across which the reshape
cannot be propagated) stops the propagation to the edges of the
function. At this stage, the converse patterns that fold the reshapes
with generic ops by collapsing the dimensions of the generic op can
push the reshape towards edges. In particular it helps the case where
reshapes exist in between named ops and generic ops.

`linalg.named_op` -> `tensor.expand_shape` -> `linalg.generic`

Pushing the reshape down will help fusion of `linalg.named_op` ->
`linalg.generic` using tile + fuse transformations.

This pattern is intended to replace the following patterns

1) FoldReshapeByLinearization : These patterns create indexing maps
that are not projected permutations that affect future
transformations. They are only useful for folding unit-dimensions.
2) PushReshapeByExpansion : This pattern has the same functionality
but has some restrictions
    a) It tries to avoid creating new reshapes that limits its
    applicability. The pattern added here can achieve the same
    functionality through use of the `controlFn` that allows clients
    of the pattern freedom to make this decision.
    b) It does not work for ops with indexing semantics.

These patterns will be deprecated in a future patch.

Differential Revision: https://reviews.llvm.org/D119365

Added: 
    mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
    mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 00de7ba0265bd..dbf65aec97880 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -669,6 +669,22 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
         return *(indexingMaps.begin() + opOperand->getOperandNumber());
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the indexing map for a `result`.
+      }],
+      /*retTy=*/"AffineMap",
+      /*methodName=*/"getTiedIndexingMapForResult",
+      /*args=*/(ins "OpResult":$result),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        assert(result.getOwner() == this->getOperation());
+        auto indexingMaps =
+          $_op.indexing_maps().template getAsValueRange<AffineMapAttr>();
+          return *(indexingMaps.begin() + getNumInputs() +
+              result.getResultNumber());
+      }]
+    >,
     InterfaceMethod<
       /*desc=*/[{
         Return the result tied to `opOperand`.

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index d729ae5b850f9..24230a3730481 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -73,10 +73,21 @@ void populateFoldReshapeOpsByExpansionPatterns(
     const ControlElementwiseOpsFusionFn &controlFoldingReshapes =
         skipUnitDimReshape);
 
+/// Patterns to fold an expanding tensor.expand_shape operation with its
+/// producer generic operation by collapsing the dimensions of the generic op.
+void populateFoldReshapeOpsByCollapsingPatterns(
+    RewritePatternSet &patterns,
+    const ControlElementwiseOpsFusionFn &controlFoldingReshapes =
+        [](const OpResult & /*producer*/, OpOperand & /*consumer*/) {
+          return true;
+        });
+
 /// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
 /// producer (consumer) generic operation by linearizing the indexing map used
 /// to access the source (target) of the reshape operation in the generic
 /// operation.
+/// TODO(ravishankarm): These patterns are to be deprecated in favor of using
+/// the `populateFoldReshapeByCollapsingPatterns`.
 void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns);
 
 /// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
@@ -84,6 +95,8 @@ void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns);
 /// to access the source (target) of the reshape operation in the generic
 /// operation. The patterns are applied only when the tensor reshape involved is
 /// collapsing (introducing) unit-extent dimensions.
+/// TODO(ravishankarm): These patterns are to be deprecated in favor of using
+/// the `populateFoldReshapeByCollapsingPatterns`.
 void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
     RewritePatternSet &patterns);
 
@@ -153,6 +166,8 @@ void populateElementwiseOpsFusionPatterns(
 
 /// Patterns to push reshape op towards the end of the graph in order to expose
 /// more fusion opportunities.
+/// TODO(ravishankarm): These patterns are to be deprecated in favor of using
+/// the `populateFoldReshapeByCollapsingPatterns`.
 void populatePushReshapeOpsPatterns(RewritePatternSet &patterns);
 
 /// Perform standalone tiling of a single LinalgOp by `tileSizes`.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index a30263990500d..570e844878d79 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -396,10 +396,11 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
 // linearization of indexing maps.
 //===---------------------------------------------------------------------===//
 
-// TODO(ravishankarm): These patterns need to be deprecated. The indexing maps
+// TODO(ravishankarm): The indexing maps
 // these produce in the general case are detrimental to transformations.
-// They are useful now only in the limited case of unit-dimension folding.
-// Remove these in favor of more general folding by dimension contraction.
+// These patterns are on deprecation path in favor of using fusion by
+// collapsing, which covers the only legitimate use case of this pattern of
+// folding unit-extent dims.
 
 /// Linearize the expressions in `sourceMap` based on the `reassociationMaps`
 /// provided, given the shape of the source tensor that corresponds to the
@@ -1144,11 +1145,470 @@ struct FoldReshapeWithGenericOpByExpansion
 };
 } // namespace
 
+//===---------------------------------------------------------------------===//
+// Methods and patterns to fuse reshape with linalg.generic operations by
+// contraction of dimensions.
+//===---------------------------------------------------------------------===//
+
+/// For an `indexingMap` that is a projected permutation, if the range is to be
+/// collapsed using the given `reassociation`, get the reassociation in the
+/// domain that would keep the map a projected permutation.
+static SmallVector<ReassociationIndices>
+getDomainReassociation(AffineMap indexingMap,
+                       ArrayRef<ReassociationIndices> rangeReassociation) {
+  assert(indexingMap.isProjectedPermutation() &&
+         "expected projected permutation map");
+  unsigned counter = 0;
+  SmallVector<ReassociationIndices> domainReassociation;
+  llvm::SmallDenseSet<unsigned, 4> processedDomainDims;
+  // Iterate over the reassociation indices.
+  for (ReassociationIndicesRef foldedRangeDims : rangeReassociation) {
+    ReassociationIndices foldedDomainDims;
+    for (auto rangeDim : foldedRangeDims) {
+      (void)rangeDim;
+      AffineDimExpr dimExpr =
+          indexingMap.getResult(counter++).cast<AffineDimExpr>();
+      foldedDomainDims.push_back(dimExpr.getPosition());
+      processedDomainDims.insert(dimExpr.getPosition());
+    }
+    domainReassociation.emplace_back(std::move(foldedDomainDims));
+  }
+  // Fill in the missing domain dims.
+  for (auto dim : llvm::seq<unsigned>(0, indexingMap.getNumDims())) {
+    if (processedDomainDims.count(dim))
+      continue;
+    ReassociationIndices vec = {dim};
+    domainReassociation.emplace_back(std::move(vec));
+  }
+
+  // Sort the reassociation using the first dimension of the folded range to
+  // not create unnecessary transposes.
+  llvm::sort(domainReassociation,
+             [](ReassociationIndicesRef lhs, ReassociationIndicesRef rhs) {
+               return lhs[0] < rhs[0];
+             });
+  return domainReassociation;
+}
+
+/// For a given `dimSequence`, check if the sequence is conserved in the
+/// `indexingMap`. `indexingMap` is expected to be a projected permutation.
+/// Non-existence of the sequence returns true as well.
+static bool isDimSequencePreserved(AffineMap indexingMap,
+                                   ReassociationIndicesRef dimSequence) {
+  assert(!dimSequence.empty() &&
+         "expected non-empty list for dimension sequence");
+  assert(indexingMap.isProjectedPermutation() &&
+         "expected indexing map to be projected permutation");
+
+  llvm::SmallDenseSet<unsigned, 4> sequenceElements;
+  sequenceElements.insert(dimSequence.begin(), dimSequence.end());
+
+  unsigned dimSequenceStart = dimSequence[0];
+  for (auto expr : enumerate(indexingMap.getResults())) {
+    unsigned dimInMapStart = expr.value().cast<AffineDimExpr>().getPosition();
+    // 1.  Check if this start of the sequence.
+    if (dimInMapStart == dimSequenceStart) {
+      if (expr.index() + dimSequence.size() > indexingMap.getNumResults())
+        return false;
+      // 1a. Check if sequence is preserved.
+      for (auto dimInSequence : enumerate(dimSequence)) {
+        unsigned dimInMap =
+            indexingMap.getResult(expr.index() + dimInSequence.index())
+                .cast<AffineDimExpr>()
+                .getPosition();
+        if (dimInMap != dimInSequence.value())
+          return false;
+      }
+      // Found the sequence. Projected permutation
+      // enforces that all AffineDimExprs in the result are unique, so no
+      // further checks are needed.
+      return true;
+    }
+    // 2. If position in the expr (which is of type AffineDimExpr) is part
+    // of sequence, return false here. This implies the entire sequence does not
+    // exist in the indexing map.
+    if (sequenceElements.count(dimInMapStart))
+      return false;
+  }
+  // 3. No element of sequence found. Return true.
+  return true;
+}
+
+// Check if a generic op can be fused along an operand by collapsing dimensions.
+static bool isFusableWithReshapeByDimCollapse(
+    GenericOp genericOp, OpOperand *fusableOperand,
+    ArrayRef<ReassociationIndices> reassociation) {
+  // Some basic checks for this fusion to be valid.
+  if (!genericOp.hasTensorSemantics() || genericOp.getNumOutputs() != 1)
+    return false;
+
+  if (!llvm::all_of(genericOp.getIndexingMaps(), [](AffineMap map) {
+        return map.isProjectedPermutation();
+      })) {
+    return false;
+  }
+
+  // Get the reassociation for the iteration space.
+  SmallVector<ReassociationIndices> iterationReassociation =
+      getDomainReassociation(genericOp.getTiedIndexingMap(fusableOperand),
+                             reassociation);
+  if (iterationReassociation.empty()) {
+    // If the domain reassociation indices is empty, then this is a scalar op.
+    // Nothing to do.
+    return false;
+  }
+
+  auto iteratorTypes = genericOp.iterator_types().getValue();
+  ArrayRef<Attribute> iteratorTypesRef(iteratorTypes);
+  for (ReassociationIndicesRef foldedIterDims : iterationReassociation) {
+    // Check that for all indexing maps, the folded dimensions sequence is
+    // preserved.
+    if (!llvm::all_of(genericOp.getIndexingMaps(), [&](AffineMap indexingMap) {
+          return isDimSequencePreserved(indexingMap, foldedIterDims);
+        }))
+      return false;
+    unsigned startDim = foldedIterDims[0];
+    ArrayRef<Attribute> foldedIteratorTypes =
+        iteratorTypesRef.drop_front(startDim).take_front(foldedIterDims.size());
+    // Check that all folded iterator types are either all parallel, or all
+    // reduction.
+    if (!llvm::all_of(
+            foldedIteratorTypes,
+            [](Attribute attr) { return isParallelIterator(attr); }) &&
+        !llvm::all_of(foldedIteratorTypes,
+                      [](Attribute attr) { return isReductionIterator(attr); }))
+      return false;
+  }
+  return true;
+}
+
+/// Helper class to carry state while collapsing the `linalg.generic` op.
+namespace {
+class CollapsingInfo {
+public:
+  CollapsingInfo(SmallVector<ReassociationIndices> &&reassociation) {
+    iterationReassociation = std::move(reassociation);
+    for (auto foldedIterDims : enumerate(iterationReassociation)) {
+      foldedDimStartToSequenceMap[foldedIterDims.value()[0]] =
+          foldedIterDims.index();
+    }
+  }
+
+  // Returns the iteration space reassociation.
+  ArrayRef<ReassociationIndices> getReassociationIndices() {
+    return iterationReassociation;
+  }
+
+  // Returns true if the given dimension is the start of a sequence of folded
+  // dimensions.
+  bool isDimStartOfFoldedDims(unsigned dim) {
+    return foldedDimStartToSequenceMap.count(dim);
+  }
+
+  // Return the folded dimensions starting at `dim`.
+  ReassociationIndicesRef getFoldedDimsStartingAt(unsigned dim) {
+    assert(foldedDimStartToSequenceMap.count(dim) &&
+           "invalid start dim of folded dim "
+           "sequence");
+    return iterationReassociation[foldedDimStartToSequenceMap[dim]];
+  }
+
+  // For a dim in the original op, return the dim in the collapsed op, that it
+  // is mapped to. Expectes `dim` to be start of a folded dimension sequence.
+  unsigned getDimInCollapsedOpForStartOfFoldedDims(unsigned dim) {
+    assert(foldedDimStartToSequenceMap.count(dim) &&
+           "invalid start dim of folded dim sequence");
+    return foldedDimStartToSequenceMap[dim];
+  }
+
+private:
+  /// Reassociation describing the folded iteration space dimensions.
+  SmallVector<ReassociationIndices> iterationReassociation;
+
+  /// Map from the starting dimensions of the folded dimension sequences to
+  /// their index in `iterationReassociation`.
+  llvm::DenseMap<unsigned, unsigned> foldedDimStartToSequenceMap;
+};
+} // namespace
+
+/// Get the iterator types for the collapsed operation given the original
+/// iterator types and collapsed dimensions.
+static SmallVector<StringRef>
+getCollapsedOpIteratorTypes(ArrayRef<Attribute> iteratorTypes,
+                            CollapsingInfo &collapsingInfo) {
+  SmallVector<StringRef> collapsedIteratorTypes;
+  for (ReassociationIndicesRef foldedIterDims :
+       collapsingInfo.getReassociationIndices()) {
+    assert(!foldedIterDims.empty() &&
+           "reassociation indices expected to have non-empty sets");
+    // Just pick the iterator type of the first folded dim. Pre-condition checks
+    // expected to have checked that iterator types of all folded dimensions are
+    // the same.
+    collapsedIteratorTypes.push_back(
+        iteratorTypes[foldedIterDims[0]].cast<StringAttr>().getValue());
+  }
+  return collapsedIteratorTypes;
+}
+
+/// Compute the indexing map in the collapsed op that corresponds to the given
+/// `indexingMap` of the original operation.
+static AffineMap getCollapsedOpIndexingMap(AffineMap indexingMap,
+                                           CollapsingInfo &collapsingInfo) {
+  MLIRContext *context = indexingMap.getContext();
+  assert(indexingMap.isProjectedPermutation() &&
+         "expected indexing map to be projected permutation");
+  SmallVector<AffineExpr> resultExprs;
+  for (auto expr : indexingMap.getResults()) {
+    unsigned dim = expr.cast<AffineDimExpr>().getPosition();
+    if (collapsingInfo.isDimStartOfFoldedDims(dim)) {
+      resultExprs.push_back(getAffineDimExpr(
+          collapsingInfo.getDimInCollapsedOpForStartOfFoldedDims(dim),
+          context));
+    }
+  }
+  return AffineMap::get(collapsingInfo.getReassociationIndices().size(), 0,
+                        resultExprs, context);
+}
+
+/// Return the `reassociation` indices to use to collapse the operand when the
+/// iteration space of a generic op is collapsed.
+static SmallVector<ReassociationIndices>
+getOperandReassociation(AffineMap indexingMap, CollapsingInfo &collapsingInfo) {
+  unsigned counter = 0;
+  SmallVector<ReassociationIndices> operandReassociation;
+  for (auto expr : indexingMap.getResults()) {
+    unsigned dim = expr.cast<AffineDimExpr>().getPosition();
+    if (collapsingInfo.isDimStartOfFoldedDims(dim)) {
+      unsigned numFoldedDims =
+          collapsingInfo.getFoldedDimsStartingAt(dim).size();
+      auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
+      operandReassociation.emplace_back(range.begin(), range.end());
+      counter += numFoldedDims;
+    }
+  }
+  return operandReassociation;
+}
+
+/// Get the new value to use for a given `OpOperand` in the collapsed operation.
+static Value getCollapsedOpOperand(Location loc, GenericOp genericOp,
+                                   OpOperand *opOperand,
+                                   CollapsingInfo &collapsingInfo,
+                                   OpBuilder &builder) {
+  AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
+  SmallVector<ReassociationIndices> operandReassociation =
+      getOperandReassociation(indexingMap, collapsingInfo);
+
+  // If the number of entries in the reassocation for the operand is same as the
+  // number of results of the indexing map, then nothing to do for this operand.
+  Value operand = opOperand->get();
+  if (operandReassociation.size() == indexingMap.getNumResults())
+    return operand;
+
+  // Insert a reshape to collapse the dimensions.
+  auto reshapeOp = builder.create<tensor::CollapseShapeOp>(
+      loc, operand, operandReassociation);
+  return reshapeOp.getResult();
+}
+
+/// Modify the `linalg.index` operations in the original generic op, to its
+/// value in the collapsed operation.
+void generateCollapsedIndexingRegion(Location loc, Block *block,
+                                     CollapsingInfo &collapsingInfo,
+                                     ValueRange loopRange,
+                                     PatternRewriter &rewriter) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPointToStart(block);
+
+  // Collect all the original index ops.
+  auto indexOps = llvm::to_vector(block->getOps<linalg::IndexOp>());
+
+  // For each folded dimension list resolve the original induction variable
+  // values in terms of the folded dimension induction variable.
+  //   i_{folded} = (i_0 * d1 + i1) * d2 + i2.
+  // can be inverted to
+  //   i2 = i_{folded} % d2
+  //   i1 = (i_{folded} / d2) % d1
+  //   i0 = i_{folded} / (d1 * d2)
+  llvm::DenseMap<unsigned, Value> indexReplacementVals;
+  for (auto &foldedDims : enumerate(collapsingInfo.getReassociationIndices())) {
+    ReassociationIndicesRef foldedDimsRef(foldedDims.value());
+    Value newIndexVal =
+        rewriter.create<linalg::IndexOp>(loc, foldedDims.index());
+    for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
+      indexReplacementVals[dim] =
+          rewriter.create<arith::RemUIOp>(loc, newIndexVal, loopRange[dim]);
+      newIndexVal =
+          rewriter.create<arith::DivUIOp>(loc, newIndexVal, loopRange[dim]);
+    }
+    indexReplacementVals[foldedDims.value().front()] = newIndexVal;
+  }
+
+  for (auto indexOp : indexOps) {
+    auto dim = indexOp.dim();
+    rewriter.replaceOp(indexOp, indexReplacementVals[dim]);
+  }
+}
+
+/// Implementation of fusion with reshape operation by collapsing dimensions.
+static Optional<SmallVector<Value>>
+fuseWithReshapeByCollapsing(GenericOp genericOp, Operation *reshapeOp,
+                            OpOperand *fusableOpOperand,
+                            PatternRewriter &rewriter) {
+  SmallVector<ReassociationIndices> reassociation =
+      isa<tensor::CollapseShapeOp>(reshapeOp)
+          ? cast<tensor::CollapseShapeOp>(reshapeOp).getReassociationIndices()
+          : cast<tensor::ExpandShapeOp>(reshapeOp).getReassociationIndices();
+  assert(isFusableWithReshapeByDimCollapse(genericOp, fusableOpOperand,
+                                           reassociation) &&
+         "preconditions for fusing with reshape by collapse failed");
+
+  CollapsingInfo collapsingInfo(getDomainReassociation(
+      genericOp.getTiedIndexingMap(fusableOpOperand), reassociation));
+  // Check for trivial no transformation cases. In that case return nothing.
+  if (collapsingInfo.getReassociationIndices().size() ==
+      genericOp.getNumLoops())
+    return llvm::None;
+
+  // Get the iterator types for the operand.
+  SmallVector<StringRef> iteratorTypes = getCollapsedOpIteratorTypes(
+      genericOp.iterator_types().getValue(), collapsingInfo);
+
+  // Get the indexing maps.
+  auto indexingMaps = llvm::to_vector(
+      llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap map) {
+        return getCollapsedOpIndexingMap(map, collapsingInfo);
+      }));
+
+  Location loc =
+      rewriter.getFusedLoc({genericOp->getLoc(), reshapeOp->getLoc()});
+
+  // Get the input operands.
+  auto inputOperands = llvm::to_vector(
+      llvm::map_range(genericOp.getInputOperands(), [&](OpOperand *opOperand) {
+        return getCollapsedOpOperand(loc, genericOp, opOperand, collapsingInfo,
+                                     rewriter);
+      }));
+
+  // Get the output operands and result types.
+  SmallVector<Type> resultTypes;
+  SmallVector<Value> outputOperands;
+  resultTypes.reserve(genericOp.getNumOutputs());
+  outputOperands.reserve(genericOp.getNumOutputs());
+  for (OpOperand *output : genericOp.getOutputOperands()) {
+    Value newOutput =
+        getCollapsedOpOperand(loc, genericOp, output, collapsingInfo, rewriter);
+    outputOperands.push_back(newOutput);
+    resultTypes.push_back(newOutput.getType());
+  }
+
+  // Create the generic op.
+  auto collapsedGenericOp = rewriter.create<linalg::GenericOp>(
+      loc, resultTypes, inputOperands, outputOperands, indexingMaps,
+      iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
+  Block *origOpBlock = &genericOp->getRegion(0).front();
+  Block *collapsedOpBlock = &collapsedGenericOp->getRegion(0).front();
+  rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
+                       collapsedOpBlock->getArguments());
+
+  if (collapsedGenericOp.hasIndexSemantics()) {
+    // Collect the loop range of the generic op.
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPoint(collapsedGenericOp);
+    SmallVector<Range> loopRanges =
+        cast<LinalgOp>(genericOp.getOperation())
+            .createLoopRanges(rewriter, genericOp.getLoc());
+    assert(llvm::all_of(loopRanges,
+                        [](Range range) {
+                          return matchPattern(range.offset, m_Zero()) &&
+                                 matchPattern(range.stride, m_One());
+                        }) &&
+           "expected all loop ranges to have zero start and unit stride");
+    SmallVector<Value> loopBound = llvm::to_vector(
+        llvm::map_range(loopRanges, [](Range range) { return range.size; }));
+    generateCollapsedIndexingRegion(loc,
+                                    &collapsedGenericOp->getRegion(0).front(),
+                                    collapsingInfo, loopBound, rewriter);
+  }
+
+  // Insert expanding reshape for the result to get back the original result
+  // type.
+  SmallVector<Value> results;
+  for (auto originalResult : llvm::enumerate(genericOp->getResults())) {
+    Value collapsedOpResult =
+        collapsedGenericOp->getResult(originalResult.index());
+    auto originalResultType =
+        originalResult.value().getType().cast<ShapedType>();
+    auto collapsedOpResultType = collapsedOpResult.getType().cast<ShapedType>();
+    if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
+      AffineMap indexingMap =
+          genericOp.getTiedIndexingMapForResult(originalResult.value());
+      SmallVector<ReassociationIndices> reassociation =
+          getOperandReassociation(indexingMap, collapsingInfo);
+      Value result = rewriter.create<tensor::ExpandShapeOp>(
+          loc, originalResultType, collapsedOpResult, reassociation);
+      results.push_back(result);
+    } else {
+      results.push_back(collapsedOpResult);
+    }
+  }
+  return results;
+}
+
+namespace {
+
+/// Pattern to fuse a tensor.expand_shape op with its consumer generic op by
+/// contracting dimensions of the loop.
+class FoldWithProducerReshapeOpByCollapsing
+    : public OpRewritePattern<GenericOp> {
+public:
+  FoldWithProducerReshapeOpByCollapsing(
+      MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
+      PatternBenefit benefit = 1)
+      : OpRewritePattern<GenericOp>(context, benefit),
+        controlFoldingReshapes(std::move(foldReshapes)) {}
+
+  LogicalResult matchAndRewrite(GenericOp genericOp,
+                                PatternRewriter &rewriter) const override {
+    for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
+      tensor::ExpandShapeOp reshapeOp =
+          opOperand->get().getDefiningOp<tensor::ExpandShapeOp>();
+      if (!reshapeOp)
+        continue;
+
+      if (!isFusableWithReshapeByDimCollapse(
+              genericOp, opOperand, reshapeOp.getReassociationIndices()) ||
+          !controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)) {
+        continue;
+      }
+
+      Optional<SmallVector<Value>> replacements = fuseWithReshapeByCollapsing(
+          genericOp, reshapeOp, opOperand, rewriter);
+      if (!replacements) {
+        return rewriter.notifyMatchFailure(
+            genericOp, "failed to do the fusion by collapsing transformation");
+      }
+
+      rewriter.replaceOp(genericOp, replacements.getValue());
+      return success();
+    }
+    return failure();
+  }
+
+private:
+  ControlElementwiseOpsFusionFn controlFoldingReshapes;
+};
+} // namespace
+
 //===---------------------------------------------------------------------===//
 // Methods and patterns to convert tensor.expand_shape -> linalg.generic
 // into linalg.generic -> tensor.expand_shape, i.e. push the reshape down.
 //===---------------------------------------------------------------------===//
 
+// TODO(ravishankarm): This pattern is to be deprecated in favor of fusion by
+// collapsing that provides a more general functionality. This pattern is very
+// specific to a particular use case. The fusion by collapsing can provide the
+// same control to clients using the control function there.
+
 static SmallVector<ReassociationIndices>
 getReassociationIndices(ArrayRef<AffineMap> maps) {
   SmallVector<ReassociationIndices> reassociation;
@@ -1785,6 +2245,13 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
                                                      controlFoldingReshapes);
 }
 
+void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
+    RewritePatternSet &patterns,
+    const ControlElementwiseOpsFusionFn &controlFoldingReshapes) {
+  patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(),
+                                                      controlFoldingReshapes);
+}
+
 void mlir::linalg::populateElementwiseOpsFusionPatterns(
     RewritePatternSet &patterns, LinalgElementwiseFusionOptions options) {
   auto *context = patterns.getContext();

diff  --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
new file mode 100644
index 0000000000000..c45c5b296a781
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
@@ -0,0 +1,401 @@
+// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-with-reshape-by-collapsing -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-with-reshape-by-collapsing-control -split-input-file | FileCheck %s --check-prefix=CONTROL
+
+// Static problem sizes. Checks all aspects of fusion by collapsing. Rest of the 
+// tests only check a subset of conditions.
+#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
+#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2)>
+#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6)>
+#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
+func @fuse_by_collapsing(%arg0 : tensor<2x12x5x336x9xi32>,
+    %arg1 : tensor<2x3x4xi32>, %arg2 : tensor<5x6x7x8xi32>) -> tensor<2x3x4x5x6x7x8x9xi32> {
+  %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]]
+      : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32>
+  %init = linalg.init_tensor [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x3x4x5x6x7x8x9xi32>
+  %generic = linalg.generic {
+    indexing_maps = [#map0, #map1, #map2, #map3],
+    iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]}
+    ins(%expand, %arg1, %arg2 : tensor<2x3x4x5x6x7x8x9xi32>, tensor<2x3x4xi32>, tensor<5x6x7x8xi32>)
+    outs(%init : tensor<2x3x4x5x6x7x8x9xi32>) {
+      ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32):
+        %t0 = arith.addi %b0, %b1 : i32
+        %t1 = arith.addi %t0, %b2 : i32
+        linalg.yield %t1 : i32
+    } -> tensor<2x3x4x5x6x7x8x9xi32>
+  return %generic : tensor<2x3x4x5x6x7x8x9xi32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
+//      CHECK: func @fuse_by_collapsing(
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>
+// CHECK-SAME:   %[[ARG1:.+]]: tensor<2x3x4xi32>
+// CHECK-SAME:   %[[ARG2:.+]]: tensor<5x6x7x8xi32>
+//  CHECK-DAG:   %[[INIT:.+]] = linalg.init_tensor [2, 3, 4, 5, 6, 7, 8, 9]
+//  CHECK-DAG:   %[[ARG1_RESHAPE:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1, 2]{{\]}}
+//  CHECK-DAG:   %[[ARG2_RESHAPE:.+]] = tensor.collapse_shape %[[ARG2]] {{\[}}[0], [1, 2, 3]{{\]}}
+//  CHECK-DAG:   %[[INIT_RESHAPE:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}}
+//      CHECK:   %[[COLLAPSED_OP:.+]] = linalg.generic
+// CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]]]
+// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME:       ins(%[[ARG0]], %[[ARG1_RESHAPE]], %[[ARG2_RESHAPE]] :
+// CHECK_SAME:       outs(%[[INIT_RESHAPE]] :
+//      CHECK:   %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[COLLAPSED_OP]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}}
+//      CHECK:   return %[[RESULT_RESHAPE]]
+
+//      CONTROL: func @fuse_by_collapsing(
+// CONTROL-SAME:   %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>
+// CONTROL-SAME:   %[[ARG1:.+]]: tensor<2x3x4xi32>
+// CONTROL-SAME:   %[[ARG2:.+]]: tensor<5x6x7x8xi32>
+//      CONTROL:   %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]]
+//      CONTROL:   %[[GENERIC:.+]] = linalg.generic
+// CONTROL-SAME:       ins(%[[EXPAND]],
+//      CONTROL:   return %[[GENERIC]]
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
+#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2)>
+#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6)>
+#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
+func @fuse_by_collapsing_indexing_op(%arg0 : tensor<2x12x5x336x9xi32>,
+    %arg1 : tensor<2x3x4xi32>, %arg2 : tensor<5x6x7x8xi32>) -> tensor<2x3x4x5x6x7x8x9xi32> {
+  %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]]
+      : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32>
+  %init = linalg.init_tensor [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x3x4x5x6x7x8x9xi32>
+  %generic = linalg.generic {
+    indexing_maps = [#map0, #map1, #map2, #map3],
+    iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]}
+    ins(%expand, %arg1, %arg2 : tensor<2x3x4x5x6x7x8x9xi32>, tensor<2x3x4xi32>, tensor<5x6x7x8xi32>)
+    outs(%init : tensor<2x3x4x5x6x7x8x9xi32>) {
+      ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32):
+        %iv0 = linalg.index 0: index
+        %iv1 = linalg.index 1: index
+        %t0 = arith.addi %iv0, %iv1 : index
+        %iv2 = linalg.index 2 : index
+        %t1 = arith.addi %t0, %iv2 : index
+        %iv3 = linalg.index 3 : index
+        %t2 = arith.addi %t1, %iv3 : index
+        %iv4 = linalg.index 4 : index
+        %t3 = arith.addi %t2, %iv4 : index
+        %iv5 = linalg.index 5 : index
+        %t4 = arith.addi %t3, %iv5 : index
+        %iv6 = linalg.index 6 : index
+        %t5 = arith.addi %t4, %iv6 : index
+        %iv7 = linalg.index 7 : index
+        %t6 = arith.addi %t5, %iv7 : index
+        %yield = arith.index_cast %t6 : index to i32
+        linalg.yield %yield : i32
+    } -> tensor<2x3x4x5x6x7x8x9xi32>
+  return %generic : tensor<2x3x4x5x6x7x8x9xi32>
+}
+// CHECK-LABEL: func @fuse_by_collapsing_indexing_op(
+//   CHECK-DAG:   %[[C4:.+]] = arith.constant 4 : index
+//   CHECK-DAG:   %[[C8:.+]] = arith.constant 8 : index
+//   CHECK-DAG:   %[[C7:.+]] = arith.constant 7 : index
+//       CHECK:     %[[IV0:.+]] = linalg.index 0
+//       CHECK:     %[[IV1:.+]] = linalg.index 1
+//       CHECK:     %[[REM_IV1:.+]] = arith.remui %[[IV1]], %[[C4]]
+//       CHECK:     %[[DIV_IV1:.+]] = arith.divui %[[IV1]], %[[C4]]
+//       CHECK:     %[[IV2:.+]] = linalg.index 2
+//       CHECK:     %[[IV3:.+]] = linalg.index 3
+//       CHECK:     %[[REM1_IV3:.+]] = arith.remui %[[IV3]], %[[C8]]
+//       CHECK:     %[[DIV1_IV3:.+]] = arith.divui %[[IV3]], %[[C8]]
+//       CHECK:     %[[REM2_IV3:.+]] = arith.remui %[[DIV1_IV3]], %[[C7]]
+//       CHECK:     %[[DIV2_IV3:.+]] = arith.divui %[[DIV1_IV3]], %[[C7]]
+//       CHECK:     %[[IV4:.+]] = linalg.index 4
+//       CHECK:     %[[T0:.+]] = arith.addi %[[IV0]], %[[DIV_IV1]]
+//       CHECK:     %[[T1:.+]] = arith.addi %[[T0]], %[[REM_IV1]]
+//       CHECK:     %[[T2:.+]] = arith.addi %[[T1]], %[[IV2]]
+//       CHECK:     %[[T3:.+]] = arith.addi %[[T2]], %[[DIV2_IV3]]
+//       CHECK:     %[[T4:.+]] = arith.addi %[[T3]], %[[REM2_IV3]]
+//       CHECK:     %[[T5:.+]] = arith.addi %[[T4]], %[[REM1_IV3]]
+//       CHECK:     %[[T6:.+]] = arith.addi %[[T5]], %[[IV4]]
+//       CHECK:     %[[YIELD:.+]] = arith.index_cast %[[T6]]
+//       CHECK:     linalg.yield %[[YIELD]]
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d7, d5, d6, d0, d1, d2, d3, d4)>
+#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d0)>
+#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d1, d2, d3)>
+#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
+func @fuse_by_collapsing_change_reshape_order(%arg0 : tensor<9x56x2x60x6xi32>,
+    %arg1 : tensor<7x8x2xi32>, %arg2 : tensor<6x3x4x5xi32>) -> tensor<2x3x4x5x6x7x8x9xi32> {
+  %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]]
+      : tensor<9x56x2x60x6xi32> into tensor<9x7x8x2x3x4x5x6xi32>
+  %init = linalg.init_tensor [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x3x4x5x6x7x8x9xi32>
+  %generic = linalg.generic {
+    indexing_maps = [#map0, #map1, #map2, #map3],
+    iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]}
+    ins(%expand, %arg1, %arg2 : tensor<9x7x8x2x3x4x5x6xi32>, tensor<7x8x2xi32>, tensor<6x3x4x5xi32>)
+    outs(%init : tensor<2x3x4x5x6x7x8x9xi32>) {
+      ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32):
+        %t0 = arith.addi %b0, %b1 : i32
+        %t1 = arith.addi %t0, %b2 : i32
+        linalg.yield %t1 : i32
+    } -> tensor<2x3x4x5x6x7x8x9xi32>
+  return %generic : tensor<2x3x4x5x6x7x8x9xi32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d4, d3, d0, d1, d2)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d0)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d1)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+//      CHECK: func @fuse_by_collapsing_change_reshape_order(
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<9x56x2x60x6xi32>
+// CHECK-SAME:   %[[ARG1:.+]]: tensor<7x8x2xi32>
+// CHECK-SAME:   %[[ARG2:.+]]: tensor<6x3x4x5xi32>
+//  CHECK-DAG:   %[[INIT:.+]] = linalg.init_tensor [2, 3, 4, 5, 6, 7, 8, 9]
+//  CHECK-DAG:   %[[ARG1_RESHAPE:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0, 1], [2]{{\]}}
+//  CHECK-DAG:   %[[ARG2_RESHAPE:.+]] = tensor.collapse_shape %[[ARG2]] {{\[}}[0], [1, 2, 3]{{\]}}
+//  CHECK-DAG:   %[[INIT_RESHAPE:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1, 2, 3], [4], [5, 6], [7]{{\]}}
+//      CHECK:   %[[COLLAPSED_OP:.+]] = linalg.generic
+// CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]]
+// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME:       ins(%[[ARG0]], %[[ARG1_RESHAPE]], %[[ARG2_RESHAPE]] :
+// CHECK_SAME:       outs(%[[INIT_RESHAPE]] :
+//      CHECK:   %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[COLLAPSED_OP]] {{\[}}[0], [1, 2, 3], [4], [5, 6], [7]{{\]}}
+//      CHECK:   return %[[RESULT_RESHAPE]]
+
+// -----
+
+// Dynamic case. Only checks things not covered by `fuse_by_collapsing` test above.
+#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d7, d5, d6, d0, d1, d2, d3, d4)>
+#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d0)>
+#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d1, d2, d3)>
+#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
+func @fuse_by_collapsing_dynamic(%arg0 : tensor<?x?x?x?x?xi32>,
+    %arg1 : tensor<?x?x?xi32>, %arg2 : tensor<?x?x?x?xi32>) -> tensor<?x3x?x5x?x7x?x?xi32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]]
+      : tensor<?x?x?x?x?xi32> into tensor<?x7x?x?x3x?x5x?xi32>
+  %d0 = tensor.dim %arg1, %c2 : tensor<?x?x?xi32>
+  %d2 = tensor.dim %arg2, %c2 : tensor<?x?x?x?xi32>
+  %d4 = tensor.dim %arg2, %c0 : tensor<?x?x?x?xi32>
+  %d6 = tensor.dim %arg1, %c1 : tensor<?x?x?xi32>
+  %d7 = tensor.dim %arg0, %c0 : tensor<?x?x?x?x?xi32>
+  %init = linalg.init_tensor [%d0, 3, %d2, 5, %d4, 7, %d6, %d7] : tensor<?x3x?x5x?x7x?x?xi32>
+  %generic = linalg.generic {
+    indexing_maps = [#map0, #map1, #map2, #map3],
+    iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]}
+    ins(%expand, %arg1, %arg2 : tensor<?x7x?x?x3x?x5x?xi32>, tensor<?x?x?xi32>, tensor<?x?x?x?xi32>)
+    outs(%init : tensor<?x3x?x5x?x7x?x?xi32>) {
+      ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32):
+        %iv0 = linalg.index 0: index
+        %iv1 = linalg.index 1: index
+        %t0 = arith.addi %iv0, %iv1 : index
+        %iv2 = linalg.index 2 : index
+        %t1 = arith.addi %t0, %iv2 : index
+        %iv3 = linalg.index 3 : index
+        %t2 = arith.addi %t1, %iv3 : index
+        %iv4 = linalg.index 4 : index
+        %t3 = arith.addi %t2, %iv4 : index
+        %iv5 = linalg.index 5 : index
+        %t4 = arith.addi %t3, %iv5 : index
+        %iv6 = linalg.index 6 : index
+        %t5 = arith.addi %t4, %iv6 : index
+        %iv7 = linalg.index 7 : index
+        %t6 = arith.addi %t5, %iv7 : index
+        %yield = arith.index_cast %t6 : index to i32
+        linalg.yield %yield : i32
+    } -> tensor<?x3x?x5x?x7x?x?xi32>
+  return %generic : tensor<?x3x?x5x?x7x?x?xi32>
+}
+//      CHECK: func @fuse_by_collapsing_dynamic(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?x?x?xi32>
+//  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//  CHECK-DAG:   %[[C5:.+]] = arith.constant 5 : index
+//      CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]]
+//  CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[EXPAND]], %[[C2]]
+//  CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[EXPAND]], %[[C5]]
+//      CHECK:   linalg.generic
+//      CHECK:     %[[IV0:.+]] = linalg.index 1
+//      CHECK:     %[[REM1_IV0:.+]] = arith.remui %[[IV0]], %[[C5]]
+//      CHECK:     %[[DIV1_IV0:.+]] = arith.divui %[[IV0]], %[[C5]]
+//      CHECK:     %[[REM2_IV0:.+]] = arith.remui %[[DIV1_IV0]], %[[D1]]
+//      CHECK:     %[[DIV2_IV0:.+]] = arith.divui %[[DIV1_IV0]], %[[D1]]
+//      CHECK:     %[[IV1:.+]] = linalg.index 3
+//      CHECK:     %[[REM1_IV1:.+]] = arith.remui %[[IV1]], %[[D0]]
+//      CHECK:     %[[DIV1_IV1:.+]] = arith.divui %[[IV1]], %[[D0]]
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3)>
+func @fuse_reductions(%arg0 : tensor<2x?x5xf32>, %arg1 : tensor<2x5xf32>) -> tensor<2x5xf32> {
+  %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] : tensor<2x?x5xf32> into tensor<2x6x?x5xf32>
+  %1 = linalg.generic {
+      indexing_maps = [#map0, #map1],
+      iterator_types = ["parallel", "reduction", "reduction", "parallel"]}
+      ins(%0 : tensor<2x6x?x5xf32>) outs(%arg1 : tensor<2x5xf32>) {
+        ^bb0(%b0 : f32, %b1 : f32):
+          %2 = arith.addf %b0, %b1 : f32
+          linalg.yield %2 : f32
+      } -> tensor<2x5xf32>
+  return %1 : tensor<2x5xf32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+//      CHECK: func @fuse_reductions(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<2x?x5xf32>
+// CHECK-SAME:     %[[ARG1:.+]]: tensor<2x5xf32>) -> tensor<2x5xf32>
+//      CHECK:   %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME:       iterator_types = ["parallel", "reduction", "parallel"]
+// CHECK-SAME:       ins(%[[ARG0]] : tensor<2x?x5xf32>)
+// CHECK-SAME:       outs(%[[ARG1]] : tensor<2x5xf32>)
+
+// -----
+
+// Test no fusion because the folded dimensions are not all preserved.
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+func @no_fuse_unpreserved_folding(%arg0 : tensor<2x12x5xf32>, %arg1 : tensor<2x3xf32>) -> tensor<2x3x4x5xf32> {
+  %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] : tensor<2x12x5xf32> into tensor<2x3x4x5xf32>
+  %init = linalg.init_tensor [2, 3, 4, 5] : tensor<2x3x4x5xf32>
+  %1 = linalg.generic {
+      indexing_maps = [#map0, #map1, #map0],
+      iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+      ins(%0, %arg1 : tensor<2x3x4x5xf32>, tensor<2x3xf32>) outs(%init : tensor<2x3x4x5xf32>) {
+        ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+          %2 = arith.addf %b0, %b1 : f32
+          linalg.yield %2 : f32
+      } -> tensor<2x3x4x5xf32>
+  return %1 : tensor<2x3x4x5xf32>
+}
+//      CHECK: func @no_fuse_unpreserved_folding
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<2x12x5xf32>
+// CHECK-SAME:     %[[ARG1:.+]]: tensor<2x3xf32>
+//      CHECK:   %[[RESHAPE:.+]] = tensor.expand_shape %[[ARG0]]
+//      CHECK:   %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME:       ins(%[[RESHAPE]], %[[ARG1]] :
+//      CHECK:   return %[[GENERIC]]
+
+// -----
+
+// Test no fusion because the folded dimensions are not all preserved.
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>
+func @no_fuse_unpreserved_folding_transpose(%arg0 : tensor<2x12x5xf32>, %arg1 : tensor<2xf32>) -> tensor<2x4x3x5xf32> {
+  %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] : tensor<2x12x5xf32> into tensor<2x3x4x5xf32>
+  %init = linalg.init_tensor [2, 4, 3, 5] : tensor<2x4x3x5xf32>
+  %1 = linalg.generic {
+      indexing_maps = [#map0, #map1, #map2],
+      iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+      ins(%0, %arg1 : tensor<2x3x4x5xf32>, tensor<2xf32>) outs(%init : tensor<2x4x3x5xf32>) {
+        ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+          %2 = arith.addf %b0, %b1 : f32
+          linalg.yield %2 : f32
+      } -> tensor<2x4x3x5xf32>
+  return %1 : tensor<2x4x3x5xf32>
+}
+//      CHECK: func @no_fuse_unpreserved_folding_transpose
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<2x12x5xf32>
+// CHECK-SAME:     %[[ARG1:.+]]: tensor<2xf32>
+//      CHECK:   %[[RESHAPE:.+]] = tensor.expand_shape %[[ARG0]]
+//      CHECK:   %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME:       ins(%[[RESHAPE]], %[[ARG1]] :
+//      CHECK:   return %[[GENERIC]]
+
+// -----
+
+// Test no fusion because the iterator types of folded dims are not preserved.
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d3)>
+func @no_fuse_mismatched_iterator_types(%arg0 : tensor<2x12x5xf32>, %arg1 : tensor<2x3xf32>) -> tensor<2x5xf32> {
+  %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] : tensor<2x12x5xf32> into tensor<2x3x4x5xf32>
+  %init = linalg.init_tensor [2, 5] : tensor<2x5xf32>
+  %1 = linalg.generic {
+      indexing_maps = [#map0, #map1, #map2],
+      iterator_types = ["parallel", "reduction", "parallel", "parallel"]}
+      ins(%0, %arg1 : tensor<2x3x4x5xf32>, tensor<2x3xf32>) outs(%init : tensor<2x5xf32>) {
+        ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+          %2 = arith.addf %b0, %b1 : f32
+          linalg.yield %2 : f32
+      } -> tensor<2x5xf32>
+  return %1 : tensor<2x5xf32>
+}
+//      CHECK: func @no_fuse_mismatched_iterator_types
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<2x12x5xf32>
+// CHECK-SAME:     %[[ARG1:.+]]: tensor<2x3xf32>
+//      CHECK:   %[[RESHAPE:.+]] = tensor.expand_shape %[[ARG0]]
+//      CHECK:   %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME:       ins(%[[RESHAPE]], %[[ARG1]] :
+//      CHECK:   return %[[GENERIC]]
+
+// -----
+
+// Test control of fusion using control function
+// Test no fusion because the folded dimensions are not all preserved.
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+func @control_fusion(%arg0 : tensor<6xf32>, %arg1 : tensor<20xf32>) -> tensor<2x3x4x5xf32> {
+  %0 = tensor.expand_shape %arg0 [[0, 1]] : tensor<6xf32> into tensor<2x3xf32>
+  %1 = tensor.expand_shape %arg1 [[0, 1]] : tensor<20xf32> into tensor<4x5xf32>
+    %init = linalg.init_tensor [2, 3, 4, 5] : tensor<2x3x4x5xf32>
+  %2 = linalg.generic {
+      indexing_maps = [#map0, #map1, #map2],
+      iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+      ins(%0, %1 : tensor<2x3xf32>, tensor<4x5xf32>) outs(%init : tensor<2x3x4x5xf32>) {
+        ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+          %3 = arith.addf %b0, %b1 : f32
+          linalg.yield %3 : f32
+      } -> tensor<2x3x4x5xf32>
+  return %2 : tensor<2x3x4x5xf32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d1)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//      CHECK: func @control_fusion(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<6xf32>
+// CHECK-SAME:     %[[ARG1:.+]]: tensor<20xf32>
+//      CHECK:   %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME:       iterator_types = ["parallel", "parallel"]
+// CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]] :
+// CHECK-SAME:       outs(%{{.+}}: tensor<6x20xf32>)
+//      CHECK:   %[[RESHAPE1:.+]] = tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1, 2]{{\]}}
+//      CHECK:   %[[RESHAPE2:.+]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1], [2], [3]{{\]}}
+//      CHECK:   return %[[RESHAPE2]]
+
+//  CONTROL-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+//  CONTROL-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+//  CONTROL-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+//      CONTROL: func @control_fusion(
+// CONTROL-SAME:     %[[ARG0:.+]]: tensor<6xf32>
+// CONTROL-SAME:     %[[ARG1:.+]]: tensor<20xf32>
+//      CONTROL:     %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]]
+//      CONTROL:     %[[INIT:.+]] = linalg.init_tensor [2, 3, 4, 5]
+//      CONTROL:     %[[INIT_RESHAPE:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1], [2, 3]{{\]}}
+//      CONTROL:     %[[GENERIC:.+]] = linalg.generic
+// CONTROL-SAME:         ins(%[[EXPAND]], %[[ARG1]] :
+// CONTROL-SAME:         outs(%[[INIT_RESHAPE]] :
+//      CONTROL:     %[[RESULT:.+]] = tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1], [2, 3]{{\]}}
+
+// -----
+
+// Corner case that isnt handled currently.
+#map = affine_map<(d0) -> (d0)>
+func @zero_D_test(%arg0: tensor<f32>) -> tensor<1xf32> {
+  %0 = tensor.expand_shape %arg0 [] : tensor<f32> into tensor<1xf32>
+  %init = linalg.init_tensor [1] : tensor<1xf32>
+  %1 = linalg.generic {
+      indexing_maps = [#map, #map],
+      iterator_types = ["parallel"]}
+      ins(%0: tensor<1xf32>) outs(%init : tensor<1xf32>) {
+        ^bb0(%b0 : f32, %b1 : f32):
+          linalg.yield %b0: f32
+      } -> tensor<1xf32>
+  return %1 : tensor<1xf32>
+}
+//      CHECK: func @zero_D_test
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<f32>
+//      CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]]
+//      CHECK:   %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME:       ins(%[[EXPAND]] :
+//      CHECK:   return %[[GENERIC]]

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
index 3efa97f941406..16c81b5612d08 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
@@ -78,6 +78,19 @@ struct TestLinalgElementwiseFusion
                      "to generic -> expand_shape pattern"),
       llvm::cl::init(false)};
 
+  Option<bool> fuseWithReshapeByCollapsing{
+      *this, "fuse-with-reshape-by-collapsing",
+      llvm::cl::desc("Test linalg expand_shape -> generic fusion patterns that "
+                     "collapse the iteration space of the consumer"),
+      llvm::cl::init(false)};
+
+  Option<bool> fuseWithReshapeByCollapsingWithControlFn{
+      *this, "fuse-with-reshape-by-collapsing-control",
+      llvm::cl::desc("Test controlling the linalg expand_shape -> generic "
+                     "fusion patterns that "
+                     "collapse the iteration space of the consumer"),
+      llvm::cl::init(false)};
+
   void runOnOperation() override {
     MLIRContext *context = &this->getContext();
     FuncOp funcOp = this->getOperation();
@@ -129,6 +142,26 @@ struct TestLinalgElementwiseFusion
       linalg::populatePushReshapeOpsPatterns(patterns);
       (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
     }
+
+    if (fuseWithReshapeByCollapsing) {
+      RewritePatternSet patterns(context);
+      linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns);
+      (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
+    }
+
+    if (fuseWithReshapeByCollapsingWithControlFn) {
+      RewritePatternSet patterns(context);
+      linalg::ControlElementwiseOpsFusionFn controlFn =
+          [](const OpResult &producer, OpOperand &consumer) -> bool {
+        if (isa<tensor::ExpandShapeOp>(producer.getDefiningOp())) {
+          // Skip fusing the first operand.
+          return consumer.getOperandNumber();
+        }
+        return true;
+      };
+      linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns, controlFn);
+      (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
+    }
   }
 };
 


        


More information about the Mlir-commits mailing list