[Mlir-commits] [mlir] 0c090dc - [mlir][Linalg] Deprecate legacy reshape + generic op folding patterns.

Mahesh Ravishankar llvmlistbot at llvm.org
Thu Apr 21 15:25:36 PDT 2022


Author: Mahesh Ravishankar
Date: 2022-04-21T22:25:23Z
New Revision: 0c090dcc8a97a07bb3b3d2f64dbd1abf3990c1c6

URL: https://github.com/llvm/llvm-project/commit/0c090dcc8a97a07bb3b3d2f64dbd1abf3990c1c6
DIFF: https://github.com/llvm/llvm-project/commit/0c090dcc8a97a07bb3b3d2f64dbd1abf3990c1c6.diff

LOG: [mlir][Linalg] Deprecate legacy reshape + generic op folding patterns.

These patterns have been superceded by the fusion by collapsing patterns.

Differential Revision: https://reviews.llvm.org/D124145

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Passes.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
    mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
    mlir/test/Dialect/Linalg/reshape_fusion.mlir
    mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp

Removed: 
    mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir
    mlir/test/Dialect/Linalg/reshape_linearization_fusion_with_unit_dims.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 9f717d07d276e..06f0e217986d7 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -90,31 +90,11 @@ def LinalgFoldUnitExtentDims : Pass<"linalg-fold-unit-extent-dims", ""> {
 def LinalgElementwiseOpFusion : Pass<"linalg-fuse-elementwise-ops"> {
   let summary = "Fuse elementwise operations on tensors";
   let constructor = "mlir::createLinalgElementwiseOpFusionPass()";
-  let options = [
-    Option<"allowFoldingUnitDimReshapes", "allow-folding-unit-dim-reshapes",
-           "bool", /*default=*/"false",
-           "Allow fusing linalg.tensor_reshape ops that performs unit "
-           "dimension collapsing">
-  ];
   let dependentDialects = [
     "AffineDialect", "linalg::LinalgDialect", "memref::MemRefDialect"
   ];
 }
 
-def LinalgFoldReshapeOpsByLinearization :
-  Pass<"linalg-fold-reshape-ops-by-linearization"> {
-  let summary = "Fold TensorReshapeOps with generic/indexed generic ops by "
-                "linearization";
-  let constructor = "mlir::createFoldReshapeOpsByLinearizationPass()";
-  let options = [
-    Option<"allowFoldingUnitDimReshapes", "allow-folding-unit-dim-reshapes",
-           "bool", /*default=*/"false",
-           "Allow fusing linalg.tensor_reshape ops that performs unit "
-           "dimension collapsing">
-  ];
-  let dependentDialects = ["AffineDialect", "memref::MemRefDialect"];
-}
-
 def LinalgNamedOpConversion: Pass<"linalg-named-op-conversion"> {
   let summary = "Convert from one named linalg op to another.";
   let constructor = "mlir::createLinalgNamedOpConversionPass()";

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index b4fefc21132e7..188f2b436a3d3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -37,10 +37,6 @@ struct LinalgElementwiseFusionOptions;
 struct LinalgFusionOptions;
 struct LinalgTilingOptions;
 
-/// Default function to control reshape folding. Skips folding unit dimension
-/// reshapes.
-bool skipUnitDimReshape(const OpResult &producer, OpOperand &consumer);
-
 //===----------------------------------------------------------------------===//
 // Transformations exposed as function calls.
 //===----------------------------------------------------------------------===//
@@ -91,24 +87,6 @@ void populateFoldReshapeOpsByCollapsingPatterns(
 void populateConstantFoldLinalgOperations(RewritePatternSet &patterns,
                                           const ControlFusionFn &controlFn);
 
-/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
-/// producer (consumer) generic operation by linearizing the indexing map used
-/// to access the source (target) of the reshape operation in the generic
-/// operation.
-/// TODO(ravishankarm): These patterns are to be deprecated in favor of using
-/// the `populateFoldReshapeByCollapsingPatterns`.
-void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns);
-
-/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
-/// producer (consumer) generic operation by linearizing the indexing map used
-/// to access the source (target) of the reshape operation in the generic
-/// operation. The patterns are applied only when the tensor reshape involved is
-/// collapsing (introducing) unit-extent dimensions.
-/// TODO(ravishankarm): These patterns are to be deprecated in favor of using
-/// the `populateFoldReshapeByCollapsingPatterns`.
-void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
-    RewritePatternSet &patterns);
-
 /// Pattern to fuse a `tensor.pad` operation with the producer of its source,
 /// if the producer is a `linalg` operation with all parallel iterator types.
 void populateFuseTensorPadWithProducerLinalgOpPatterns(
@@ -128,12 +106,6 @@ void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns);
 /// Patterns that are used to bubble up extract slice op above linalg op.
 void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns);
 
-/// Patterns to push reshape op towards the end of the graph in order to expose
-/// more fusion opportunities.
-/// TODO(ravishankarm): These patterns are to be deprecated in favor of using
-/// the `populateFoldReshapeByCollapsingPatterns`.
-void populatePushReshapeOpsPatterns(RewritePatternSet &patterns);
-
 /// Perform standalone tiling of a single LinalgOp by `tileSizes`.
 /// and permute the loop nest according to `interchangeVector`
 /// The permutation is expressed as a list of integers that specify

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 3aabac2ba456a..cc0ec0866f842 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -392,263 +392,6 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
 };
 } // namespace
 
-//===---------------------------------------------------------------------===//
-// Methods and patterns that fuse reshape ops with elementwise operations by
-// linearization of indexing maps.
-//===---------------------------------------------------------------------===//
-
-// TODO(ravishankarm): The indexing maps
-// these produce in the general case are detrimental to transformations.
-// These patterns are on deprecation path in favor of using fusion by
-// collapsing, which covers the only legitimate use case of this pattern of
-// folding unit-extent dims.
-
-/// Linearize the expressions in `sourceMap` based on the `reassociationMaps`
-/// provided, given the shape of the source tensor that corresponds to the
-/// `sourceMap`. Note that this implicitly assumes that the tensors dimensions
-/// are "row-major" ordered logically.
-///
-/// For example:
-///
-/// %0 = op ... : tensor<?x?x4x5xf32>
-/// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>`
-///
-/// and reshape:
-/// %1 = tensor.collapse_shape %0 [[0], [0, 1, 2]] :
-///        tensor<?x?x4x5xf32> into tensor<?x?xf32>
-///
-/// would be rewritten into:
-/// %0 = op ... : tensor<?x?x4x5xf32>
-/// with output index_map
-///   `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>`
-template <typename TensorReshapeOp>
-static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
-                                        TensorReshapeOp reshapeOp) {
-  constexpr bool isExpanding =
-      std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value;
-  ArrayRef<int64_t> sourceShape =
-      (isExpanding ? reshapeOp.getResultType().getShape()
-                   : reshapeOp.getSrcType().getShape());
-  SmallVector<AffineExpr> resultExprs;
-  ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults();
-  MLIRContext *context = sourceMap.getContext();
-
-  // Compute the result exprs based on the reassociation maps.
-  for (auto &indices : reshapeOp.getReassociationIndices()) {
-    // Assume that they are in-order and contiguous (already checked in
-    // verifier).
-    assert(!indices.empty());
-    SmallVector<int64_t> sizes;
-    SmallVector<AffineExpr> dimExprs;
-    for (auto en : llvm::zip(sourceShape.slice(indices[0], indices.size()),
-                             sourceExprs.slice(indices[0], indices.size()))) {
-      if (std::get<0>(en) == 1)
-        continue;
-      sizes.push_back(std::get<0>(en));
-      dimExprs.push_back(std::get<1>(en));
-    }
-    AffineExpr linearizedExpr =
-        makeCanonicalStridedLayoutExpr(sizes, dimExprs, context);
-    resultExprs.push_back(linearizedExpr);
-  }
-  // The new affine map cannot drop unused dimension but some new symbols may
-  // have been added. Create a map with at least as many dimensions/symbols as
-  // the original affine map.
-  int64_t maxDim = -1;
-  int64_t maxSym = -1;
-  getMaxDimAndSymbol<SmallVector<AffineExpr>>({resultExprs}, maxDim, maxSym);
-  unsigned numDims = std::max(unsigned(maxDim + 1), sourceMap.getNumDims());
-  unsigned numSyms = std::max(unsigned(maxSym + 1), sourceMap.getNumSymbols());
-  return AffineMap::get(numDims, numSyms, resultExprs, context);
-}
-
-// tensor::ExpandShapeOp is fusable with its consumer (i.e. reshape as a
-// producer). Fusing when operand has higher rank will require use of mods and
-// divs in the indexing maps of the fused op which would make it non-invertible.
-static bool isTensorReshapeOpFoldableByLinearization(
-    tensor::ExpandShapeOp expandOp, AffineMap useIndexMap, bool asProducer) {
-  if (!asProducer)
-    return false;
-  return useIndexMap.isPermutation();
-}
-
-// tensor::CollapseShapeOp is fusable with its producer (i.e. reshape as a
-// consumer).
-static bool
-isTensorReshapeOpFoldableByLinearization(tensor::CollapseShapeOp collapseOp,
-                                         AffineMap useIndexMap,
-                                         bool asProducer) {
-  if (asProducer)
-    return false;
-  return useIndexMap.isPermutation();
-}
-
-/// Check if the reshape operation is only expansion into/collapsing of
-/// unit-dimension.
-template <typename TensorReshapeOp>
-static bool isUnitDimExpansionOnly(TensorReshapeOp reshapeOp) {
-  constexpr bool isExpanding =
-      std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value;
-  ArrayRef<int64_t> expandedShape =
-      (isExpanding ? reshapeOp.getResultType().getShape()
-                   : reshapeOp.getSrcType().getShape());
-  for (auto &indices : reshapeOp.getReassociationIndices()) {
-    unsigned numUnitDims = 0;
-    for (int64_t position : indices)
-      if (expandedShape[position] == 1)
-        numUnitDims++;
-    if (numUnitDims != indices.size() - 1)
-      return false;
-  }
-  return true;
-}
-
-namespace {
-/// Pattern to fold tensor_expand_shape op with its consumer by using the source
-/// of the reshape op as the operand in the consumer (instead of the result of
-/// the tensor_collapse_shape). The corresponding index map in the consumer
-/// needs to be modified to linearize the folded dimension.
-///
-/// For example,
-///
-/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-/// %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]]
-///      tensor<?x?x?xf32> into tensor<?x?x4x?xf32>
-/// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... }
-///        ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) ...
-///        -> tensor<?x?x4x?xf32>
-///
-/// can be folded into
-///
-/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
-/// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-/// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... }
-///        ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ...
-///        -> tensor<?x?x4x?xf32>
-template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
-struct FoldProducerReshapeOpByLinearization
-    : public OpRewritePattern<GenericOp> {
-  using OpRewritePattern<GenericOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(GenericOp genericOp,
-                                PatternRewriter &rewriter) const override {
-    if (!genericOp.hasTensorSemantics())
-      return failure();
-    SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
-    for (const auto &en : llvm::enumerate(inputOperands)) {
-      auto reshapeOp = en.value()->get().getDefiningOp<TensorReshapeOp>();
-      if (!reshapeOp)
-        continue;
-
-      if (!isTensorReshapeOpFoldableByLinearization(
-              reshapeOp, genericOp.getTiedIndexingMap(en.value()),
-              /*asProducer =*/true) ||
-          (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
-        continue;
-
-      // Compute the fused operands list,
-      SmallVector<Value> fusedOperands = genericOp.getInputOperands();
-      fusedOperands[en.index()] = reshapeOp.src();
-      SmallVector<Value> outputOperands = genericOp.getOutputOperands();
-      llvm::append_range(fusedOperands, outputOperands);
-
-      // Compute indexing_maps for the fused operation. The indexing_maps for
-      // the operands of the consumers that arent fused are the same.
-      SmallVector<AffineMap> fusedIndexMaps = genericOp.getIndexingMaps();
-
-      // Compute the indexing map to use for the result of the producer.
-      AffineMap modifiedMap =
-          linearizeCollapsedDims(fusedIndexMaps[en.index()], reshapeOp);
-      // The modified map cannot have symbols.
-      if (modifiedMap.getNumSymbols())
-        return failure();
-      for (AffineExpr expr : modifiedMap.getResults()) {
-        if (!expr.isPureAffine())
-          return failure();
-      }
-      fusedIndexMaps[en.index()] = modifiedMap;
-
-      // Further check that the resulting index maps can be fused and
-      // inverted. Without this the resultant op is not legal.
-      if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
-        return rewriter.notifyMatchFailure(
-            genericOp, "fused op loop bound computation failed");
-      }
-
-      rewriter.startRootUpdate(genericOp);
-      genericOp->setOperands(fusedOperands);
-      genericOp.indexing_mapsAttr(
-          rewriter.getAffineMapArrayAttr(fusedIndexMaps));
-      rewriter.finalizeRootUpdate(genericOp);
-      return success();
-    }
-    return failure();
-  }
-};
-
-/// Pattern to fold tensor_collapse_shape or tensor_expand_shape op with its
-/// producer. The corresponding index map in the consumer needs to be modified
-/// to linearize the folded dimension.
-template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
-struct FoldConsumerReshapeOpByLinearization
-    : public OpRewritePattern<TensorReshapeOp> {
-  using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
-                                PatternRewriter &rewriter) const override {
-    GenericOp producer = reshapeOp.src().template getDefiningOp<GenericOp>();
-    if (!producer || !producer.hasTensorSemantics() ||
-        producer.getNumOutputs() != 1 ||
-        !isTensorReshapeOpFoldableByLinearization(
-            reshapeOp,
-            producer.getTiedIndexingMap(producer.getOutputOperand(0)),
-            /*asProducer =*/false) ||
-        (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
-      return failure();
-    // The indexing_maps for the operands of the fused operation are same as
-    // those for the operands of the producer.
-    SmallVector<AffineMap> fusedIndexMaps = producer.getIndexingMaps();
-
-    // Compute the indexing map to use for the operand of the producer.
-    AffineMap modifiedMap = linearizeCollapsedDims(
-        producer.getTiedIndexingMap(producer.getOutputOperand(0)), reshapeOp);
-    for (AffineExpr expr : modifiedMap.getResults()) {
-      if (!expr.isPureAffine()) {
-        return rewriter.notifyMatchFailure(
-            producer, "fused op indexing map is not affine");
-      }
-    }
-    fusedIndexMaps.back() = modifiedMap;
-
-    // Further check that the resulting index maps can be fused and
-    // inverted. Without this the resultant op is not legal.
-    if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
-      return rewriter.notifyMatchFailure(
-          producer, "fused op loop bound computation failed");
-    }
-
-    Location loc = producer.getLoc();
-    SmallVector<Value> inputOperands = producer.getInputOperands();
-    Value output = rewriter.create<TensorReshapeOp>(
-        loc, producer.getOutputOperand(0)->get(),
-        reshapeOp.getReassociationExprs());
-    auto fusedOp = rewriter.create<GenericOp>(
-        loc, reshapeOp.getResultType(),
-        /*inputs=*/inputOperands,
-        // TODO: handle outputs.
-        /*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
-        producer.iterator_types(),
-        /*doc=*/nullptr,
-        /*library_call=*/nullptr);
-    auto &fusedRegion = fusedOp->getRegion(0);
-    rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion,
-                               fusedRegion.begin());
-    rewriter.replaceOp(reshapeOp, fusedOp->getResults());
-    return success();
-  }
-};
-} // namespace
-
 //===---------------------------------------------------------------------===//
 // Methods and patterns that fuse reshape ops with elementwise operations by
 // expanding the dimensionality of the elementwise operations.
@@ -1737,174 +1480,6 @@ class FoldWithProducerReshapeOpByCollapsing
 };
 } // namespace
 
-//===---------------------------------------------------------------------===//
-// Methods and patterns to convert tensor.expand_shape -> linalg.generic
-// into linalg.generic -> tensor.expand_shape, i.e. push the reshape down.
-//===---------------------------------------------------------------------===//
-
-// TODO(ravishankarm): This pattern is to be deprecated in favor of fusion by
-// collapsing that provides a more general functionality. This pattern is very
-// specific to a particular use case. The fusion by collapsing can provide the
-// same control to clients using the control function there.
-
-static SmallVector<ReassociationIndices>
-getReassociationIndices(ArrayRef<AffineMap> maps) {
-  SmallVector<ReassociationIndices> reassociation;
-  for (AffineMap map : maps) {
-    ReassociationIndices indices;
-    for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
-      unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition();
-      indices.push_back(pos);
-    }
-    reassociation.push_back(indices);
-  }
-  return reassociation;
-}
-
-namespace {
-/// Pattern to move rank reducing reshape after an elementwise linalg generic
-/// op. This is useful to expose more fusion opportunities between named ops and
-/// generic ops. This can only be done if there is no broadcast or permuation
-/// within the dimensions we need to merge.
-///
-/// For example,
-///
-///  %0 = tensor.expand_shape %A [[0, 1], [2]]
-///      : tensor<12544x16xf32> into tensor<112x112x16xf32>
-///  %2 = linalg.generic {indexing_maps = [
-///    affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
-///    affine_map<(d0, d1, d2) -> (d2)>,
-///    affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types =
-///    ["parallel", "parallel", "parallel"]} {
-///  } -> tensor<112x112x16xf32>
-///
-///  into
-///
-///  %2 = linalg.generic {indexing_maps = [
-///    affine_map<(d0, d1) -> (d0, d1)>,
-///    affine_map<(d0, d1) -> (d1)>,
-///    affine_map<(d0, d1) -> (d0, d1)>],
-///    iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1
-///    : tensor<12544x16xf32>, tensor<16xf32>) outs(%1 : tensor<12544x16xf32>) {
-///  } -> tensor<12544x16xf32>
-///  %3 = tensor.expand_shape %2 [[0, 1], [2]]
-///    : tensor<12544x16xf32> into tensor<112x112x16xf32>
-struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
-  using OpRewritePattern<GenericOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(GenericOp genericOp,
-                                PatternRewriter &rewriter) const override {
-    // Only apply to elementwise linalg on tensor.
-    if (!genericOp.hasTensorSemantics() || genericOp.hasIndexSemantics() ||
-        genericOp.getNumParallelLoops() != genericOp.getNumLoops())
-      return failure();
-    // Only support identity output maps. It could be extended to permuations if
-    // needed.
-    if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *opOperand) {
-          return !genericOp.getTiedIndexingMap(opOperand).isIdentity();
-        }))
-      return failure();
-    int64_t destRank = genericOp.getNumParallelLoops();
-    SmallVector<Value> newOperands = genericOp.getInputOperands();
-    tensor::ExpandShapeOp reshapeFound;
-    // 1. Look for tensor_expand_shape operands and figure out save the
-    // dimensions merged.
-    SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
-    for (const auto &en : llvm::enumerate(inputOperands)) {
-      auto reshapeOp =
-          en.value()->get().template getDefiningOp<tensor::ExpandShapeOp>();
-      if (!reshapeOp)
-        continue;
-      // TODO: We could support non-identity map as long as the merged
-      // dimensions are still contiguous.
-      if (!genericOp.getTiedIndexingMap(en.value()).isIdentity())
-        continue;
-      if (reshapeFound) {
-        // Only support a second reshape op if it has the same reassociate maps.
-        if (reshapeFound.getReassociationMaps() ==
-            reshapeOp.getReassociationMaps())
-          newOperands[en.index()] = reshapeOp.src();
-        continue;
-      }
-      reshapeFound = reshapeOp;
-      newOperands[en.index()] = reshapeOp.src();
-    }
-    if (!reshapeFound)
-      return failure();
-
-    // Calculate the reassociation indices and rassociated reverse map.
-    SmallVector<ReassociationIndices> reassociation =
-        getReassociationIndices(reshapeFound.getReassociationMaps());
-    SmallVector<unsigned> remap(destRank);
-    for (auto &indices : llvm::enumerate(reassociation)) {
-      for (int64_t index : indices.value()) {
-        remap[index] = indices.index();
-      }
-    }
-    // 2. Verify that we can merge the dimensions in the linalg and that we
-    // don't need to create new reshapes operands. Inserting new reshape
-    // operands would defeat the purpose of the transformation.
-    for (const auto &en : llvm::enumerate(inputOperands)) {
-      if (en.value()->get() == newOperands[en.index()]) {
-        AffineMap map = genericOp.getTiedIndexingMap(en.value());
-        for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) {
-          if (reassociation[remap[map.getDimPosition(i)]].size() > 1)
-            return failure();
-        }
-      }
-    }
-
-    // 3. Calculate the affine map remapping and the reassociation to apply to
-    // output tensors.
-    SmallVector<AffineMap> newMaps;
-    unsigned newRank = reassociation.size();
-    for (auto map : genericOp.getIndexingMaps()) {
-      SmallVector<AffineExpr> newExprs;
-      for (auto expr : map.getResults()) {
-        unsigned position = expr.template cast<AffineDimExpr>().getPosition();
-        // Skip dimension merged except for the last of the group.
-        if (reassociation[remap[position]].back() == position) {
-          newExprs.push_back(
-              getAffineDimExpr(remap[position], genericOp.getContext()));
-        }
-      }
-      newMaps.push_back(
-          AffineMap::get(newRank, 0, newExprs, genericOp.getContext()));
-    }
-
-    // 4. Reshape the output tensors.
-    SmallVector<Value> newOutputs;
-    SmallVector<Type> newOutputTypes;
-    for (auto output : genericOp.outputs()) {
-      auto newOutputType = RankedTensorType::get(
-          reshapeFound.getSrcType().getShape(),
-          output.getType().template cast<RankedTensorType>().getElementType());
-      Value newOutput = rewriter.create<tensor::CollapseShapeOp>(
-          genericOp->getLoc(), newOutputType, output, reassociation);
-      newOutputTypes.push_back(newOutputType);
-      newOutputs.push_back(newOutput);
-    }
-    // 5. Create a new generic op with lowerer rank.
-    SmallVector<StringRef> iteratorTypes(newRank,
-                                         getParallelIteratorTypeName());
-    auto newOp = rewriter.create<GenericOp>(genericOp->getLoc(), newOutputTypes,
-                                            newOperands, newOutputs, newMaps,
-                                            iteratorTypes);
-    rewriter.inlineRegionBefore(genericOp.region(), newOp.region(),
-                                newOp.region().begin());
-    // 6. Reshape the so that the type matches the uses.
-    SmallVector<Value> newResults;
-    for (const auto &result : llvm::enumerate(newOp->getResults())) {
-      newResults.push_back(rewriter.create<tensor::ExpandShapeOp>(
-          genericOp->getLoc(), genericOp.getOutputTensorTypes()[result.index()],
-          result.value(), reassociation));
-    }
-    rewriter.replaceOp(genericOp, newResults);
-    return success();
-  }
-};
-} // namespace
-
 //===---------------------------------------------------------------------===//
 // Methods and patterns that fuse constants with linalg.generic operations.
 //===---------------------------------------------------------------------===//
@@ -2093,27 +1668,6 @@ struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
   }
 };
 } // namespace
-//===---------------------------------------------------------------------===//
-// Methods that add patterns described in this file to a pattern list.
-//===---------------------------------------------------------------------===//
-
-void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns(
-    RewritePatternSet &patterns) {
-  patterns.add<
-      FoldProducerReshapeOpByLinearization<false, tensor::CollapseShapeOp>,
-      FoldProducerReshapeOpByLinearization<false, tensor::ExpandShapeOp>,
-      FoldConsumerReshapeOpByLinearization<false, tensor::CollapseShapeOp>>(
-      patterns.getContext());
-}
-
-void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
-    RewritePatternSet &patterns) {
-  patterns
-      .add<FoldProducerReshapeOpByLinearization<true, tensor::CollapseShapeOp>,
-           FoldProducerReshapeOpByLinearization<true, tensor::ExpandShapeOp>,
-           FoldConsumerReshapeOpByLinearization<true, tensor::CollapseShapeOp>>(
-          patterns.getContext());
-}
 
 void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
     RewritePatternSet &patterns,
@@ -2140,28 +1694,10 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
                RemoveOutsDependency>(context);
 }
 
-void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) {
-  auto *context = patterns.getContext();
-  patterns.add<PushExpandingReshape>(context);
-}
-
 //===---------------------------------------------------------------------===//
 // Passes
 //===---------------------------------------------------------------------===//
 
-bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
-                                      OpOperand &consumer) {
-  if (auto producerCollapseOp =
-          dyn_cast<tensor::CollapseShapeOp>(producer.getOwner())) {
-    return !isUnitDimExpansionOnly(producerCollapseOp);
-  }
-  if (auto consumerExpandOp =
-          dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
-    return !isUnitDimExpansionOnly(consumerExpandOp);
-  }
-  return true;
-}
-
 namespace {
 
 /// Pass that fuses generic ops on tensors. Used only for testing.
@@ -2186,9 +1722,7 @@ struct LinalgElementwiseOpFusionPass
     // Add elementwise op fusion patterns.
     populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
 
-    populateFoldReshapeOpsByExpansionPatterns(
-        patterns,
-        allowFoldingUnitDimReshapes ? defaultControlFn : skipUnitDimReshape);
+    populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
 
     // Add the sparse tensor rewriting patterns.
     populateSparseTensorRewriting(patterns);
@@ -2212,27 +1746,8 @@ struct LinalgElementwiseOpFusionPass
   }
 };
 
-/// Pass to test folding of reshape ops with generic ops by linearization.
-struct FoldReshapeOpsByLinearizationPass
-    : public LinalgFoldReshapeOpsByLinearizationBase<
-          FoldReshapeOpsByLinearizationPass> {
-  void runOnOperation() override {
-    Operation *op = getOperation();
-    RewritePatternSet patterns(op->getContext());
-    populateFoldReshapeOpsByLinearizationPatterns(patterns);
-    if (allowFoldingUnitDimReshapes) {
-      populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns);
-    }
-    (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
-  }
-};
-
 } // namespace
 
 std::unique_ptr<Pass> mlir::createLinalgElementwiseOpFusionPass() {
   return std::make_unique<LinalgElementwiseOpFusionPass>();
 }
-
-std::unique_ptr<Pass> mlir::createFoldReshapeOpsByLinearizationPass() {
-  return std::make_unique<FoldReshapeOpsByLinearizationPass>();
-}

diff  --git a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
index ea699d820b610..33489cba431fc 100644
--- a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=push-expanding-reshape -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-with-reshape-by-collapsing -split-input-file | FileCheck %s
 
 // CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d1)>
@@ -124,30 +124,3 @@ func.func @type_correctness(%arg0 : tensor<6x5xi32>, %arg1 : tensor<5xf32>,
 //  CHECK-SAME:   outs(%{{.+}} : tensor<6x5xf32>)
 //       CHECK:   tensor.expand_shape %[[OP]]
 //  CHECK-SAME:   tensor<6x5xf32> into tensor<2x3x5xf32>
-
-// -----
-
-func.func @generic_op_index_semantics(%A: tensor<?x16xi64>, %B: tensor<16xi64>, %init: tensor<?x112x16xi64>) -> tensor<?x112x16xi64> {
-  %0 = tensor.expand_shape %A [[0, 1], [2]]
-      : tensor<?x16xi64> into tensor<?x112x16xi64>
-  %2 = linalg.generic {indexing_maps = [
-    affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>,
-    affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
-    iterator_types = ["parallel", "parallel", "parallel"]}
-  ins(%0, %B : tensor<?x112x16xi64>, tensor<16xi64>)
-  outs(%init : tensor<?x112x16xi64>) {
-  ^bb0(%arg1: i64, %arg2: i64, %arg3: i64):  // no predecessors
-    %index = linalg.index 0 : index
-    %1 = arith.index_cast %index : index to i64
-    %add = arith.addi %arg1, %1 : i64
-    %s = arith.subi %add, %arg2 : i64
-    linalg.yield %s : i64
-  } -> tensor<?x112x16xi64>
-  return %2 : tensor<?x112x16xi64>
-}
-//      CHECK: func @generic_op_index_semantics
-// CHECK-SAME:     %[[ARG0:.+]]: tensor<?x16xi64>
-//      CHECK:   %[[RESHAPE:.+]] = tensor.expand_shape %[[ARG0]]
-//      CHECK:   %[[RESULT:.+]] = linalg.generic
-// CHECK-SAME:       ins(%[[RESHAPE]]
-//      CHECK:   return %[[RESULT]]

diff  --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index ebee7e75ac5a1..45e8721278e1c 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -linalg-fuse-elementwise-ops="allow-folding-unit-dim-reshapes=false" -split-input-file | FileCheck %s
-// RUN: mlir-opt %s -linalg-fuse-elementwise-ops="allow-folding-unit-dim-reshapes=true" -split-input-file | FileCheck %s --check-prefix=FOLDUNITDIM
+// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-with-reshape-by-expansion -split-input-file | FileCheck %s
+
 #map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
 #map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
 #map2 = affine_map<(d0, d1, d2) -> ()>
@@ -14,7 +14,7 @@ func.func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xf32>,
      indexing_maps = [#map0, #map1, #map2, #map1],
      iterator_types = ["parallel", "parallel", "parallel"]}
        ins(%0, %arg1, %arg2 : tensor<?x?x?xf32>, tensor<?x?x?xf32>, f32)
-       outs(%0 : tensor<?x?x?xf32>) {
+       outs(%arg1 : tensor<?x?x?xf32>) {
     ^bb0(%arg3: f32, %arg4: f32, %arg5: f32, %s: f32):       
       %1 = arith.mulf %arg3, %arg4 : f32
       %2 = arith.addf %1, %arg5 : f32
@@ -30,15 +30,15 @@ func.func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xf32>,
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x4x?xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: f32
-//      CHECK:   %[[T0:.+]] = tensor.collapse_shape %[[ARG0]]
-// CHECK-SAME:     [0], [1, 2], [3]
 //      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]]
 // CHECK-SAME:     [0], [1], [2, 3]
+//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG1]]
+// CHECK-SAME:     [0], [1], [2, 3]
 //      CHECK:   %[[T3:.+]] = linalg.generic
 // CHECK-SAME:     indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]], #[[MAP6]]]
 // CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
 // CHECK-SAME:     ins(%[[ARG0]], %[[T1]], %[[ARG2]] : tensor<?x?x4x?xf32>, tensor<?x?x?x4xf32>, f32)
-// CHECK-SAME:     outs(%{{.+}} : tensor<?x?x?x4xf32>)
+// CHECK-SAME:     outs(%[[T2]] : tensor<?x?x?x4xf32>)
 //      CHECK:   %[[T4:.+]] = tensor.collapse_shape %[[T3]]
 // CHECK-SAME:     [0], [1], [2, 3]
 // CHECK-SAME:     tensor<?x?x?x4xf32> into tensor<?x?x?xf32>
@@ -80,12 +80,14 @@ func.func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
 // CHECK-SAME:     tensor<?x?xf32> into tensor<?x4x?x5xf32>
 //      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]]
 // CHECK-SAME:     [0], [1, 2, 3]
+//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG0]]
+// CHECK-SAME:     [0], [1, 2, 3]
 // CHECK-SAME:     tensor<?x?xf32> into tensor<?x4x?x5xf32>
 //      CHECK:   %[[T3:.+]] = linalg.generic
 // CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP3]], #[[MAP2]]]
 // CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
 // CHECK-SAME:     ins(%[[T0]], %[[T1]], %[[ARG2]] : tensor<?x4x?x5xf32>, tensor<?x4x?x5xf32>, f32)
-// CHECK-SAME:     outs(%{{.+}} : tensor<?x4x?x5xf32>)
+// CHECK-SAME:     outs(%[[T2]] : tensor<?x4x?x5xf32>)
 //      CHECK:   return %[[T3]] : tensor<?x4x?x5xf32>
 
 
@@ -121,11 +123,14 @@ func.func @reshape_as_consumer_permutation
 //      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]]
 // CHECK-SAME:     [0, 1, 2], [3]
 // CHECK-SAME:     tensor<?x?xf32> into tensor<3x4x?x?xf32>
+//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG0]]
+// CHECK-SAME:     [0, 1], [2], [3, 4, 5]]
+// CHECK-SAME:     tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32>
 //      CHECK:   %[[T3:.+]] = linalg.generic
 // CHECK-SAME:     indexing_maps = [#[[MAP8]], #[[MAP9]], #[[MAP10]]]
 // CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]
 // CHECK-SAME:     ins(%[[T0]], %[[T1]] : tensor<3x4x?x?x2x?xf32>, tensor<3x4x?x?xf32>)
-// CHECK-SAME:     outs(%{{.+}} : tensor<?x2x?x3x4x?xf32>)
+// CHECK-SAME:     outs(%[[T2]] : tensor<?x2x?x3x4x?xf32>)
 //      CHECK:   return %[[T3]] : tensor<?x2x?x3x4x?xf32>
 
 // -----
@@ -155,14 +160,19 @@ func.func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
 //  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 //      CHECK: func @generic_op_reshape_consumer_static
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<264x4xf32>
+//  CHECK-DAG:   %[[CST:.+]] = arith.constant
+// CHECK-SAME:     : tensor<8x33x4xf32>
+//  CHECK-DAG:   %[[INIT:.+]] = linalg.init_tensor [264, 4]
 //      CHECK:   %[[T0:.+]] = tensor.expand_shape %[[ARG0]]
 // CHECK-SAME:     [0, 1], [2]
 // CHECK-SAME:     tensor<264x4xf32> into tensor<8x33x4xf32>
-//      CHECK:   %[[T1:.+]] = linalg.init_tensor [8, 33, 4]
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[INIT]]
+// CHECK-SAME:     [0, 1], [2]
+// CHECK-SAME:     : tensor<264x4xf32> into tensor<8x33x4xf32>
 //      CHECK:   %[[T2:.+]] = linalg.generic
-// CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP2]]]
+// CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]]
 // CHECK-SAME:     ["parallel", "parallel", "parallel"]
-// CHECK-SAME:     ins(%[[T0]] : tensor<8x33x4xf32>)
+// CHECK-SAME:     ins(%[[T0]], %[[CST]] :
 // CHECK-SAME:     outs(%[[T1]] : tensor<8x33x4xf32>)
 //      CHECK:   return %[[T2]] : tensor<8x33x4xf32>
 
@@ -246,7 +256,8 @@ func.func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
 }
 
 // Only check the body in the indexed version of the test.
-//       CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 5 + d2 * 20)>
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 4)>
+//   CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 5)>
 //       CHECK: func @indexed_producer_reshape_consumer_fusion
 //       CHECK:   linalg.generic
 //       CHECK:   ^{{.*}}(
@@ -256,11 +267,12 @@ func.func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
 //   CHECK-DAG:     %[[IDX1:.+]] = linalg.index 1 : index
 //   CHECK-DAG:     %[[IDX2:.+]] = linalg.index 2 : index
 //   CHECK-DAG:     %[[IDX3:.+]] = linalg.index 3 : index
-//   CHECK-DAG:     %[[T3:.+]] = affine.apply #[[MAP]](%[[IDX3]], %[[IDX2]], %[[IDX1]])
+//       CHECK:     %[[T1:.+]] = affine.apply #[[MAP1]](%[[IDX2]], %[[IDX1]])
+//       CHECK:     %[[T2:.+]] = affine.apply #[[MAP2]](%[[IDX3]], %[[T1]])
 //       CHECK:     %[[T4:.+]] = arith.muli %[[ARG3]], %[[ARG4]]
 //       CHECK:     %[[T5:.+]] = arith.index_cast %[[IDX0]]
 //       CHECK:     %[[T6:.+]] = arith.addi %[[T4]], %[[T5]]
-//       CHECK:     %[[T7:.+]] = arith.index_cast %[[T3]]
+//       CHECK:     %[[T7:.+]] = arith.index_cast %[[T2]]
 //       CHECK:     %[[T8:.+]] = arith.addi %[[T6]], %[[T7]]
 //       CHECK:     linalg.yield %[[T8]]
 
@@ -295,24 +307,29 @@ func.func @reshape_as_consumer_permutation
   return %d : tensor<2x3x4x5x6x7xi32>
 }
 
+// -----
 
-//   CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>
-//   CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
-//   CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)>
-//   CHECK-DAG: #[[MAP8:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)>
-//   CHECK-DAG: #[[MAP9:.+]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 7 + d2 * 42)>
+//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
+//   CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)>
+//   CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)>
+//   CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 6)>
+//   CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 7)>
 //       CHECK: func @reshape_as_consumer_permutation
 //  CHECK-SAME:   %[[ARG0:.+]]: tensor<210x6x4xi32>
 //  CHECK-SAME:   %[[ARG1:.+]]: tensor<210x4xi32>
+//   CHECK-DAG:   %[[INIT:.+]] = linalg.init_tensor [6, 4, 210]
 //   CHECK-DAG:   %[[T1:.+]] = tensor.expand_shape %[[ARG0]]
 //  CHECK-SAME:     [0, 1, 2], [3, 4], [5]
 //   CHECK-DAG:   %[[T2:.+]] = tensor.expand_shape %[[ARG1]]
 //  CHECK-SAME:     [0, 1, 2], [3]
-//   CHECK-DAG:   %[[T0:.+]] = linalg.init_tensor [2, 3, 4, 5, 6, 7]
+//   CHECK-DAG:   %[[T3:.+]] = tensor.expand_shape %[[INIT]]
+//  CHECK-SAME:     [0, 1], [2], [3, 4, 5]
+//  CHECK-SAME:     : tensor<6x4x210xi32> into tensor<2x3x4x5x6x7xi32>
 //       CHECK:   %[[T4:.+]] = linalg.generic
-//  CHECK-SAME:     indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]]]
+//  CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
 //  CHECK-SAME:     ins(%[[T1]], %[[T2]] : tensor<5x6x7x2x3x4xi32>, tensor<5x6x7x4xi32>)
-//  CHECK-SAME:     outs(%[[T0]] : tensor<2x3x4x5x6x7xi32>)
+//  CHECK-SAME:     outs(%[[T3]] : tensor<2x3x4x5x6x7xi32>)
 //       CHECK:   ^{{.+}}(
 //  CHECK-SAME:     %[[ARG8:[a-zA-Z0-9]+]]: i32, %[[ARG9:[a-zA-Z0-9]+]]: i32,
 //  CHECK-SAME:     %[[ARG10:[a-zA-Z0-9]+]]: i32)
@@ -322,15 +339,16 @@ func.func @reshape_as_consumer_permutation
 //   CHECK-DAG:       %[[IDX3:.+]] = linalg.index 3 : index
 //   CHECK-DAG:       %[[IDX4:.+]] = linalg.index 4 : index
 //   CHECK-DAG:       %[[IDX5:.+]] = linalg.index 5 : index
-//   CHECK-DAG:       %[[T5:.+]] = affine.apply #[[MAP8]](%[[IDX1]], %[[IDX0]])
-//   CHECK-DAG:       %[[T6:.+]] = affine.apply #[[MAP9]](%[[IDX4]], %[[IDX3]], %[[IDX2]])
-//   CHECK-DAG:       %[[T7:.+]] = arith.addi %[[ARG8]], %[[ARG9]]
-//       CHECK:       %[[T8:.+]] = arith.index_cast %[[T5]]
-//       CHECK:       %[[T9:.+]] = arith.addi %[[T7]], %[[T8]]
-//       CHECK:       %[[T10:.+]] = arith.index_cast %[[T6]]
-//       CHECK:       %[[T11:.+]] = arith.addi %[[T9]], %[[T10]]
-//       CHECK:       %[[T12:.+]] = arith.index_cast %[[IDX5]]
-//       CHECK:       %[[T13:.+]] = arith.addi %[[T11]], %[[T12]]
+//   CHECK-DAG:       %[[T5:.+]] = affine.apply #[[MAP3]](%[[IDX1]], %[[IDX0]])
+//   CHECK-DAG:       %[[T6:.+]] = affine.apply #[[MAP4]](%[[IDX3]], %[[IDX2]])
+//   CHECK-DAG:       %[[T7:.+]] = affine.apply #[[MAP5]](%[[IDX4]], %[[T6]])
+//   CHECK-DAG:       %[[T8:.+]] = arith.addi %[[ARG8]], %[[ARG9]]
+//       CHECK:       %[[T9:.+]] = arith.index_cast %[[T5]]
+//       CHECK:       %[[T10:.+]] = arith.addi %[[T8]], %[[T9]]
+//       CHECK:       %[[T11:.+]] = arith.index_cast %[[T7]]
+//       CHECK:       %[[T12:.+]] = arith.addi %[[T10]], %[[T11]]
+//       CHECK:       %[[T13:.+]] = arith.index_cast %[[IDX5]]
+//       CHECK:       %[[T14:.+]] = arith.addi %[[T12]], %[[T13]]
 
 // -----
 
@@ -421,94 +439,18 @@ func.func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor<?x?xf32>,
 //      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]]
 // CHECK-SAME:     [0, 1, 2], [3]
 // CHECK-SAME:     tensor<?x?xf32> into tensor<?x4x5x?xf32>
+//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG0]]
+// CHECK-SAME:     [0], [1, 2, 3]
+// CHECK-SAME:     tensor<?x?xf32> into tensor<?x?x4x5xf32>
 //      CHECK:   %[[T3:.+]] = linalg.generic
 // CHECK-SAME:     indexing_maps = [#[[MAP4]], #[[MAP4]], #[[MAP5]]]
 // CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
 // CHECK-SAME:     ins(%[[T0]], %[[T1]] : tensor<?x4x5x?xf32>, tensor<?x4x5x?xf32>)
-// CHECK-SAME:     outs(%{{.+}} : tensor<?x?x4x5xf32>)
+// CHECK-SAME:     outs(%[[T2]] : tensor<?x?x4x5xf32>)
 //      CHECK:   return %[[T3]] : tensor<?x?x4x5xf32>
 
 // -----
 
-func.func @unit_dim_reshape_expansion(%arg0 : tensor<1x5xf32>) -> tensor<5x5xf32> {
-  %0 = tensor.collapse_shape %arg0 [[0, 1]]
-      : tensor<1x5xf32> into tensor<5xf32>
-  %1 = linalg.init_tensor [5, 5] : tensor<5x5xf32>
-  %2 = linalg.generic
-    {indexing_maps = [affine_map<(d0, d1) -> (d0)>,
-                      affine_map<(d0, d1) -> (d0, d1)>],
-     iterator_types = ["parallel", "parallel"]}
-    ins(%0 : tensor<5xf32>) outs(%1 : tensor<5x5xf32>) {
-  ^bb0(%arg2: f32, %arg3: f32):  
-    linalg.yield %arg2 : f32
-  } -> tensor<5x5xf32>
-  return %2 : tensor<5x5xf32>
-}
-//      CHECK: func @unit_dim_reshape_expansion
-//  CHECK-DAG:   tensor.collapse_shape
-//  CHECK-DAG:   linalg.init_tensor
-//      CHECK:   linalg.generic
-
-// -----
-
-func.func @unit_dim_reshape_collapse(%arg0 : tensor<5xf32>) -> tensor<5x1x5xf32> {
-  %0 = linalg.init_tensor [5, 5] : tensor<5x5xf32>
-  %1 = linalg.generic
-    {indexing_maps = [affine_map<(d0, d1) -> (d0)>,
-                      affine_map<(d0, d1) -> (d0, d1)>],
-     iterator_types = ["parallel", "parallel"]}
-    ins(%arg0 : tensor<5xf32>) outs(%0 : tensor<5x5xf32>) {
-  ^bb0(%arg2: f32, %arg3: f32):  
-    linalg.yield %arg2 : f32
-  } -> tensor<5x5xf32>
-  %2 = tensor.expand_shape %1 [[0, 1], [2]]
-    : tensor<5x5xf32> into tensor<5x1x5xf32>
-  return %2 : tensor<5x1x5xf32>
-}
-// CHECK: func @unit_dim_reshape_collapse
-// CHECK:   linalg.init_tensor
-// CHECK:   linalg.generic
-// CHECK:   tensor.expand_shape
-
-// -----
-
-func.func @unit_dim_reshape_expansion_full
-  (%arg0 : tensor<1x?x1x2x1x4xf32>, %arg1 : tensor<?x2x4xf32>)
-  -> tensor<?x2x4xf32> {
-  %c1 = arith.constant 1 : index
-  %0 = tensor.collapse_shape %arg0 [[0, 1, 2], [3, 4], [5]]
-    : tensor<1x?x1x2x1x4xf32> into tensor<?x2x4xf32>
-  %1 = tensor.dim %arg0, %c1 : tensor<1x?x1x2x1x4xf32>
-  %2 = linalg.init_tensor [%1, 2, 4] : tensor<?x2x4xf32>
-  %3 = linalg.generic
-    {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
-                      affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
-                      affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
-     iterator_types = ["parallel", "parallel", "parallel"]}
-    ins(%0, %arg1 : tensor<?x2x4xf32>, tensor<?x2x4xf32>)
-    outs(%2 : tensor<?x2x4xf32>) {
-  ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):  
-    %4 = arith.mulf %arg2, %arg3 : f32
-    linalg.yield %4 : f32
-  } -> tensor<?x2x4xf32>
-  return %3 : tensor<?x2x4xf32>
-}
-//      CHECK: func @unit_dim_reshape_expansion_full
-//  CHECK-DAG:   tensor.collapse_shape
-//  CHECK-DAG:   linalg.init_tensor
-//      CHECK:   linalg.generic
-// CHECK-SAME:     ins(%{{.+}}, %{{.+}} : tensor<?x2x4xf32>, tensor<?x2x4xf32>)
-
-//         FOLDUNITDIM: func @unit_dim_reshape_expansion_full
-//    FOLDUNITDIM-SAME:   %[[ARG0:.+]]: tensor<1x?x1x2x1x4xf32>
-//    FOLDUNITDIM-SAME:   %[[ARG1:.+]]: tensor<?x2x4xf32>
-//     FOLDUNITDIM-DAG:   %[[RESHAPE:.+]] = tensor.expand_shape %[[ARG1]]
-//         FOLDUNITDIM:   linalg.generic
-//    FOLDUNITDIM-SAME:     ins(%[[ARG0]], %[[RESHAPE]] : tensor<1x?x1x2x1x4xf32>, tensor<1x?x1x2x1x4xf32>)
-//    FOLDUNITDIM-SAME:     outs(%{{.+}} : tensor<1x?x1x2x1x4xf32>)
-
-// -----
-
 func.func @no_fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
   %c0 = arith.constant 0 : index
   %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<?x?xf32> into tensor<?xf32>
@@ -554,7 +496,6 @@ func.func @no_fuse_mismatched_dynamism(%arg0: tensor<2x1xi64>, %arg1: tensor<?xi
 // CHECK-SAME:     %[[ARG0:.+]]: tensor<2x1xi64>
 // CHECK-SAME:     %[[ARG1:.+]]: tensor<?xi64>
 //      CHECK:   %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]]
-//      CHECK:   %[[CAST:.+]] = tensor.cast %[[ARG1]] : tensor<?xi64> to tensor<2xi64>
 //      CHECK:   %[[GENERIC:.+]] = linalg.generic
-// CHECK-SAME:       ins(%[[RESHAPE]], %[[CAST]] : tensor<2xi64>, tensor<2xi64>)
+// CHECK-SAME:       ins(%[[RESHAPE]], %[[ARG1]] : tensor<2xi64>, tensor<?xi64>)
 //      CHECK:   return %[[GENERIC]]

diff  --git a/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir
deleted file mode 100644
index 089b30694231f..0000000000000
--- a/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir
+++ /dev/null
@@ -1,287 +0,0 @@
-// RUN: mlir-opt -split-input-file -linalg-fold-reshape-ops-by-linearization %s | FileCheck %s
-
-// Note: These tests fuse the reshape ops by linearization. This can create
-// indexing maps which are hard to analyse later on. These patterns are useful
-// only if the folded dimensions in the reshape op are unit extent. Tests here
-// are more general for testing purposes, but use of these pattern for non-unit
-// dimensions should be deprecated.
-
-#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-func.func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xi32>)
-  -> tensor<?x?x4x?xi32> {
-  %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] :
-    tensor<?x?x?xi32> into tensor<?x?x4x?xi32>
-  %1 = linalg.generic {
-    indexing_maps = [#map0, #map0],
-    iterator_types = ["parallel", "parallel", "parallel", "parallel"] }
-    ins(%0 : tensor<?x?x4x?xi32>)
-    outs(%0 : tensor<?x?x4x?xi32>) {
-  ^bb0(%arg6: i32, %arg7 : i32):       
-    %idx = linalg.index 0 : index
-    %2 = arith.index_cast %idx : index to i32
-    %3 = arith.addi %arg6, %2 : i32
-    linalg.yield %3 : i32
-  } -> tensor<?x?x4x?xi32>
-  return %1 : tensor<?x?x4x?xi32>
-}
-//   CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
-//   CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-//       CHECK: func @generic_op_reshape_producer_fusion
-//  CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?xi32>
-//       CHECK:   %[[T0:.+]] = tensor.expand_shape %[[ARG0]]
-//  CHECK-SAME:     [0], [1, 2], [3]
-//       CHECK:   linalg.generic
-//  CHECK-SAME:     indexing_maps = [#[[MAP3]], #[[MAP4]]]
-//  CHECK-SAME:     ins(%[[ARG0]] : tensor<?x?x?xi32>)
-//  CHECK-SAME:     outs(%[[T0]] : tensor<?x?x4x?xi32>)
-//       CHECK:   %[[IDX:.+]] = linalg.index 0 : index
-//  CHECK-NEXT:   %[[IDX_CASTED:.+]] = arith.index_cast %[[IDX]] : index to i32
-
-// -----
-
-#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-func.func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xi32>)
-  -> tensor<?x?xi32> {
-  %0 = linalg.generic {
-    indexing_maps = [#map0, #map0],
-    iterator_types = ["parallel", "parallel", "parallel", "parallel"] }
-    ins(%arg0 : tensor<?x?x4x5xi32>) outs(%arg0 : tensor<?x?x4x5xi32>) {
-  ^bb0(%arg6: i32, %arg7: i32):       
-    %idx = linalg.index 0 : index
-    %2 = arith.index_cast %idx : index to i32
-    %3 = arith.addi %arg6, %2 : i32
-    linalg.yield %3 : i32
-  } -> tensor<?x?x4x5xi32>
-  %1 = tensor.collapse_shape %0 [[0], [1, 2, 3]] :
-    tensor<?x?x4x5xi32> into tensor<?x?xi32>
-  return %1 : tensor<?x?xi32>
-}
-//   CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-//   CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>
-//       CHECK: func @generic_op_reshape_consumer_fusion
-//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x4x5xi32>
-//       CHECK:   %[[T0:.+]] = tensor.collapse_shape %[[ARG0]]
-//  CHECK-SAME:     [0], [1, 2, 3]
-//       CHECK:   linalg.generic
-//  CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP3]]]
-//  CHECK-SAME:     outs(%[[T0]] : tensor<?x?xi32>)
-//       CHECK:   %[[IDX:.+]] = linalg.index 0 : index
-//  CHECK-NEXT:   %[[IDX_CASTED:.+]] = arith.index_cast %[[IDX]] : index to i32
-//   CHECK-NOT:   tensor.collapse_shape
-
-// -----
-
-#map2 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
-#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-func.func @generic_op_021_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<3x7x5xf32> {
-  %0 = tensor.expand_shape %arg0 [[0], [1, 2]]
-      : tensor<3x35xf32> into tensor<3x5x7xf32>
-  %1 = linalg.init_tensor [3, 7, 5] : tensor<3x7x5xf32>
-  %2 = linalg.generic
-    {indexing_maps = [#map2, #map3],
-     iterator_types = ["parallel", "parallel", "parallel"]}
-    ins(%0 : tensor<3x5x7xf32>) outs(%1 : tensor<3x7x5xf32>) {
-    ^bb0(%arg2: f32, %arg3 : f32):  
-      linalg.yield %arg2 : f32
-    } -> tensor<3x7x5xf32>
-    return %2 : tensor<3x7x5xf32>
-}
-
-//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 + d2 * 7)>
-//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-//       CHECK: func @generic_op_021_permultation_reshape_producer_fusion
-//   CHECK-NOT:   tensor.expand_shape
-//       CHECK:   linalg.generic
-//  CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]]]
-
-// -----
-
-#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
-#map3 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
-func.func @generic_op_120_permutation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x7x3xf32> {
-  %0 = tensor.expand_shape %arg0 [[0], [1, 2]]
-      : tensor<3x35xf32> into tensor<3x5x7xf32>
-  %1 = linalg.init_tensor [5, 7, 3] : tensor<5x7x3xf32>
-  %2 = linalg.generic
-    {indexing_maps = [#map2, #map3],
-     iterator_types = ["parallel", "parallel", "parallel"]}
-    ins(%0 : tensor<3x5x7xf32>) outs(%1 : tensor<5x7x3xf32>) {
-    ^bb0(%arg2: f32, %arg3: f32):  
-      linalg.yield %arg2 : f32
-    } -> tensor<5x7x3xf32>
-    return %2 : tensor<5x7x3xf32>
-}
-
-//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)>
-//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
-//       CHECK: func @generic_op_120_permutation_reshape_producer_fusion
-//   CHECK-NOT:   tensor.expand_shape
-//       CHECK:   linalg.generic
-//  CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]]]
-
-// -----
-
-#map0 = affine_map<(d0, d1, d2) -> (d0)>
-#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
-#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
-#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-func.func @generic_op_102_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x3x7xf32> {
-  %0 = tensor.expand_shape %arg0 [[0], [1, 2]]
-      : tensor<3x35xf32> into tensor<3x5x7xf32>
-  %1 = linalg.init_tensor [5, 3, 7] : tensor<5x3x7xf32>
-  %2 = linalg.generic
-    {indexing_maps = [#map2, #map3],
-     iterator_types = ["parallel", "parallel", "parallel"]}
-    ins(%0 : tensor<3x5x7xf32>) outs(%1 : tensor<5x3x7xf32>) {
-    ^bb0(%arg2: f32, %arg3: f32):  
-      linalg.yield %arg2 : f32
-    } -> tensor<5x3x7xf32>
-    return %2 : tensor<5x3x7xf32>
-}
-
-
-//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)>
-//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-//       CHECK: func @generic_op_102_permultation_reshape_producer_fusion
-//   CHECK-NOT:   tensor.expand_shape
-//       CHECK:   linalg.generic
-//  CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]]]
-
-// -----
-
-#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-#map1 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
-#map2 = affine_map<(d0, d1, d2) -> (d0)>
-#map3 = affine_map<(d0, d1, d2) -> (d1, d2)>
-func.func @generic_op_102_permultation_reshape_consumer_fusion(%arg0 : tensor<3x5x7xf32>) -> tensor<5x21xf32> {
-  %0 = linalg.init_tensor [5, 3, 7] : tensor<5x3x7xf32>
-  %1 = linalg.generic
-    {indexing_maps = [#map0, #map1],
-     iterator_types = ["parallel", "parallel", "parallel"]}
-    ins(%arg0 : tensor<3x5x7xf32>) outs(%0 : tensor<5x3x7xf32>) {
-    ^bb0(%arg2: f32, %arg3 : f32):  
-      linalg.yield %arg2 : f32
-  } -> tensor<5x3x7xf32>
-  %2 = tensor.collapse_shape %1 [[0], [1, 2]]
-      : tensor<5x3x7xf32> into tensor<5x21xf32>
-  return %2 : tensor<5x21xf32>
-}
-//   CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-//   CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)>
-//       CHECK: func @generic_op_102_permultation_reshape_consumer_fusion
-//  CHECK-SAME:   %[[ARG0:.+]]: tensor<3x5x7xf32>
-//       CHECK:   %[[T0:.+]] = linalg.init_tensor [5, 3, 7]
-//       CHECK:   %[[T1:.+]] = tensor.collapse_shape %[[T0]]
-//  CHECK-SAME:     [0], [1, 2]
-//       CHECK:   linalg.generic
-//  CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP3]]]
-//  CHECK-SAME:     ins(%[[ARG0]] : tensor<3x5x7xf32>)
-//  CHECK-SAME:     outs(%[[T1]] : tensor<5x21xf32>)
-
-// -----
-
-#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-func.func @generic_op_reshape_consumer_nofusion(%arg0 : tensor<?x?x?x5xf32>,
-                                           %arg1 : tensor<?x?x?x5xf32>) ->
-                                           tensor<?x?xf32>
-{
-  %0 = linalg.generic {
-     indexing_maps = [#map0, #map0, #map0],
-     iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
-      ins(%arg0, %arg1 : tensor<?x?x?x5xf32>, tensor<?x?x?x5xf32>)
-      outs(%arg0 : tensor<?x?x?x5xf32>) {
-    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):       
-      %1 = arith.mulf %arg3, %arg4 : f32
-      linalg.yield %1 : f32
-  } -> tensor<?x?x?x5xf32>
-  %1 = tensor.collapse_shape %0 [[0], [1, 2, 3]] :
-    tensor<?x?x?x5xf32> into tensor<?x?xf32>
-  return %1 : tensor<?x?xf32>
-}
-// CHECK-LABEL: func @generic_op_reshape_consumer_nofusion
-//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?x5xf32>
-//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?x5xf32>
-//       CHECK:   %[[NOFUSE:.+]] = linalg.generic
-//  CHECK-SAME:     ins(%[[ARG0]], %[[ARG1]]
-//       CHECK:   %[[RESULT:.+]] = tensor.collapse_shape %[[NOFUSE]]
-//       CHECK:   return %[[RESULT]]
-
-
-// -----
-
-func.func @generic_op_permultation_reshape_consumer_fusion_unused_dim(%arg0 : tensor<6x1xf32>) -> tensor<6xi32> {
-  %0 = linalg.init_tensor [6, 1] : tensor<6x1xi32>
-  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
-                                        affine_map<(d0, d1) -> (d0, d1)>],
-   iterator_types = ["parallel", "parallel"]}
-   ins(%arg0 : tensor<6x1xf32>) outs(%0 : tensor<6x1xi32>) {
-    ^bb0(%arg3: f32, %arg4: i32):  
-      %5 = arith.fptosi %arg3 : f32 to i32
-      linalg.yield %5 : i32
-    } -> tensor<6x1xi32>
-    %6 = tensor.collapse_shape %1 [[0, 1]] : tensor<6x1xi32> into tensor<6xi32>
-  return %6 : tensor<6xi32>
-}
-//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
-//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
-//       CHECK: func @generic_op_permultation_reshape_consumer_fusion_unused_dim
-//  CHECK-SAME:   %[[ARG0:.+]]: tensor<6x1xf32>
-//       CHECK:   %[[T0:.+]] = linalg.init_tensor [6, 1]
-//       CHECK:   %[[T1:.+]] = tensor.collapse_shape %[[T0]]
-//  CHECK-SAME:   [0, 1]
-//       CHECK:   linalg.generic
-//  CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]]]
-//  CHECK-SAME:     ins(%[[ARG0]] : tensor<6x1xf32>)
-//  CHECK-SAME:     outs(%[[T1]] : tensor<6xi32>)
-
-// -----
-
-#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2, d4, d0, d6, d3, d5, d1)>
-#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5, d6)>
-func.func @permuted_dims_fusion_expand_shape(%arg0 : tensor<3x8x7x240xf32>) -> tensor<4x6x3x8x2x5x7xf32> {
-  %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6]]
-      : tensor<3x8x7x240xf32> into tensor<3x2x4x7x8x5x6xf32>
-  %1 = linalg.init_tensor [4, 6, 3, 8, 2, 5, 7] : tensor<4x6x3x8x2x5x7xf32>
-  %2 = linalg.generic {
-      indexing_maps = [#map0, #map1],
-      iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]}
-      ins(%0 : tensor<3x2x4x7x8x5x6xf32>) outs(%1 : tensor<4x6x3x8x2x5x7xf32>) {
-      ^bb0(%arg1 : f32, %arg2 : f32):
-        linalg.yield %arg1 : f32
-      } -> tensor<4x6x3x8x2x5x7xf32>
-  return %2 : tensor<4x6x3x8x2x5x7xf32>
-}
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2, d0 + d4 * 4, d6, d1 + d3 * 30 + d5 * 6)>
-//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5, d6)>
-//      CHECK: func @permuted_dims_fusion_expand_shape(
-// CHECK-SAME:     %[[ARG0:.+]]: tensor<3x8x7x240xf32>)
-//      CHECK:   %[[RESULT:.+]] = linalg.generic
-// CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]]]
-// CHECK-SAME:       ins(%[[ARG0]] :
-//      CHECK:   return %[[RESULT]]
-
-// -----
-
-#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2, d4, d0, d6, d3, d5, d1)>
-#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5, d6)>
-func.func @permuted_dims_fusion_collapse_shape(%arg0 : tensor<4x6x3x8x2x5x7xf32>) -> tensor<3x8x7x240xf32> {
-  %0 = linalg.init_tensor [3, 2, 4, 7, 8, 5, 6] : tensor<3x2x4x7x8x5x6xf32>
-  %1 = linalg.generic {
-      indexing_maps = [#map1, #map0],
-      iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]}
-      ins(%arg0 : tensor<4x6x3x8x2x5x7xf32>) outs(%0 : tensor<3x2x4x7x8x5x6xf32>) {
-      ^bb0(%arg1 : f32, %arg2 : f32):
-        linalg.yield %arg1 : f32
-      } -> tensor<3x2x4x7x8x5x6xf32>
-  %2 = tensor.collapse_shape %1 [[0], [1, 2], [3], [4, 5, 6]]
-      : tensor<3x2x4x7x8x5x6xf32> into tensor<3x8x7x240xf32>
-  return %2 : tensor<3x8x7x240xf32>
-}
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5, d6)>
-//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2, d0 + d4 * 4, d6, d1 + d3 * 30 + d5 * 6)>
-//      CHECK: func @permuted_dims_fusion_collapse_shape(
-// CHECK-SAME:     %[[ARG0:.+]]: tensor<4x6x3x8x2x5x7xf32>)
-//      CHECK:   %[[RESULT:.+]] = linalg.generic
-// CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]]]
-// CHECK-SAME:       ins(%[[ARG0]] :
-//      CHECK:   return %[[RESULT]]

diff  --git a/mlir/test/Dialect/Linalg/reshape_linearization_fusion_with_unit_dims.mlir b/mlir/test/Dialect/Linalg/reshape_linearization_fusion_with_unit_dims.mlir
deleted file mode 100644
index 80826057c6bd3..0000000000000
--- a/mlir/test/Dialect/Linalg/reshape_linearization_fusion_with_unit_dims.mlir
+++ /dev/null
@@ -1,52 +0,0 @@
-// RUN: mlir-opt -linalg-fold-reshape-ops-by-linearization=allow-folding-unit-dim-reshapes -split-input-file %s | FileCheck %s
-
-#map = affine_map<(d0, d1) -> (d0, d1)>
-func.func @do_not_fold1(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?x1xf32>
-{
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
-  %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
-  %2 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
-  %3 = linalg.generic {
-      indexing_maps = [#map, #map, #map],
-      iterator_types = ["parallel", "parallel"]}
-      ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
-      outs(%2 : tensor<?x?xf32>) {
-      ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32):
-        %4 = arith.addf %arg2, %arg3 : f32
-        linalg.yield %4 : f32
-      } -> tensor<?x?xf32>
-  %4 = tensor.expand_shape %3 [[0], [1, 2]] : tensor<?x?xf32> into tensor<?x?x1xf32>
-  return %4 : tensor<?x?x1xf32>
-}
-// CHECK-LABEL: func @do_not_fold1
-//       CHECK: %[[VAL:.+]] = linalg.generic
-//       CHECK: tensor.expand_shape %[[VAL]]
-
-// -----
-
-#map = affine_map<(d0, d1) -> (d0, d1)>
-func.func @do_not_fold2(%arg0 : tensor<?x?x1xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32>
-{
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %0 = tensor.collapse_shape %arg0 [[0], [1, 2]] : tensor<?x?x1xf32> into tensor<?x?xf32>
-  %1 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
-  %2 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
-  %3 = linalg.init_tensor [%1, %2] : tensor<?x?xf32>
-  %4 = linalg.generic {
-      indexing_maps = [#map, #map, #map],
-      iterator_types = ["parallel", "parallel"]}
-      ins(%0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
-      outs(%3 : tensor<?x?xf32>) {
-      ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32):
-        %4 = arith.addf %arg2, %arg3 : f32
-        linalg.yield %4 : f32
-      } -> tensor<?x?xf32>
-  return %4 : tensor<?x?xf32>
-}
-// CHECK-LABEL: func @do_not_fold2
-//       CHECK: %[[VAL:.+]] = tensor.collapse_shape
-//       CHECK: linalg.generic
-//  CHECK-SAME:   ins(%[[VAL]], %{{.+}} : tensor<?x?xf32>, tensor<?x?xf32>)

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
index 211ddcfc3730a..ec36b0fe9d266 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
@@ -70,18 +70,18 @@ struct TestLinalgElementwiseFusion
       llvm::cl::desc("Test fusion of generic operations."),
       llvm::cl::init(false)};
 
+  Option<bool> fuseWithReshapeByExpansion{
+      *this, "fuse-with-reshape-by-expansion",
+      llvm::cl::desc(
+          "Test fusion of generic operations with reshape by expansion"),
+      llvm::cl::init(false)};
+
   Option<bool> controlFuseByExpansion{
       *this, "control-fusion-by-expansion",
       llvm::cl::desc(
           "Test controlling fusion of reshape with generic op by expansion"),
       llvm::cl::init(false)};
 
-  Option<bool> pushExpandingReshape{
-      *this, "push-expanding-reshape",
-      llvm::cl::desc("Test linalg expand_shape -> generic "
-                     "to generic -> expand_shape pattern"),
-      llvm::cl::init(false)};
-
   Option<bool> fuseWithReshapeByCollapsing{
       *this, "fuse-with-reshape-by-collapsing",
       llvm::cl::desc("Test linalg expand_shape -> generic fusion patterns that "
@@ -109,6 +109,17 @@ struct TestLinalgElementwiseFusion
       return;
     }
 
+    if (fuseWithReshapeByExpansion) {
+      RewritePatternSet fusionPatterns(context);
+      linalg::populateFoldReshapeOpsByExpansionPatterns(
+          fusionPatterns, [](const OpResult & /*producer*/,
+                             OpOperand & /*consumer*/) { return true; });
+      if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
+                                              std::move(fusionPatterns))))
+        return signalPassFailure();
+      return;
+    }
+
     if (controlFuseByExpansion) {
       RewritePatternSet fusionPatterns(context);
 
@@ -128,8 +139,9 @@ struct TestLinalgElementwiseFusion
                 if (linalgOp && linalgOp.isOutputTensor(&use))
                   return true;
               }
+              return false;
             }
-            return linalg::skipUnitDimReshape(producer, consumer);
+            return true;
           };
 
       linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
@@ -139,12 +151,6 @@ struct TestLinalgElementwiseFusion
       return;
     }
 
-    if (pushExpandingReshape) {
-      RewritePatternSet patterns(context);
-      linalg::populatePushReshapeOpsPatterns(patterns);
-      (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
-    }
-
     if (fuseWithReshapeByCollapsing) {
       RewritePatternSet patterns(context);
       linalg::populateFoldReshapeOpsByCollapsingPatterns(


        


More information about the Mlir-commits mailing list