[Mlir-commits] [mlir] de2568a - [mlir][Linalg] Rethink fusion of linalg ops with reshape ops.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 14 14:00:38 PDT 2020


Author: MaheshRavishankar
Date: 2020-10-14T13:50:31-07:00
New Revision: de2568aab819f4ae97a9d92ea68ef1a8ab56ae8c

URL: https://github.com/llvm/llvm-project/commit/de2568aab819f4ae97a9d92ea68ef1a8ab56ae8c
DIFF: https://github.com/llvm/llvm-project/commit/de2568aab819f4ae97a9d92ea68ef1a8ab56ae8c.diff

LOG: [mlir][Linalg] Rethink fusion of linalg ops with reshape ops.

The current fusion on tensors fuses reshape ops with generic ops by
linearizing the indexing maps of the fused tensor in the generic
op. This has some limitations
- It only works for static shapes
- The resulting indexing map has a linearization that would be
  potentially prevent fusion later on (for ex. tile + fuse).

Instead, try to fuse the reshape consumer (producer) with generic op
producer (consumer) by expanding the dimensionality of the generic op
when the reshape is expanding (folding).  This approach conflicts with
the linearization approach. The expansion method is used instead of
the linearization method.

Further refactoring that changes the fusion on tensors to be a
collection of patterns.

Differential Revision: https://reviews.llvm.org/D89002

Added: 
    mlir/test/Dialect/Linalg/reshape_fusion.mlir
    mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/include/mlir/Dialect/Linalg/Passes.h
    mlir/include/mlir/Dialect/Linalg/Passes.td
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
    mlir/test/Dialect/Linalg/fusion-tensor.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 68e210582d66..23e221d0b237 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -85,7 +85,7 @@ class Linalg_ReshapeLikeOp<string mnemonic, list<OpTrait> traits = []> :
               "ArrayRef<NamedAttribute> attrs = {}", [{
       auto reassociationMaps =
           convertReassociationIndicesToMaps($_builder, reassociation);
-      build($_builder, $_state, src, reassociationMaps, attrs);
+      build($_builder, $_state, resultType, src, reassociationMaps, attrs);
     }]>
   ];
 

diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index 2d6f6f54649f..a0235cf87fdb 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -20,6 +20,7 @@ std::unique_ptr<OperationPass<FuncOp>> createLinalgFoldUnitExtentDimsPass();
 
 std::unique_ptr<OperationPass<FuncOp>> createLinalgFusionPass();
 std::unique_ptr<Pass> createLinalgFusionOfTensorOpsPass();
+std::unique_ptr<Pass> createFoldReshapeOpsByLinearizationPass();
 
 std::unique_ptr<OperationPass<FuncOp>>
 createLinalgTilingPass(ArrayRef<int64_t> tileSizes = {});
@@ -48,6 +49,19 @@ std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgToAffineLoopsPass();
 /// buffers instead.
 std::unique_ptr<OperationPass<ModuleOp>> createLinalgBufferizePass();
 
+/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
+/// producer (consumer) generic operation by expanding the dimensionality of the
+/// loop in the generic op.
+void populateFoldReshapeOpsByExpansionPatterns(
+    MLIRContext *context, OwningRewritePatternList &patterns);
+
+/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
+/// producer (consumer) generic/indexed_generic operation by linearizing the
+/// indexing map used to access the source (target) of the reshape operation in
+/// the generic/indexed_generic operation.
+void populateFoldReshapeOpsByLinearizationPatterns(
+    MLIRContext *context, OwningRewritePatternList &patterns);
+
 /// Patterns for fusing linalg operation on tensors.
 void populateLinalgTensorOpsFusionPatterns(MLIRContext *context,
                                            OwningRewritePatternList &patterns);

diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 7103ee84b7cf..2df2051255c2 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -35,6 +35,14 @@ def LinalgFusionOfTensorOps : Pass<"linalg-fusion-for-tensor-ops"> {
   let dependentDialects = ["linalg::LinalgDialect", "AffineDialect"];
 }
 
+def LinalgFoldReshapeOpsByLinearization :
+  Pass<"linalg-fold-reshape-ops-by-linearization"> {
+  let summary = "Fold TensorReshapeOps with generic/indexed generic ops by "
+                "linearization";
+  let constructor = "mlir::createFoldReshapeOpsByLinearizationPass()";
+  let dependentDialects = ["AffineDialect"];
+}
+
 def LinalgLowerToAffineLoops : FunctionPass<"convert-linalg-to-affine-loops"> {
   let summary = "Lower the operations from the linalg dialect into affine "
                 "loops";

diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index ffcac5f48aa4..61367dd79548 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -91,9 +91,9 @@ Optional<FusionInfo> fuseProducerOf(OpBuilder &b, LinalgOp consumer,
 
 /// Fuse linalg operation on tensors, with the producer of the operand at
 /// position `consumerIdx` of the consumer.
-Operation *fuseTensorOps(PatternRewriter &rewriter, Operation *consumer,
-                         unsigned consumerIdx,
-                         OperationFolder *folder = nullptr);
+Optional<SmallVector<Value, 1>>
+fuseTensorOps(PatternRewriter &rewriter, Operation *consumer,
+              unsigned consumerIdx, OperationFolder *folder = nullptr);
 
 /// Returns the linearized list of all shape dimensions in a `linalgOp`.
 /// Applying the inverse, concatenated loopToOperandRangeMaps to this list

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 7b46348ed0cf..47552c31007f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -514,9 +514,8 @@ struct CollapseReshapeOps : public OpRewritePattern<ReshapeOpTy> {
       return success();
     }
     // Check if producer and consumer are both collapsing dims.
-    else if (areReshapeOpsFoldable(srcReshapeOp.getSrcType(),
-                                   reshapeOp.getSrcType(),
-                                   reshapeOp.getResultType())) {
+    if (areReshapeOpsFoldable(srcReshapeOp.getSrcType(), reshapeOp.getSrcType(),
+                              reshapeOp.getResultType())) {
       rewriter.replaceOpWithNewOp<ReshapeOpTy>(
           reshapeOp, reshapeOp.getResultType(), srcReshapeOp.src(),
           collapseReassociationMaps(srcReshapeOp.getReassociationMaps(),

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index ac57d5f97c1d..52fcd54e13b9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -24,247 +24,240 @@
 using namespace mlir;
 using namespace mlir::linalg;
 
-namespace {
-
 /// Implementation of fusion of generic ops and indexed_generic ops.
-struct FuseGenericOpsOnTensors {
-  static bool isFusible(LinalgOp producer, LinalgOp consumer,
-                        unsigned consumerIdx) {
-    // Producer and consumer must have tensor semantics.
-    if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
-      return false;
-
-    // Verify that
-    // - the producer has all "parallel" iterator type.
-    if (producer.getNumParallelLoops() != producer.getNumLoops())
-      return false;
-
-    // Get the consumer index map. The number of results of the consumer index
-    // map must match the number of loops of the producer.
-    AffineMap consumerIndexMap = consumer.getIndexingMap(consumerIdx);
-    if (consumerIndexMap.getNumResults() != producer.getNumLoops())
-      return false;
-
-    // Finally the index_map for the result must be invertible. For now just
-    // verify it is a permutation.
-    AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
-    return producerResultIndexMap.isPermutation();
-  }
+// struct FuseGenericOpsOnTensors {
+static bool areTensorOpsFusable(LinalgOp producer, LinalgOp consumer,
+                                unsigned consumerIdx) {
+  // Producer and consumer must have tensor semantics.
+  if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
+    return false;
 
-  static LinalgOp fuse(LinalgOp producer, LinalgOp consumer,
-                       unsigned consumerIdx, PatternRewriter &rewriter,
-                       OperationFolder *folder = nullptr) {
-    if (!isFusible(producer, consumer, consumerIdx))
-      return nullptr;
-
-    unsigned numFusedOperands = producer.getOperation()->getNumOperands() +
-                                consumer.getOperation()->getNumOperands() - 1;
-
-    // Compute the fused operands list,
-    SmallVector<Value, 2> fusedOperands;
-    fusedOperands.reserve(numFusedOperands);
-    auto consumerOperands = consumer.getOperation()->getOperands();
-    auto producerOperands = producer.getOperation()->getOperands();
-    fusedOperands.assign(consumerOperands.begin(),
-                         std::next(consumerOperands.begin(), consumerIdx));
-    fusedOperands.append(producerOperands.begin(), producerOperands.end());
-    fusedOperands.append(std::next(consumerOperands.begin(), consumerIdx + 1),
-                         consumerOperands.end());
-
-    // Compute indexing_maps for the fused operation. The indexing_maps for the
-    // operands of the consumers that arent fused are the same. The
-    // indexing_maps for the producers need to be computed based on the
-    // indexing_map of the operand at consumerIdx in the consumer.
-    SmallVector<Attribute, 4> fusedIndexMaps;
-    auto consumerIndexMaps = consumer.indexing_maps();
-    fusedIndexMaps.reserve(fusedOperands.size() +
-                           consumer.getOperation()->getNumResults());
-    fusedIndexMaps.assign(consumerIndexMaps.begin(),
-                          std::next(consumerIndexMaps.begin(), consumerIdx));
-    // Compute indexing maps for the producer args in the fused operation.
-    computeProducerOperandIndex(
-        producer, consumer.getInputIndexingMap(consumerIdx), fusedIndexMaps);
-
-    // Append the indexing maps for the remaining consumer operands.
-    fusedIndexMaps.append(std::next(consumerIndexMaps.begin(), consumerIdx + 1),
-                          consumerIndexMaps.end());
-
-    // Generate the fused op.
-    // Tensor-level fusion is only on ops without initTensors and outputBuffers.
-    LinalgOp fusedOp;
-    if (isa<GenericOp>(producer.getOperation()) &&
-        isa<GenericOp>(consumer.getOperation())) {
-      fusedOp =
-          rewriter
-              .create<GenericOp>(consumer.getLoc(),
-                                 consumer.getOperation()->getResultTypes(),
-                                 /*inputs=*/fusedOperands,
-                                 /*outputBuffers=*/ValueRange{},
-                                 /*initTensors=*/ValueRange{},
-                                 rewriter.getArrayAttr(fusedIndexMaps),
-                                 consumer.iterator_types(),
-                                 /*doc=*/nullptr,
-                                 /*library_call=*/nullptr,
-                                 /*symbol_source=*/nullptr)
-              .getOperation();
-    } else {
-      fusedOp =
-          rewriter
-              .create<IndexedGenericOp>(
-                  consumer.getLoc(), consumer.getOperation()->getResultTypes(),
-                  /*inputs=*/fusedOperands,
-                  /*outputBuffers=*/ValueRange{},
-                  /*initTensors=*/ValueRange{},
-                  rewriter.getArrayAttr(fusedIndexMaps),
-                  consumer.iterator_types(),
-                  /*doc=*/nullptr,
-                  /*library_call=*/nullptr,
-                  /*symbol_source=*/nullptr)
-              .getOperation();
-    }
+  // Verify that
+  // - the producer has all "parallel" iterator type.
+  if (producer.getNumParallelLoops() != producer.getNumLoops())
+    return false;
 
-    // Construct an AffineMap from consumer loops to producer loops.
-    // consumer loop -> tensor index
-    AffineMap consumerResultIndexMap =
-        consumer.getInputIndexingMap(consumerIdx);
-    // producer loop -> tensor index
-    AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
-    // tensor index -> producer loop
-    AffineMap invProducerResultIndexMap =
-        inversePermutation(producerResultIndexMap);
-    assert(invProducerResultIndexMap &&
-           "expected producer result indexig map to be invertible");
-    // consumer loop -> producer loop
-    AffineMap consumerToProducerLoopsMap =
-        invProducerResultIndexMap.compose(consumerResultIndexMap);
-
-    generateFusedRegion(rewriter, fusedOp, producer, consumer,
-                        consumerToProducerLoopsMap, consumerIdx,
-                        consumer.getNumLoops());
-    return fusedOp;
-  }
+  // Get the consumer index map. The number of results of the consumer index
+  // map must match the number of loops of the producer.
+  AffineMap consumerIndexMap = consumer.getIndexingMap(consumerIdx);
+  if (consumerIndexMap.getNumResults() != producer.getNumLoops())
+    return false;
 
-private:
-  /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
-  /// the `producer` to use in the fused operation given the indexing map of the
-  /// result of the producer in the consumer.
-  static void computeProducerOperandIndex(
-      LinalgOp producer, AffineMap fusedConsumerArgIndexMap,
-      SmallVectorImpl<Attribute> &fusedOpIndexingMapAttrs) {
-    // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
-    // from consumer loop -> consumer arg tensor index/producer result tensor
-    // index. The fused loop is same as the consumer loop. For each producer arg
-    // the indexing map to be computed is a map from consumer loop -> producer
-    // arg tensor index.
-
-    AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
-    // producerResultIndexMap is a map from producer loop -> tensor index.
-    // Compute the inverse to get map from tensor index -> producer loop.
-    // The inverse is a map from producer result tensor index -> producer loop.
-    AffineMap invProducerResultIndexMap =
-        inversePermutation(producerResultIndexMap);
-    assert(invProducerResultIndexMap &&
-           "expected producer result indexig map to be invertible");
-    for (unsigned argNum : llvm::seq<unsigned>(0, producer.getNumInputs())) {
-      // argMap is a map from producer loop -> producer arg tensor index.
-      AffineMap argMap = producer.getInputIndexingMap(argNum);
-
-      // Compose argMap with invProducerResultIndexMap to get a map from
-      // producer result tensor index -> producer arg tensor index.
-      AffineMap t1 = argMap.compose(invProducerResultIndexMap);
-
-      // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
-      // consumer loop/ fused loop -> producer arg tensor index.
-      AffineMap indexingMap = t1.compose(fusedConsumerArgIndexMap);
-      fusedOpIndexingMapAttrs.push_back(AffineMapAttr::get(indexingMap));
-    }
+  // Finally the index_map for the result must be invertible. For now just
+  // verify it is a permutation.
+  AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
+  return producerResultIndexMap.isPermutation();
+}
+
+/// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
+/// the `producer` to use in the fused operation given the indexing map of the
+/// result of the producer in the consumer.
+static void getIndexingMapOfProducerOperandsInFusedOp(
+    LinalgOp producer, AffineMap fusedConsumerArgIndexMap,
+    SmallVectorImpl<Attribute> &fusedOpIndexingMapAttrs) {
+  // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
+  // from consumer loop -> consumer arg tensor index/producer result tensor
+  // index. The fused loop is same as the consumer loop. For each producer arg
+  // the indexing map to be computed is a map from consumer loop -> producer
+  // arg tensor index.
+
+  AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
+  // producerResultIndexMap is a map from producer loop -> tensor index.
+  // Compute the inverse to get map from tensor index -> producer loop.
+  // The inverse is a map from producer result tensor index -> producer loop.
+  AffineMap invProducerResultIndexMap =
+      inversePermutation(producerResultIndexMap);
+  assert(invProducerResultIndexMap &&
+         "expected producer result indexig map to be invertible");
+  for (unsigned argNum : llvm::seq<unsigned>(0, producer.getNumInputs())) {
+    // argMap is a map from producer loop -> producer arg tensor index.
+    AffineMap argMap = producer.getInputIndexingMap(argNum);
+
+    // Compose argMap with invProducerResultIndexMap to get a map from
+    // producer result tensor index -> producer arg tensor index.
+    AffineMap t1 = argMap.compose(invProducerResultIndexMap);
+
+    // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
+    // consumer loop/ fused loop -> producer arg tensor index.
+    AffineMap indexingMap = t1.compose(fusedConsumerArgIndexMap);
+    fusedOpIndexingMapAttrs.push_back(AffineMapAttr::get(indexingMap));
   }
+}
 
-  /// Generate the region of the fused operation. The region of the fused op
-  /// must be empty.
-  static void generateFusedRegion(PatternRewriter &rewriter, Operation *fusedOp,
-                                  LinalgOp producer, LinalgOp consumer,
-                                  AffineMap consumerToProducerLoopsMap,
-                                  unsigned consumerIdx, unsigned nloops) {
-    // Build the region of the fused op.
-    Block &producerBlock = producer.getOperation()->getRegion(0).front();
-    Block &consumerBlock = consumer.getOperation()->getRegion(0).front();
-    Block *fusedBlock = new Block();
-    fusedOp->getRegion(0).push_back(fusedBlock);
-    BlockAndValueMapping mapper;
-    OpBuilder::InsertionGuard guard(rewriter);
-    rewriter.setInsertionPointToStart(fusedBlock);
-
-    // The block arguments are
-    // [index_0, index_1, ... ,
-    //   consumer_operand_0, ... , consumer_operand_(`consumerIdx`-1),
-    //   producer_operand_0, ... , producer_operand_(n-1)],
-    //   consumer_operand_(`consumerIdx`), .. consumer_operand_(m-1)]
-    // , where n is the number of producer's operand and m is the number
-    // consumer's operand.
-    // If both `numProducerIndices` and `numConsumerIndices` are zero, this is a
-    // generic op. In this case, there are no indices in block arguments.
-    unsigned numProducerIndices =
-        isa<IndexedGenericOp>(producer.getOperation()) ? nloops : 0;
-    unsigned numConsumerIndices =
-        isa<IndexedGenericOp>(consumer.getOperation()) ? nloops : 0;
-    // Firstly, add all the indices to the block arguments.
-    for (unsigned i = 0, e = std::max(numProducerIndices, numConsumerIndices);
-         i < e; ++i)
-      fusedBlock->addArgument(rewriter.getIndexType());
-    // Map the arguments for the unmodified args from the consumer.
-    for (auto consumerArg : llvm::enumerate(consumerBlock.getArguments())) {
-      if (consumerArg.index() == consumerIdx + numConsumerIndices) {
-        // Map the arguments for the args from the producer.
-        for (auto producerArg : llvm::enumerate(producerBlock.getArguments())) {
-          // If producer is an indexed_generic op, map the indices from consumer
-          // loop to producer loop (because the fusedOp is built based on
-          // consumer's perspective).
-          if (producerArg.index() < numProducerIndices) {
-            auto newIndex = rewriter.create<mlir::AffineApplyOp>(
-                producer.getLoc(),
-                consumerToProducerLoopsMap.getSubMap(producerArg.index()),
-                fusedBlock->getArguments().take_front(nloops));
-            mapper.map(producerArg.value(), newIndex);
-          } else {
-            mapper.map(producerArg.value(),
-                       fusedBlock->addArgument(producerArg.value().getType()));
-          }
+/// Generate the region of the fused tensor operation. The region of the fused
+/// op must be empty.
+static void generateFusedTensorOpRegion(PatternRewriter &rewriter,
+                                        Operation *fusedOp, LinalgOp producer,
+                                        LinalgOp consumer,
+                                        AffineMap consumerToProducerLoopsMap,
+                                        unsigned consumerIdx, unsigned nloops) {
+  // Build the region of the fused op.
+  Block &producerBlock = producer.getOperation()->getRegion(0).front();
+  Block &consumerBlock = consumer.getOperation()->getRegion(0).front();
+  Block *fusedBlock = new Block();
+  fusedOp->getRegion(0).push_back(fusedBlock);
+  BlockAndValueMapping mapper;
+  OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPointToStart(fusedBlock);
+
+  // The block arguments are
+  // [index_0, index_1, ... ,
+  //   consumer_operand_0, ... , consumer_operand_(`consumerIdx`-1),
+  //   producer_operand_0, ... , producer_operand_(n-1)],
+  //   consumer_operand_(`consumerIdx`), .. consumer_operand_(m-1)]
+  // , where n is the number of producer's operand and m is the number
+  // consumer's operand.
+  // If both `numProducerIndices` and `numConsumerIndices` are zero, this is a
+  // generic op. In this case, there are no indices in block arguments.
+  unsigned numProducerIndices =
+      isa<IndexedGenericOp>(producer.getOperation()) ? nloops : 0;
+  unsigned numConsumerIndices =
+      isa<IndexedGenericOp>(consumer.getOperation()) ? nloops : 0;
+  // Firstly, add all the indices to the block arguments.
+  for (unsigned i = 0, e = std::max(numProducerIndices, numConsumerIndices);
+       i < e; ++i)
+    fusedBlock->addArgument(rewriter.getIndexType());
+  // Map the arguments for the unmodified args from the consumer.
+  for (auto consumerArg : llvm::enumerate(consumerBlock.getArguments())) {
+    if (consumerArg.index() == consumerIdx + numConsumerIndices) {
+      // Map the arguments for the args from the producer.
+      for (auto producerArg : llvm::enumerate(producerBlock.getArguments())) {
+        // If producer is an indexed_generic op, map the indices from consumer
+        // loop to producer loop (because the fusedOp is built based on
+        // consumer's perspective).
+        if (producerArg.index() < numProducerIndices) {
+          auto newIndex = rewriter.create<mlir::AffineApplyOp>(
+              producer.getLoc(),
+              consumerToProducerLoopsMap.getSubMap(producerArg.index()),
+              fusedBlock->getArguments().take_front(nloops));
+          mapper.map(producerArg.value(), newIndex);
+        } else {
+          mapper.map(producerArg.value(),
+                     fusedBlock->addArgument(producerArg.value().getType()));
         }
-        continue;
       }
+      continue;
+    }
 
-      // If consumer is an indexed_generic op, map the indices to the block
-      // arguments directly. Otherwise, add the same type of arugment and map to
-      // it.
-      if (consumerArg.index() < numConsumerIndices) {
-        mapper.map(consumerArg.value(),
-                   fusedBlock->getArgument(consumerArg.index()));
-      } else {
-        mapper.map(consumerArg.value(),
-                   fusedBlock->addArgument(consumerArg.value().getType()));
-      }
+    // If consumer is an indexed_generic op, map the indices to the block
+    // arguments directly. Otherwise, add the same type of arugment and map to
+    // it.
+    if (consumerArg.index() < numConsumerIndices) {
+      mapper.map(consumerArg.value(),
+                 fusedBlock->getArgument(consumerArg.index()));
+    } else {
+      mapper.map(consumerArg.value(),
+                 fusedBlock->addArgument(consumerArg.value().getType()));
     }
+  }
 
-    // Add operations from producer (except the yield operation) to the fused
-    // op.
-    for (auto &op : producerBlock.getOperations()) {
-      if (auto yieldOp = dyn_cast<linalg::YieldOp>(op)) {
-        // Lookup the value the yield operation is mapped to.
-        Value yieldVal = yieldOp.getOperand(0);
-        if (Value clonedVal = mapper.lookupOrNull(yieldVal))
-          mapper.map(
-              consumerBlock.getArgument(consumerIdx + numConsumerIndices),
-              clonedVal);
-        continue;
-      }
-      rewriter.clone(op, mapper);
+  // Add operations from producer (except the yield operation) to the fused
+  // op.
+  for (auto &op : producerBlock.getOperations()) {
+    if (auto yieldOp = dyn_cast<linalg::YieldOp>(op)) {
+      // Lookup the value the yield operation is mapped to.
+      Value yieldVal = yieldOp.getOperand(0);
+      if (Value clonedVal = mapper.lookupOrNull(yieldVal))
+        mapper.map(consumerBlock.getArgument(consumerIdx + numConsumerIndices),
+                   clonedVal);
+      continue;
     }
-    for (auto &op : consumerBlock.getOperations())
-      rewriter.clone(op, mapper);
+    rewriter.clone(op, mapper);
   }
-};
-} // namespace
+  for (auto &op : consumerBlock.getOperations())
+    rewriter.clone(op, mapper);
+}
+
+static Optional<SmallVector<Value, 1>>
+fuseTensorOpsImpl(LinalgOp producer, LinalgOp consumer, unsigned consumerIdx,
+                  PatternRewriter &rewriter,
+                  OperationFolder *folder = nullptr) {
+  if (!areTensorOpsFusable(producer, consumer, consumerIdx))
+    return llvm::None;
+
+  unsigned numFusedOperands =
+      producer.getNumInputs() + consumer.getNumInputs() - 1;
+
+  // Compute the fused operands list,
+  SmallVector<Value, 2> fusedOperands;
+  fusedOperands.reserve(numFusedOperands);
+  auto consumerOperands = consumer.getInputs();
+  auto producerOperands = producer.getInputs();
+  fusedOperands.assign(consumerOperands.begin(),
+                       std::next(consumerOperands.begin(), consumerIdx));
+  fusedOperands.append(producerOperands.begin(), producerOperands.end());
+  fusedOperands.append(std::next(consumerOperands.begin(), consumerIdx + 1),
+                       consumerOperands.end());
+
+  // Compute indexing_maps for the fused operation. The indexing_maps for the
+  // operands of the consumers that arent fused are the same. The
+  // indexing_maps for the producers need to be computed based on the
+  // indexing_map of the operand at consumerIdx in the consumer.
+  SmallVector<Attribute, 4> fusedIndexMaps;
+  auto consumerIndexMaps = consumer.indexing_maps();
+  fusedIndexMaps.reserve(fusedOperands.size() + consumer.getNumOutputs());
+  fusedIndexMaps.assign(consumerIndexMaps.begin(),
+                        std::next(consumerIndexMaps.begin(), consumerIdx));
+  // Compute indexing maps for the producer args in the fused operation.
+  getIndexingMapOfProducerOperandsInFusedOp(
+      producer, consumer.getInputIndexingMap(consumerIdx), fusedIndexMaps);
+
+  // Append the indexing maps for the remaining consumer operands.
+  fusedIndexMaps.append(std::next(consumerIndexMaps.begin(), consumerIdx + 1),
+                        consumerIndexMaps.end());
+
+  // Generate the fused op.
+  // Tensor-level fusion is only on ops without initTensors and outputBuffers.
+  LinalgOp fusedOp;
+  if (isa<GenericOp>(producer.getOperation()) &&
+      isa<GenericOp>(consumer.getOperation())) {
+    fusedOp = rewriter
+                  .create<GenericOp>(consumer.getLoc(),
+                                     consumer.getOperation()->getResultTypes(),
+                                     /*inputs=*/fusedOperands,
+                                     /*outputBuffers=*/ValueRange{},
+                                     /*initTensors=*/ValueRange{},
+                                     rewriter.getArrayAttr(fusedIndexMaps),
+                                     consumer.iterator_types(),
+                                     /*doc=*/nullptr,
+                                     /*library_call=*/nullptr,
+                                     /*symbol_source=*/nullptr)
+                  .getOperation();
+  } else {
+    fusedOp =
+        rewriter
+            .create<IndexedGenericOp>(consumer.getLoc(),
+                                      consumer.getOperation()->getResultTypes(),
+                                      /*inputs=*/fusedOperands,
+                                      /*outputBuffers=*/ValueRange{},
+                                      /*initTensors=*/ValueRange{},
+                                      rewriter.getArrayAttr(fusedIndexMaps),
+                                      consumer.iterator_types(),
+                                      /*doc=*/nullptr,
+                                      /*library_call=*/nullptr,
+                                      /*symbol_source=*/nullptr)
+            .getOperation();
+  }
+
+  // Construct an AffineMap from consumer loops to producer loops.
+  // consumer loop -> tensor index
+  AffineMap consumerResultIndexMap = consumer.getInputIndexingMap(consumerIdx);
+  // producer loop -> tensor index
+  AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
+  // tensor index -> producer loop
+  AffineMap invProducerResultIndexMap =
+      inversePermutation(producerResultIndexMap);
+  assert(invProducerResultIndexMap &&
+         "expected producer result indexig map to be invertible");
+  // consumer loop -> producer loop
+  AffineMap consumerToProducerLoopsMap =
+      invProducerResultIndexMap.compose(consumerResultIndexMap);
+
+  generateFusedTensorOpRegion(rewriter, fusedOp.getOperation(), producer,
+                              consumer, consumerToProducerLoopsMap, consumerIdx,
+                              consumer.getNumLoops());
+  return SmallVector<Value, 1>(fusedOp.getOperation()->getResults());
+}
 
 /// Linearize the expressions in `sourceMap` based on the `reassociationMaps`
 /// provided, given the shape of the source tensor that corresponds to the
@@ -313,18 +306,21 @@ static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
 /// Checks if the `reshapeOp` can be fused with it consumer (if `asProducer` is
 /// true) or its producer (if `asProducer` is false) given the indexing map at
 /// its use.
-static bool isTensorReshapeOpFusible(TensorReshapeOp reshapeOp,
-                                     AffineMap useIndexMap, bool asProducer) {
+static bool isTensorReshapeOpFoldableByLinearization(TensorReshapeOp reshapeOp,
+                                                     AffineMap useIndexMap,
+                                                     bool asProducer) {
   RankedTensorType returnType = reshapeOp.getResultType();
   RankedTensorType operandType = reshapeOp.getSrcType();
-  // Reshape is fusible with its consumer (i.e. reshape as a producer) when its
+  // Reshape is fusable with its consumer (i.e. reshape as a producer) when its
   // operand is of lesser rank than the result. Fusing when operand has higher
   // rank will require use of mods and divs in the indexing maps of the fused op
   // which would make it non-invertible. Similarly reshape is fused with its
   // producer (i.e. reshape as consumer) only if the return type has lesser
   // rank.
-  if ((asProducer && returnType.getRank() < operandType.getRank()) ||
-      (!asProducer && operandType.getRank() < returnType.getRank()))
+  if ((asProducer && reshapeOp.getSrcType().hasStaticShape() &&
+       returnType.getRank() < operandType.getRank()) ||
+      (!asProducer && reshapeOp.getResultType().hasStaticShape() &&
+       operandType.getRank() < returnType.getRank()))
     return false;
   return useIndexMap.isPermutation();
 }
@@ -346,314 +342,533 @@ static LinalgOp createLinalgOpOfSameType(LinalgOp op, PatternRewriter &rewriter,
   return nullptr;
 }
 
-namespace {
+/// Conditions for folding a generic/indexed-generic operation with a reshape op
+/// by expanding the iteration space dimensionality for tensor operations. These
+/// are preconditions assumed by `foldReshapeByDimExpansion` which implements
+/// the following fusion pattern.
+///
+///  Consider
+///
+///  %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
+///         indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
+///                          affine_map<(d0, d1, d2) -> (d1, d2)>,
+///                          affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
+///  %d = linalg.tensor_reshape %c
+///         [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>,
+///          affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>,
+///          affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>]
+///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
+///
+///  The reshape can be folded into the `linalgOp` if the
+///  generic/indexed-generic op loop dimensionality is increased to match the
+///  result (operand) of the tensor_reshape when the reshape is expanding
+///  (folding). The indexing_map of the fused tensor in the `linalgOp` and the
+///  reassociation map helps compute the indexing maps of the modified op. For
+///  the above example, based on the reassociation map it can be concluded that
+///
+///  - The loop used to access the first dimension of the fused tensor is split
+///    into two.
+///  - The loop used to access the second dimension of the fused tensor is kept
+///    as is.
+///  - The loop used to access the third dimension of the fused tensor is split
+///    into three.
+///
+///  i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
+///  op, then
+///
+///   d0 -> e0, e1
+///   d1 -> e2, e3, e4
+///   d2 -> e5
+///
+///  substituting this, the generic op can be rewritten as
+///
+///  %d = linalg.generic ins(%0, %1 : )
+///        indexing_maps =
+///         [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
+///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
+///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
+///
+///  Since operands to the linalg generic are now 5D, reshapes can be introduced
+///  to make it consistent
+///
+///  %0 = linalg.tensor_reshape %a
+///         [affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e2),
+///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e3, e4),
+///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e5)]
+///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
+///  %1 = linalg.tensor_reshape %b
+///         [affine_map<(e0, e1, e2, e3) -> (e0, e1, e2),
+///          affine_map<(e0, e1, e2, e3) -> (e3)]
+///       : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
+///
+///  The added reshapes are again expanding patterns, so they will get fused
+///  with its producers if possible.
+static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
+                                               unsigned fusedTensorIndex) {
+  // Is fusable only if:
+  // - The linalgOp is a generic op.
+  // - All the indexing maps for operands in linalgOp are projected
+  //   permutations.
+  // - The indexing map at the position representing the fused tensor is a
+  //   permutation.
+  // - All the loops in linalgOp are parallel loops.
+  return isa<GenericOp>(linalgOp.getOperation()) &&
+         linalgOp.hasTensorSemantics() &&
+         llvm::all_of(linalgOp.indexing_maps().getValue().take_front(
+                          linalgOp.getNumInputs()),
+                      [](Attribute attr) {
+                        return attr.cast<AffineMapAttr>()
+                            .getValue()
+                            .isProjectedPermutation();
+                      }) &&
+         linalgOp.getIndexingMap(fusedTensorIndex).isPermutation() &&
+         llvm::all_of(linalgOp.iterator_types(), [](Attribute attr) {
+           return attr.cast<StringAttr>().getValue() ==
+                  getParallelIteratorTypeName();
+         });
+}
 
-/// Implementation of fusion on tensor ops when producer is a TensorReshapeOp.
-struct FuseTensorReshapeOpAsProducer {
-  static bool isFusible(TensorReshapeOp producer, LinalgOp consumer,
-                        unsigned consumerIdx) {
-    return isa<GenericOp, IndexedGenericOp>(consumer.getOperation()) &&
-           consumer.hasTensorSemantics() &&
-           isTensorReshapeOpFusible(producer,
-                                    consumer.getInputIndexingMap(consumerIdx),
-                                    /*asProducer=*/true);
+/// Implements the fusion of a tensor_reshape op and a generic/indexed_generic
+/// op as explained in `isFusableWithReshapeByExpansion`. Assumes that those
+/// conditions have been satisfied.
+static Optional<SmallVector<Value, 1>>
+fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
+                           unsigned fusedTensorIndex, PatternRewriter &rewriter,
+                           OperationFolder *folder = nullptr) {
+  assert(isFusableWithReshapeByDimExpansion(linalgOp, fusedTensorIndex) &&
+         "preconditions for fuse operation failed");
+  // Check if reshape is expanding or collapsing.
+  bool isExpanding =
+      reshapeOp.getSrcType().getRank() < reshapeOp.getResultType().getRank();
+  RankedTensorType expandedType =
+      isExpanding ? reshapeOp.getResultType() : reshapeOp.getSrcType();
+  RankedTensorType foldedType =
+      isExpanding ? reshapeOp.getSrcType() : reshapeOp.getResultType();
+  AffineMap fusedIndexMap = linalgOp.getIndexingMap(fusedTensorIndex);
+
+  // The reshape is folding/expanding consecutive dimensions. Given the indexing
+  // map of the fused tensor find the number of dimensions each of the loops of
+  // the original op is expanded into. Also record the shape of the expanded
+  // dimensions.
+  ArrayRef<int64_t> expandedShape = expandedType.getShape();
+  SmallVector<unsigned, 4> numFoldedDims(foldedType.getRank(), 0);
+  SmallVector<SmallVector<int64_t, 4>, 4> expandedDimsShape(
+      expandedType.getRank());
+  auto reassociationMaps = reshapeOp.getReassociationMaps();
+  for (auto resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
+    unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition();
+    AffineMap foldedDims = reassociationMaps[resultExpr.index()];
+    numFoldedDims[pos] = foldedDims.getNumResults();
+    ArrayRef<int64_t> shape = expandedShape.slice(
+        foldedDims.getResult(0).cast<AffineDimExpr>().getPosition(),
+        numFoldedDims[pos]);
+    expandedDimsShape[pos].assign(shape.begin(), shape.end());
   }
 
-  static LinalgOp fuse(TensorReshapeOp producer, LinalgOp consumer,
-                       unsigned consumerIdx, PatternRewriter &rewriter,
-                       OperationFolder *folder = nullptr) {
-    if (producer.src().getDefiningOp<ConstantOp>())
-      return nullptr;
+  // The remapping of the indices is then the prefix sum (inclusive) of the
+  // numFoldedDims.
+  SmallVector<unsigned, 4> remapping(numFoldedDims.size() + 1, 0);
+  unsigned sum = 0;
+  for (auto numFoldedDim : llvm::enumerate(numFoldedDims)) {
+    sum += numFoldedDim.value();
+    remapping[numFoldedDim.index() + 1] = sum;
+  }
 
-    if (!isFusible(producer, consumer, consumerIdx))
-      return nullptr;
+  SmallVector<AffineMap, 4> expandedOpIndexingMaps;
+  // Compute the modified indexing maps by replacing every loop (AffineDimExpr)
+  // in the original indexing map with the sequence of loops that it is expanded
+  // to.
+  for (AffineMap indexingMap : linalgOp.getIndexingMaps()) {
+    SmallVector<AffineExpr, 4> newExprs;
+    for (AffineExpr expr : indexingMap.getResults()) {
+      unsigned pos = expr.cast<AffineDimExpr>().getPosition();
+      for (unsigned newPos :
+           llvm::seq<unsigned>(remapping[pos], remapping[pos + 1])) {
+        newExprs.push_back(rewriter.getAffineDimExpr(newPos));
+      }
+    }
+    expandedOpIndexingMaps.push_back(
+        AffineMap::get(remapping.back(), indexingMap.getNumSymbols(), newExprs,
+                       rewriter.getContext()));
+  }
 
-    // Compute the fused operands list,
-    Operation *consumerOp = consumer.getOperation();
-    SmallVector<Value, 2> fusedOperands(consumerOp->getOperands());
-    fusedOperands[consumerIdx] = producer.src();
+  // The operands of the expanded op are computed by reshaping the original
+  // operands. The reshape depends on the ordering of the loop used to access
+  // the tensor in the original operation, and are expanded into as many
+  // dimensions as the loop is expanded into (as computed by `remapping`).
+  auto getReshapeInfo =
+      [&](AffineMap operandIndexingMap,
+          SmallVectorImpl<ReassociationIndices> &reassociation,
+          SmallVectorImpl<int64_t> &expandedOpOperandShape) {
+        unsigned reshapeDims = 0;
+        for (AffineExpr expr : operandIndexingMap.getResults()) {
+          unsigned origDim = expr.cast<AffineDimExpr>().getPosition();
+          auto foldedDims = llvm::seq<int64_t>(
+              reshapeDims, reshapeDims + numFoldedDims[origDim]);
+          reassociation.emplace_back(foldedDims.begin(), foldedDims.end());
+          expandedOpOperandShape.append(expandedDimsShape[origDim].begin(),
+                                        expandedDimsShape[origDim].end());
+          reshapeDims += numFoldedDims[origDim];
+        }
+      };
+  SmallVector<Value, 4> expandedOpOperands;
+  for (auto operand : llvm::enumerate(linalgOp.getInputs())) {
+    if (operand.index() == fusedTensorIndex) {
+      expandedOpOperands.push_back(reshapeOp.src());
+      continue;
+    }
+    AffineMap indexingMap = linalgOp.getIndexingMap(operand.index());
+    SmallVector<ReassociationIndices, 4> reassociation;
+    SmallVector<int64_t, 4> expandedOperandShape;
+    getReshapeInfo(indexingMap, reassociation, expandedOperandShape);
+    Type expandedOperandType = RankedTensorType::get(
+        expandedOperandShape,
+        operand.value().getType().cast<ShapedType>().getElementType());
+    if (expandedOperandType != operand.value().getType()) {
+      expandedOpOperands.push_back(rewriter.create<TensorReshapeOp>(
+          linalgOp.getLoc(), expandedOperandType, operand.value(),
+          reassociation));
+    } else {
+      expandedOpOperands.push_back(operand.value());
+    }
+  }
+  SmallVector<Type, 1> resultTypes;
+  SmallVector<SmallVector<ReassociationIndices, 4>, 1> resultReassociation;
+  for (auto result : llvm::enumerate(linalgOp.getOperation()->getResults())) {
+    AffineMap indexingMap =
+        linalgOp.getIndexingMap(linalgOp.getNumInputs() + result.index());
+    SmallVector<ReassociationIndices, 4> reassociation;
+    SmallVector<int64_t, 4> expandedResultShape;
+    getReshapeInfo(indexingMap, reassociation, expandedResultShape);
+    resultTypes.push_back(RankedTensorType::get(
+        expandedResultShape,
+        result.value().getType().cast<ShapedType>().getElementType()));
+    resultReassociation.emplace_back(std::move(reassociation));
+  }
 
-    // Compute indexing_maps for the fused operation. The indexing_maps for the
-    // operands of the consumers that arent fused are the same.
-    SmallVector<AffineMap, 4> fusedIndexMaps =
-        llvm::to_vector<4>(llvm::map_range(
-            consumer.indexing_maps(), [](Attribute attr) -> AffineMap {
-              return attr.cast<AffineMapAttr>().getValue();
-            }));
+  // The iterator types of the expanded op are all parallel.
+  SmallVector<StringRef, 4> iteratorTypes(remapping.back(),
+                                          getParallelIteratorTypeName());
+
+  LinalgOp fusedOp = createLinalgOpOfSameType(
+      linalgOp, rewriter, linalgOp.getLoc(), resultTypes,
+      /*inputs=*/expandedOpOperands,
+      /*outputBuffers=*/ValueRange{},
+      /*initTensors=*/ValueRange{}, expandedOpIndexingMaps, iteratorTypes);
+  Region &fusedRegion = fusedOp.getOperation()->getRegion(0);
+  // TODO: Add support for indexed generic op, which would need mapping the
+  // expanded dimensions to the original dimension arguments.
+  rewriter.cloneRegionBefore(linalgOp.getOperation()->getRegion(0), fusedRegion,
+                             fusedRegion.begin());
+
+  // Reshape the result values to their original shape if this is a collapsing
+  // reshape folded into its consumer.
+  SmallVector<Value, 1> resultVals;
+  for (auto result : llvm::enumerate(linalgOp.getOperation()->getResults())) {
+    if (!isExpanding &&
+        resultTypes[result.index()] != result.value().getType()) {
+      resultVals.push_back(rewriter.create<TensorReshapeOp>(
+          linalgOp.getLoc(), result.value().getType(),
+          fusedOp.getOperation()->getResult(result.index()),
+          resultReassociation[result.index()]));
+    } else {
+      resultVals.push_back(fusedOp.getOperation()->getResult(result.index()));
+    }
+  }
+  // Assuming a single result.
+  return resultVals;
+}
 
-    // Accepted consumer maps are either identity or permutation.
-    auto invMap = inversePermutation(fusedIndexMaps[consumerIdx]);
+namespace {
 
-    // Compute the indexing map to use for the operand of the producer.
-    AffineMap modifiedMap =
-        linearizeCollapsedDims(invMap, producer.getResultType().getShape(),
-                               producer.getReassociationMaps());
-    for (AffineExpr expr : modifiedMap.getResults()) {
-      if (!expr.isPureAffine())
-        return nullptr;
-    }
-    fusedIndexMaps[consumerIdx] = modifiedMap;
+/// Pattern to fold tensor_reshape op with its consumer by using the source of
+/// the reshape op as the operand in the consumer (instead of the result of the
+/// tensor_reshapeop) when the tensor_reshape op is collapsing. The
+/// corresponding index map in the consumer needs to be modified to linearize
+/// the folded dimension.
+///
+/// For example,
+///
+/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+/// %0 = linalg.tensor_reshape %arg0
+///        [affine_map<(i, j, k, l) -> (i)>, affine_map<(i, j, k, l) -> (j, k)>,
+///         affine_map<(i, j, k, l) -> (l)>]
+///      tensor<?x?x?xf32> into tensor<?x?x4x?xf32>
+/// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... }
+///        ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) ...
+///        -> tensor<?x?x4x?xf32>
+///
+/// can be folded into
+///
+/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
+/// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+/// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... }
+///        ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ...
+///        -> tensor<?x?x4x?xf32>
+template <typename LinalgOpTy>
+struct FoldProducerReshapeOpByLinearization
+    : public OpRewritePattern<LinalgOpTy> {
+  using OpRewritePattern<LinalgOpTy>::OpRewritePattern;
 
-    // Further check that the resulting index maps can be fused and
-    // inverted. Without this the resultant op is not legal.
-    if (!inversePermutation(concatAffineMaps(fusedIndexMaps)))
-      return nullptr;
+  LogicalResult matchAndRewrite(LinalgOpTy op,
+                                PatternRewriter &rewriter) const override {
+    if (!op.hasTensorSemantics())
+      return failure();
+    LinalgOp linalgOp = cast<LinalgOp>(op.getOperation());
+    for (auto operand : llvm::enumerate(linalgOp.getInputs())) {
+      TensorReshapeOp reshapeOp =
+          operand.value().getDefiningOp<TensorReshapeOp>();
+      if (!reshapeOp ||
+          !isTensorReshapeOpFoldableByLinearization(
+              reshapeOp, linalgOp.getInputIndexingMap(operand.index()),
+              /*asProducer =*/true))
+        continue;
 
-    SmallVector<Attribute, 4> indexMapAttrs = llvm::to_vector<4>(
-        llvm::map_range(fusedIndexMaps, [](AffineMap map) -> Attribute {
-          return AffineMapAttr::get(map);
-        }));
-    LinalgOp fusedOp = createLinalgOpOfSameType(
-        consumer, rewriter, rewriter.getUnknownLoc(),
-        consumerOp->getResultTypes(),
-        /*inputs=*/fusedOperands,
-        /*outputBuffers=*/ValueRange{},
-        /*initTensors=*/ValueRange{}, // no init tensors for now.
-        rewriter.getArrayAttr(indexMapAttrs), consumer.iterator_types(),
-        /*doc=*/nullptr,
-        /*library_call=*/nullptr,
-        /*symbol_source=*/nullptr);
-    auto &fusedRegion = fusedOp.getOperation()->getRegion(0);
-    rewriter.cloneRegionBefore(consumerOp->getRegion(0), fusedRegion,
-                               fusedRegion.begin());
-    return fusedOp;
+      // Compute the fused operands list,
+      SmallVector<Value, 2> fusedOperands(linalgOp.getInputs());
+      fusedOperands[operand.index()] = reshapeOp.src();
+
+      // Compute indexing_maps for the fused operation. The indexing_maps for
+      // the operands of the consumers that arent fused are the same.
+      SmallVector<AffineMap, 4> fusedIndexMaps = llvm::to_vector<4>(
+          op.indexing_maps().template getAsValueRange<AffineMapAttr>());
+
+      // Accepted consumer maps are either identity or permutation.
+      auto invMap = inversePermutation(fusedIndexMaps[operand.index()]);
+
+      // Compute the indexing map to use for the result of the producer.
+      AffineMap modifiedMap =
+          linearizeCollapsedDims(invMap, reshapeOp.getResultType().getShape(),
+                                 reshapeOp.getReassociationMaps());
+      for (AffineExpr expr : modifiedMap.getResults()) {
+        if (!expr.isPureAffine())
+          return failure();
+      }
+      fusedIndexMaps[operand.index()] = modifiedMap;
+
+      // Further check that the resulting index maps can be fused and
+      // inverted. Without this the resultant op is not legal.
+      if (!inversePermutation(concatAffineMaps(fusedIndexMaps)))
+        return op.emitRemark("fused op loop bound computation failed");
+
+      rewriter.startRootUpdate(op);
+      op.getOperation()->setOperands(fusedOperands);
+      op.indexing_mapsAttr(rewriter.getAffineMapArrayAttr(fusedIndexMaps));
+      rewriter.finalizeRootUpdate(op);
+      if (reshapeOp.use_empty())
+        rewriter.eraseOp(reshapeOp);
+      return success();
+    }
+    return op.emitRemark("no fusion candidates found");
   }
 };
 
-/// Implementation of fusion on tensor ops when consumer is a TensorReshapeOp.
-struct FuseTensorReshapeOpAsConsumer {
-  static bool isCollapsingAndFusible(LinalgOp producer,
-                                     TensorReshapeOp consumer,
-                                     unsigned consumerIdx) {
-    return isa<GenericOp, IndexedGenericOp>(producer.getOperation()) &&
-           producer.hasTensorSemantics() &&
-           isTensorReshapeOpFusible(consumer, producer.getOutputIndexingMap(0),
-                                    /*asProducer=*/false);
+/// Pattern to fuse a tensor_reshape op with its consumer generic op, when the
+/// reshape op is collapsing dimensions. The dimensionality of the loop in the
+/// consumer generic op is expanded.
+struct FoldWithProducerReshapeOpByExpansion
+    : public OpRewritePattern<GenericOp> {
+  using OpRewritePattern<GenericOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(GenericOp genericOp,
+                                PatternRewriter &rewriter) const override {
+    LinalgOp linalgOp = cast<LinalgOp>(genericOp.getOperation());
+    for (auto operand : llvm::enumerate(linalgOp.getInputs())) {
+      TensorReshapeOp reshapeOp =
+          operand.value().getDefiningOp<TensorReshapeOp>();
+      if (!reshapeOp)
+        continue;
+
+      // Fold only if
+      // - The tensor reshape op is folding.
+      // - All constraints of fusing with reshape by expansion are met.
+      if (reshapeOp.getSrcType().getRank() <
+              reshapeOp.getResultType().getRank() ||
+          !isFusableWithReshapeByDimExpansion(linalgOp, operand.index()))
+        continue;
+
+      Optional<SmallVector<Value, 1>> replacementValues =
+          fuseWithReshapeByExpansion(linalgOp, reshapeOp, operand.index(),
+                                     rewriter);
+      if (!replacementValues)
+        return failure();
+      rewriter.replaceOp(genericOp, replacementValues.getValue());
+      if (reshapeOp.use_empty())
+        rewriter.eraseOp(reshapeOp);
+      return success();
+    }
+    return failure();
   }
+};
 
-  static LinalgOp fuseCollapsingCase(LinalgOp producer,
-                                     TensorReshapeOp consumer,
-                                     unsigned consumerIdx,
-                                     PatternRewriter &rewriter) {
+/// Pattern to fold tensor_reshape op with its producer. The corresponding index
+/// map in the consumer needs to be modified to linearize the folded dimension.
+struct FoldConsumerReshapeOpByLinearization
+    : public OpRewritePattern<TensorReshapeOp> {
+  using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
+                                PatternRewriter &rewriter) const override {
+    LinalgOp producer = reshapeOp.src().getDefiningOp<LinalgOp>();
+    if (!producer ||
+        !isa<GenericOp, IndexedGenericOp>(producer.getOperation()) ||
+        !producer.hasTensorSemantics() || producer.getNumOutputs() != 1 ||
+        !isTensorReshapeOpFoldableByLinearization(
+            reshapeOp, producer.getOutputIndexingMap(0), /*asProducer =*/false))
+      return failure();
     // The indexing_maps for the operands of the fused operation are same as
     // those for the operands of the producer.
-    SmallVector<AffineMap, 4> fusedIndexMaps =
-        llvm::to_vector<4>(llvm::map_range(
-            producer.indexing_maps(), [](Attribute attr) -> AffineMap {
-              return attr.cast<AffineMapAttr>().getValue();
-            }));
+    SmallVector<AffineMap, 4> fusedIndexMaps = llvm::to_vector<4>(
+        producer.indexing_maps().getAsValueRange<AffineMapAttr>());
 
     auto invMap = inversePermutation(producer.getOutputIndexingMap(0));
 
     // Compute the indexing map to use for the operand of the producer.
     AffineMap modifiedMap =
-        linearizeCollapsedDims(invMap, consumer.getSrcType().getShape(),
-                               consumer.getReassociationMaps());
+        linearizeCollapsedDims(invMap, reshapeOp.getSrcType().getShape(),
+                               reshapeOp.getReassociationMaps());
     for (AffineExpr expr : modifiedMap.getResults()) {
       if (!expr.isPureAffine())
-        return nullptr;
+        return reshapeOp.emitRemark("fused op indexing map is not affine");
     }
     fusedIndexMaps.back() = modifiedMap;
 
     // Further check that the resulting index maps can be fused and
     // inverted. Without this the resultant op is not legal.
     if (!inversePermutation(concatAffineMaps(fusedIndexMaps)))
-      return nullptr;
+      return reshapeOp.emitRemark("fused op loop bound computation failed");
 
-    SmallVector<Attribute, 4> indexMapAttrs = llvm::to_vector<4>(
-        llvm::map_range(fusedIndexMaps, [](AffineMap map) -> Attribute {
-          return AffineMapAttr::get(map);
-        }));
-
-    Operation *producerOp = producer.getOperation();
     LinalgOp fusedOp = createLinalgOpOfSameType(
-        producer, rewriter, rewriter.getUnknownLoc(), consumer.getResultType(),
-        /*inputs=*/producerOp->getOperands(),
+        producer, rewriter, rewriter.getUnknownLoc(), reshapeOp.getResultType(),
+        /*inputs=*/producer.getInputs(),
         /*outputBuffers=*/ValueRange{},
         /*initTensors=*/ValueRange{}, // no init tensors for now.
-        rewriter.getArrayAttr(indexMapAttrs), producer.iterator_types(),
+        rewriter.getAffineMapArrayAttr(fusedIndexMaps),
+        producer.iterator_types(),
         /*doc=*/nullptr,
         /*library_call=*/nullptr,
         /*symbol_source=*/nullptr);
     auto &fusedRegion = fusedOp.getOperation()->getRegion(0);
-    rewriter.cloneRegionBefore(producerOp->getRegion(0), fusedRegion,
-                               fusedRegion.begin());
-    return fusedOp;
-  }
-
-  static bool isExpandingAndFusible(LinalgOp producer, TensorReshapeOp consumer,
-                                    unsigned consumerIdx) {
-    // Is fusible only if:
-    //   1) The producer is a generic op.
-    //   2) The producer has tensor semantics.
-    //   3) The tensor reshape op is a expanding case.
-    //   4) All the shapes are the same for the generic op.
-    //   5) All the indexing maps in producer are identity.
-    //   6) All the loops in producer are parallel loops.
-    //   7) The producer has a single user.
-    auto types = producer.getInputOutputShapedTypes();
-    assert(!types.empty());
-    return isa<GenericOp>(producer.getOperation()) &&
-           producer.hasTensorSemantics() &&
-           consumer.getSrcType().getRank() <
-               consumer.getResultType().getRank() &&
-           std::equal(types.begin() + 1, types.end(), types.begin()) &&
-           llvm::all_of(producer.getIndexingMaps(),
-                        [](AffineMap map) { return map.isIdentity(); }) &&
-           llvm::all_of(producer.iterator_types(),
-                        [](Attribute attr) {
-                          return attr.cast<StringAttr>().getValue() ==
-                                 getParallelIteratorTypeName();
-                        }) &&
-           producer.getOperation()->hasOneUse();
-  }
-
-  static LinalgOp fuseExpandingCase(LinalgOp producer, TensorReshapeOp consumer,
-                                    unsigned consumerIdx,
-                                    PatternRewriter &rewriter) {
-    Location loc = producer.getLoc();
-    auto dstShape = consumer.getResultType().cast<ShapedType>().getShape();
-    SmallVector<Value, 4> args;
-    for (auto arg : producer.getOperation()->getOperands()) {
-      auto type = RankedTensorType::get(
-          dstShape, arg.getType().cast<ShapedType>().getElementType());
-      args.push_back(rewriter.createOrFold<linalg::TensorReshapeOp>(
-          loc, type, arg, consumer.reassociation()));
-    }
-
-    SmallVector<Type, 4> resultTypes;
-    for (auto t : producer.getOutputTensorTypes()) {
-      Type type = RankedTensorType::get(dstShape,
-                                        t.cast<ShapedType>().getElementType());
-      resultTypes.push_back(type);
-    }
-
-    int rank = dstShape.size();
-    auto genericOp = rewriter.create<linalg::GenericOp>(
-        loc, resultTypes, /*inputs=*/args,
-        /*outputBuffers=*/ValueRange{},
-        /*initTensors=*/ValueRange{},
-        SmallVector<AffineMap, 3>(args.size() + resultTypes.size(),
-                                  rewriter.getMultiDimIdentityMap(rank)),
-        SmallVector<StringRef, 3>(rank, getParallelIteratorTypeName()));
-    Region &region = genericOp.getRegion();
-    rewriter.cloneRegionBefore(producer.getOperation()->getRegion(0), region,
-                               region.begin());
-    return cast<LinalgOp>(genericOp.getOperation());
-  }
-
-  static LinalgOp fuse(LinalgOp producer, TensorReshapeOp consumer,
-                       unsigned consumerIdx, PatternRewriter &rewriter,
-                       OperationFolder *folder = nullptr) {
-    if (isCollapsingAndFusible(producer, consumer, consumerIdx))
-      return fuseCollapsingCase(producer, consumer, consumerIdx, rewriter);
-    if (isExpandingAndFusible(producer, consumer, consumerIdx))
-      return fuseExpandingCase(producer, consumer, consumerIdx, rewriter);
-    return nullptr;
+    rewriter.cloneRegionBefore(producer.getOperation()->getRegion(0),
+                               fusedRegion, fusedRegion.begin());
+    rewriter.replaceOp(reshapeOp, fusedOp.getOperation()->getResults());
+    if (producer.use_empty())
+      rewriter.eraseOp(producer);
+    return success();
   }
 };
 
-/// Implementation of fusion on tensor ops when producer is a splat constant.
-struct FuseConstantOpAsProducer {
-  static bool isFusible(ConstantOp producer, LinalgOp consumer,
-                        unsigned consumerIdx) {
-    return isa<GenericOp, IndexedGenericOp>(consumer.getOperation()) &&
-           consumer.hasTensorSemantics() &&
-           producer.getResult().getType().isa<RankedTensorType>() &&
-           producer.value().cast<DenseElementsAttr>().isSplat();
+/// Pattern to fold a tensor_reshape op with its producer generic op if the
+/// tensor_reshape op is expanding, by expanding the dimensionality of the loop
+/// in the producer op.
+struct FoldReshapeWithGenericOpByExpansion
+    : public OpRewritePattern<TensorReshapeOp> {
+  using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
+                                PatternRewriter &rewriter) const override {
+    // Fold only if
+    // - The tensor reshape op is a expanding case.
+    // - All constraints of fusing with reshape by expansion are met.
+    if (reshapeOp.getSrcType().getRank() > reshapeOp.getResultType().getRank())
+      return failure();
+    LinalgOp producer = reshapeOp.src().getDefiningOp<LinalgOp>();
+    if (!producer || producer.getNumOutputs() != 1 ||
+        !isFusableWithReshapeByDimExpansion(producer, producer.getNumInputs()))
+      return failure();
+    Optional<SmallVector<Value, 1>> replacementValues =
+        fuseWithReshapeByExpansion(producer, reshapeOp, producer.getNumInputs(),
+                                   rewriter);
+    if (!replacementValues)
+      return failure();
+    rewriter.replaceOp(reshapeOp, replacementValues.getValue());
+    if (producer.use_empty())
+      rewriter.eraseOp(producer);
+    return success();
   }
+};
 
-  static LinalgOp fuse(ConstantOp producer, LinalgOp consumer,
-                       unsigned consumerIdx, PatternRewriter &rewriter,
-                       OperationFolder *folder = nullptr) {
-    if (!isFusible(producer, consumer, consumerIdx))
-      return nullptr;
-
-    // The indexing_maps for the operands of the fused operation are same as
-    // those for the operands of the consumer without the indexing map at
-    // consumerIdx
-    SmallVector<AffineMap, 4> fusedIndexMaps =
-        llvm::to_vector<4>(llvm::map_range(
-            consumer.indexing_maps(), [](Attribute attr) -> AffineMap {
-              return attr.cast<AffineMapAttr>().getValue();
-            }));
-    fusedIndexMaps.erase(std::next(fusedIndexMaps.begin(), consumerIdx));
-
-    // The operands list is same as the consumer with the argument for constant
-    // index dropped.
-    Operation *consumerOp = consumer.getOperation();
-    SmallVector<Value, 4> fusedOperands(consumerOp->getOperands());
-    fusedOperands.erase(std::next(fusedOperands.begin(), consumerIdx));
-
-    // Create a constant scalar value from the splat constant.
-    Value scalarConstant = rewriter.create<ConstantOp>(
-        producer.getLoc(),
-        producer.value().cast<DenseElementsAttr>().getSplatValue());
+/// Pattern to fold a GenericOp/IndexedGenericOp with a splat constant.
+template <typename LinalgOpTy>
+struct FoldSplatConstants : public OpRewritePattern<LinalgOpTy> {
+  using OpRewritePattern<LinalgOpTy>::OpRewritePattern;
 
-    LinalgOp fusedOp = createLinalgOpOfSameType(
-        consumer, rewriter, rewriter.getUnknownLoc(),
-        consumerOp->getResultTypes(),
-        /*inputs=*/fusedOperands,
-        /*outputBuffers=*/ValueRange{},
-        /*initTensors=*/ValueRange{}, // no init tensors for now.
-        rewriter.getAffineMapArrayAttr(fusedIndexMaps),
-        consumer.iterator_types(),
-        /*doc=*/nullptr,
-        /*library_call=*/nullptr,
-        /*symbol_source=*/nullptr);
+  LogicalResult matchAndRewrite(LinalgOpTy op,
+                                PatternRewriter &rewriter) const override {
+    if (!op.hasTensorSemantics())
+      return failure();
+    LinalgOp linalgOp = cast<LinalgOp>(op.getOperation());
+    for (auto operand : llvm::enumerate(linalgOp.getInputs())) {
+      ConstantOp constantOp = operand.value().getDefiningOp<ConstantOp>();
+      if (!constantOp ||
+          !constantOp.value().cast<DenseElementsAttr>().isSplat())
+        continue;
 
-    // Map the block argument corresponding to the replaced argument with the
-    // scalar constant.
-    Region &consumerRegion = consumerOp->getRegion(0);
-    Block &entryBlock = *consumerRegion.begin();
-    unsigned argIndex = entryBlock.getNumArguments() -
-                        consumerOp->getNumOperands() + consumerIdx;
-    BlockAndValueMapping mapping;
-    mapping.map(entryBlock.getArgument(argIndex), scalarConstant);
-    Region &fusedRegion = fusedOp.getOperation()->getRegion(0);
-    rewriter.cloneRegionBefore(consumerRegion, fusedRegion, fusedRegion.begin(),
-                               mapping);
-    return fusedOp;
+      // The indexing_maps for the operands of the fused operation are same as
+      // those for the operands of the linalgOp without the indexing map at
+      // operand.index()
+      SmallVector<AffineMap, 4> fusedIndexMaps = llvm::to_vector<4>(
+          linalgOp.indexing_maps().getAsValueRange<AffineMapAttr>());
+      fusedIndexMaps.erase(std::next(fusedIndexMaps.begin(), operand.index()));
+
+      // The operands list is same as the linalgOp with the argument for
+      // constant index dropped.
+      SmallVector<Value, 4> fusedOperands(linalgOp.getInputs());
+      fusedOperands.erase(std::next(fusedOperands.begin(), operand.index()));
+
+      // Create a constant scalar value from the splat constant.
+      Value scalarConstant = rewriter.create<ConstantOp>(
+          constantOp.getLoc(),
+          constantOp.value().cast<DenseElementsAttr>().getSplatValue());
+
+      LinalgOp fusedOp = createLinalgOpOfSameType(
+          linalgOp, rewriter, rewriter.getUnknownLoc(),
+          linalgOp.getOperation()->getResultTypes(),
+          /*inputs=*/fusedOperands,
+          /*outputBuffers=*/ValueRange{},
+          /*initTensors=*/ValueRange{}, // no init tensors for now.
+          rewriter.getAffineMapArrayAttr(fusedIndexMaps),
+          linalgOp.iterator_types(),
+          /*doc=*/nullptr,
+          /*library_call=*/nullptr,
+          /*symbol_source=*/nullptr);
+
+      // Map the block argument corresponding to the replaced argument with the
+      // scalar constant.
+      Region &linalgOpRegion = linalgOp.getOperation()->getRegion(0);
+      Block &entryBlock = *linalgOpRegion.begin();
+      unsigned argIndex = entryBlock.getNumArguments() -
+                          linalgOp.getNumInputs() + operand.index();
+      BlockAndValueMapping mapping;
+      mapping.map(entryBlock.getArgument(argIndex), scalarConstant);
+      Region &fusedRegion = fusedOp.getOperation()->getRegion(0);
+      rewriter.cloneRegionBefore(linalgOpRegion, fusedRegion,
+                                 fusedRegion.begin(), mapping);
+      rewriter.replaceOp(linalgOp, fusedOp.getOperation()->getResults());
+      if (constantOp.use_empty())
+        rewriter.eraseOp(constantOp);
+      return success();
+    }
+    return failure();
   }
 };
 } // namespace
 
-Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
-                                       Operation *consumer,
-                                       unsigned consumerIdx,
-                                       OperationFolder *folder) {
+Optional<SmallVector<Value, 1>>
+mlir::linalg::fuseTensorOps(PatternRewriter &rewriter, Operation *consumer,
+                            unsigned consumerIdx, OperationFolder *folder) {
   if (consumerIdx >= consumer->getNumOperands())
-    return nullptr;
+    return llvm::None;
   Operation *producer = consumer->getOperand(consumerIdx).getDefiningOp();
   if (!producer || producer->getNumResults() != 1)
-    return nullptr;
+    return llvm::None;
 
   // Fuse when consumer is GenericOp or IndexedGenericOp.
-  if (isa<GenericOp, IndexedGenericOp>(consumer)) {
-    if (isa<GenericOp, IndexedGenericOp>(producer))
-      return FuseGenericOpsOnTensors::fuse(cast<LinalgOp>(producer),
-                                           cast<LinalgOp>(consumer),
-                                           consumerIdx, rewriter, folder);
-    if (auto reshapeOpProducer = dyn_cast<TensorReshapeOp>(producer))
-      return FuseTensorReshapeOpAsProducer::fuse(reshapeOpProducer,
-                                                 cast<LinalgOp>(consumer),
-                                                 consumerIdx, rewriter, folder);
-    if (auto constantOpProducer = dyn_cast<ConstantOp>(producer))
-      return FuseConstantOpAsProducer::fuse(constantOpProducer,
-                                            cast<LinalgOp>(consumer),
-                                            consumerIdx, rewriter, folder);
-    return nullptr;
-  }
+  if (!isa<GenericOp, IndexedGenericOp>(consumer) ||
+      !isa<GenericOp, IndexedGenericOp>(producer))
+    return llvm::None;
 
-  if (isa<GenericOp, IndexedGenericOp>(producer)) {
-    // Fuse when consumer is a TensorReshapeOp.
-    if (TensorReshapeOp reshapeOp = dyn_cast<TensorReshapeOp>(consumer)) {
-      return FuseTensorReshapeOpAsConsumer::fuse(
-          cast<LinalgOp>(producer), reshapeOp, consumerIdx, rewriter, folder);
-    }
-  }
-
-  return nullptr;
+  return fuseTensorOpsImpl(cast<LinalgOp>(producer), cast<LinalgOp>(consumer),
+                           consumerIdx, rewriter, folder);
 }
 
 namespace {
@@ -669,10 +884,13 @@ struct FuseTensorOps : public OpRewritePattern<LinalgOpTy> {
          llvm::seq<unsigned>(0, op.getOperation()->getNumOperands())) {
       Operation *producer =
           op.getOperation()->getOperand(operandNum).getDefiningOp();
-      if (Operation *fusedOp = fuseTensorOps(rewriter, op, operandNum)) {
-        rewriter.replaceOp(op, fusedOp->getResults());
-        if (producer && llvm::all_of(producer->getResults(),
-                                     [](Value val) { return val.use_empty(); }))
+      if (!producer)
+        continue;
+      Optional<SmallVector<Value, 1>> fusedOpResults =
+          fuseTensorOps(rewriter, op, operandNum);
+      if (fusedOpResults) {
+        rewriter.replaceOp(op, *fusedOpResults);
+        if (producer->use_empty())
           rewriter.eraseOp(producer);
         return success();
       }
@@ -689,16 +907,52 @@ struct FusionOfTensorOpsPass
     Operation *op = getOperation();
     populateLinalgTensorOpsFusionPatterns(op->getContext(), patterns);
     applyPatternsAndFoldGreedily(op->getRegions(), patterns);
-  };
+  }
+};
+
+/// Pass to test folding of reshape op with generic/indexed_generic ops by
+/// linearization.
+struct FoldReshapeOpsByLinearizationPass
+    : public LinalgFoldReshapeOpsByLinearizationBase<
+          FoldReshapeOpsByLinearizationPass> {
+  void runOnOperation() override {
+    OwningRewritePatternList patterns;
+    Operation *op = getOperation();
+    populateFoldReshapeOpsByLinearizationPatterns(op->getContext(), patterns);
+    applyPatternsAndFoldGreedily(op->getRegions(), patterns);
+  }
 };
+
 } // namespace
 
+void mlir::populateFoldReshapeOpsByLinearizationPatterns(
+    MLIRContext *context, OwningRewritePatternList &patterns) {
+  patterns.insert<FoldProducerReshapeOpByLinearization<GenericOp>,
+                  FoldProducerReshapeOpByLinearization<IndexedGenericOp>,
+                  FoldConsumerReshapeOpByLinearization>(context);
+}
+
+void mlir::populateFoldReshapeOpsByExpansionPatterns(
+    MLIRContext *context, OwningRewritePatternList &patterns) {
+  patterns.insert<FoldReshapeWithGenericOpByExpansion,
+                  FoldWithProducerReshapeOpByExpansion>(context);
+}
+
 void mlir::populateLinalgTensorOpsFusionPatterns(
     MLIRContext *context, OwningRewritePatternList &patterns) {
   patterns.insert<FuseTensorOps<GenericOp>, FuseTensorOps<IndexedGenericOp>,
-                  FuseTensorOps<TensorReshapeOp>>(context);
+                  FoldSplatConstants<GenericOp>,
+                  FoldSplatConstants<IndexedGenericOp>>(context);
+  populateFoldReshapeOpsByExpansionPatterns(context, patterns);
+  GenericOp::getCanonicalizationPatterns(patterns, context);
+  IndexedGenericOp::getCanonicalizationPatterns(patterns, context);
+  TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
 }
 
 std::unique_ptr<Pass> mlir::createLinalgFusionOfTensorOpsPass() {
   return std::make_unique<FusionOfTensorOpsPass>();
 }
+
+std::unique_ptr<Pass> mlir::createFoldReshapeOpsByLinearizationPass() {
+  return std::make_unique<FoldReshapeOpsByLinearizationPass>();
+}

diff  --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
index 3f8b0680d7a4..40ef68d870ea 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
@@ -142,124 +142,6 @@ func @add_mul_scalar_fusion(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tenso
 
 // -----
 
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-
-#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xf32>,
-                                         %arg1 : tensor<?x?x4x?xf32>) ->
-                                         tensor<?x?x4x?xf32>
-{
-  %0 = linalg.tensor_reshape %arg0 [affine_map<(i, j, k, l) -> (i)>,
-                                    affine_map<(i, j, k, l) -> (j, k)>,
-                                    affine_map<(i, j, k, l) -> (l)>] :
-    tensor<?x?x?xf32> into tensor<?x?x4x?xf32>
-  %1 = linalg.generic {
-     indexing_maps = [#map0, #map0, #map0],
-     iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
-      ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) {
-    ^bb0(%arg3: f32, %arg4: f32):       // no predecessors
-      %1 = mulf %arg3, %arg4 : f32
-      linalg.yield %1 : f32
-  } -> tensor<?x?x4x?xf32>
-  return %1 : tensor<?x?x4x?xf32>
-}
-
-// CHECK-LABEL: func @generic_op_reshape_producer_fusion
-//       CHECK: linalg.generic
-//  CHECK-SAME:   indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]]
-//   CHECK-NOT: linalg.generic
-
-
-// -----
-
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>
-
-#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xf32>,
-                                         %arg1 : tensor<?x?x4x5xf32>) ->
-                                         tensor<?x?xf32>
-{
-  %0 = linalg.generic {
-     indexing_maps = [#map0, #map0, #map0],
-     iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
-      ins(%arg0, %arg1 : tensor<?x?x4x5xf32>, tensor<?x?x4x5xf32>) {
-    ^bb0(%arg3: f32, %arg4: f32):       // no predecessors
-      %1 = mulf %arg3, %arg4 : f32
-      linalg.yield %1 : f32
-  } -> tensor<?x?x4x5xf32>
-  %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
-                                 affine_map<(i, j, k, l) -> (j, k, l)>] :
-    tensor<?x?x4x5xf32> into tensor<?x?xf32>
-  return %1 : tensor<?x?xf32>
-}
-
-// CHECK-LABEL: func @generic_op_reshape_consumer_fusion
-//       CHECK: linalg.generic
-//  CHECK-SAME:   indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP1]]]
-//   CHECK-NOT: linalg.generic
-
-// -----
-
-#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-func @generic_op_reshape_consumer_nofusion(%arg0 : tensor<?x?x?x5xf32>,
-                                           %arg1 : tensor<?x?x?x5xf32>) ->
-                                           tensor<?x?xf32>
-{
-  %0 = linalg.generic {
-     indexing_maps = [#map0, #map0, #map0],
-     iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
-      ins(%arg0, %arg1 : tensor<?x?x?x5xf32>, tensor<?x?x?x5xf32>) {
-    ^bb0(%arg3: f32, %arg4: f32):       // no predecessors
-      %1 = mulf %arg3, %arg4 : f32
-      linalg.yield %1 : f32
-  } -> tensor<?x?x?x5xf32>
-  %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
-                                 affine_map<(i, j, k, l) -> (j, k, l)>] :
-    tensor<?x?x?x5xf32> into tensor<?x?xf32>
-  return %1 : tensor<?x?xf32>
-}
-
-// CHECK-LABEL: func @generic_op_reshape_consumer_nofusion
-//       CHECK: linalg.tensor_reshape
-
-// -----
-
-#map0 = affine_map<(d0, d1) -> (d0, d1)>
-#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
-#map2 = affine_map<(d0, d1, d2) -> (d2)>
-
-func @generic_op_reshape_consumer_expanding(%arg0: tensor<264x4xf32>)
-                                            -> tensor<8x33x4xf32> {
-  %cst = constant dense<2.000000e+00> : tensor<264x4xf32>
-  %0 = linalg.generic {
-     indexing_maps = [#map0, #map0, #map0],
-     iterator_types = ["parallel", "parallel"]}
-      ins(%arg0, %cst : tensor<264x4xf32>, tensor<264x4xf32>) {
-    ^bb0(%arg1: f32, %arg2: f32):  // no predecessors
-      %2 = mulf %arg1, %arg2 : f32
-      linalg.yield %2 : f32
-    } -> tensor<264x4xf32>
-  %1 = linalg.tensor_reshape %0 [#map1, #map2] :
-    tensor<264x4xf32> into tensor<8x33x4xf32>
-  return %1 : tensor<8x33x4xf32>
-}
-
-// The reshape op in `%arg0` is folded into the indexing map of generic op.
-//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0 * 33 + d1, d2)>
-//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-//       CHECK: func @generic_op_reshape_consumer_expanding
-//   CHECK-NOT:   linalg.tensor_reshape
-//       CHECK:   %[[CST:.*]] = constant {{.*}} : f32
-//       CHECK:   linalg.generic
-//  CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]]]
-//  CHECK-SAME:   tensor<264x4xf32>
-//       CHECK:   -> tensor<8x33x4xf32>
-//   CHECK-NOT:   linalg.tensor_reshape
-
-// -----
-
 #map0 = affine_map<(d0, d1, d2) -> (d0)>
 #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 func @generic_op_constant_fusion(%arg0 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32>
@@ -499,159 +381,3 @@ func @indexed_generic_op_fusion(%arg0: tensor<?x?xi32>) {
 //      CHECK:   %[[VAL4:.+]] = subi %[[VAL3]], %[[SUB_OPERAND2]] : i32
 //      CHECK:   linalg.yield %[[VAL4]] : i32
 //   CHECK-NOT: linalg.indexed_generic
-
-// -----
-
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-
-#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-func @indexed_generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xi32>)
-  -> tensor<?x?x4x?xi32> {
-  %0 = linalg.tensor_reshape %arg0 [affine_map<(i, j, k, l) -> (i)>,
-                                    affine_map<(i, j, k, l) -> (j, k)>,
-                                    affine_map<(i, j, k, l) -> (l)>] :
-    tensor<?x?x?xi32> into tensor<?x?x4x?xi32>
-  %1 = linalg.indexed_generic {
-    indexing_maps = [#map0, #map0],
-    iterator_types = ["parallel", "parallel", "parallel", "parallel"] }
-    ins(%0 : tensor<?x?x4x?xi32>) {
-  ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: i32):       // no predecessors
-    %2 = index_cast %arg2 : index to i32
-    %3 = addi %arg6, %2 : i32
-    linalg.yield %3 : i32
-  } -> tensor<?x?x4x?xi32>
-  return %1 : tensor<?x?x4x?xi32>
-}
-
-// CHECK-LABEL: func @indexed_generic_op_reshape_producer_fusion
-//   CHECK-NOT: linalg.tensor_reshape
-//       CHECK: linalg.indexed_generic
-//  CHECK-SAME:   indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
-//   CHECK-NOT: linalg.tensor_reshape
-
-// -----
-
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>
-
-#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xi32>)
-  -> tensor<?x?xi32> {
-  %0 = linalg.indexed_generic {
-    indexing_maps = [#map0, #map0],
-    iterator_types = ["parallel", "parallel", "parallel", "parallel"] }
-    ins(%arg0 : tensor<?x?x4x5xi32>) {
-  ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: i32):       // no predecessors
-    %2 = index_cast %arg2 : index to i32
-    %3 = addi %arg6, %2 : i32
-    linalg.yield %3 : i32
-  } -> tensor<?x?x4x5xi32>
-  %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
-                                 affine_map<(i, j, k, l) -> (j, k, l)>] :
-    tensor<?x?x4x5xi32> into tensor<?x?xi32>
-  return %1 : tensor<?x?xi32>
-}
-
-// CHECK-LABEL: func @indexed_generic_op_reshape_consumer_fusion
-//   CHECK-NOT: linalg.tensor_reshape
-//       CHECK: linalg.indexed_generic
-//  CHECK-SAME:   indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
-//   CHECK-NOT: linalg.tensor_reshape
-
-// -----
-
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 + d2 * 7)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-
-#map0 = affine_map<(d0, d1, d2) -> (d0)>
-#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
-#map2 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
-#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-func @generic_op_021_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<3x7x5xf32> {
-  %0 = linalg.tensor_reshape %arg0 [#map0, #map1] : tensor<3x35xf32> into tensor<3x5x7xf32>
-  %1 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<3x5x7xf32>) {
-    ^bb0(%arg2: f32):  // no predecessors
-      linalg.yield %arg2 : f32
-    } -> tensor<3x7x5xf32>
-    return %1 : tensor<3x7x5xf32>
-}
-
-// CHECK-LABEL: func @generic_op_021_permultation_reshape_producer_fusion
-//   CHECK-NOT: linalg.tensor_reshape
-//       CHECK: linalg.generic
-//  CHECK-SAME:   indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
-//   CHECK-NOT: linalg.tensor_reshape
-
-// -----
-
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0 * 7 + d1)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-
-#map0 = affine_map<(d0, d1, d2) -> (d0)>
-#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
-#map2 = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
-#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-func @generic_op_120_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x7x3xf32> {
-  %0 = linalg.tensor_reshape %arg0 [#map0, #map1] : tensor<3x35xf32> into tensor<3x5x7xf32>
-  %1 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<3x5x7xf32>) {
-    ^bb0(%arg2: f32):  // no predecessors
-      linalg.yield %arg2 : f32
-    } -> tensor<5x7x3xf32>
-    return %1 : tensor<5x7x3xf32>
-}
-
-// CHECK-LABEL: func @generic_op_120_permultation_reshape_producer_fusion
-//   CHECK-NOT: linalg.tensor_reshape
-//       CHECK: linalg.generic
-//  CHECK-SAME:   indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
-//   CHECK-NOT: linalg.tensor_reshape
-
-// -----
-
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-
-#map0 = affine_map<(d0, d1, d2) -> (d0)>
-#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
-#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
-#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-func @generic_op_102_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x3x7xf32> {
-  %0 = linalg.tensor_reshape %arg0 [#map0, #map1] : tensor<3x35xf32> into tensor<3x5x7xf32>
-  %1 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<3x5x7xf32>) {
-    ^bb0(%arg2: f32):  // no predecessors
-      linalg.yield %arg2 : f32
-    } -> tensor<5x3x7xf32>
-    return %1 : tensor<5x3x7xf32>
-}
-
-// CHECK-LABEL: func @generic_op_102_permultation_reshape_producer_fusion
-//   CHECK-NOT: linalg.tensor_reshape
-//       CHECK: linalg.generic
-//  CHECK-SAME:   indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
-//   CHECK-NOT: linalg.tensor_reshape
-
-// -----
-
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)>
-
-
-#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-#map1 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
-#map2 = affine_map<(d0, d1, d2) -> (d0)>
-#map3 = affine_map<(d0, d1, d2) -> (d1, d2)>
-func @generic_op_102_permultation_reshape_consumer_fusion(%arg0 : tensor<3x5x7xf32>) -> tensor<5x21xf32> {
-  %0 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<3x5x7xf32>) {
-    ^bb0(%arg2: f32):  // no predecessors
-      linalg.yield %arg2 : f32
-  } -> tensor<5x3x7xf32>
-  %1 = linalg.tensor_reshape %0 [#map2, #map3] : tensor<5x3x7xf32> into tensor<5x21xf32>
-  return %1 : tensor<5x21xf32>
-}
-
-// CHECK-LABEL: func @generic_op_102_permultation_reshape_consumer_fusion
-//   CHECK-NOT: linalg.tensor_reshape
-//       CHECK: linalg.generic
-//  CHECK-SAME:   indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
-//   CHECK-NOT: linalg.tensor_reshape

diff  --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
new file mode 100644
index 000000000000..865b10b51696
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -0,0 +1,192 @@
+// RUN: mlir-opt %s -linalg-fusion-for-tensor-ops -split-input-file | FileCheck %s
+
+#map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
+func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?x?xf32>,
+                                         %arg1 : tensor<?x?x?xf32>) ->
+                                         tensor<?x?x?xf32>
+{
+  %0 = linalg.tensor_reshape %arg0 [affine_map<(i, j, k, l) -> (i)>,
+                                    affine_map<(i, j, k, l) -> (j, k)>,
+                                    affine_map<(i, j, k, l) -> (l)>] :
+    tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
+  %1 = linalg.generic {
+     indexing_maps = [#map0, #map1, #map1],
+     iterator_types = ["parallel", "parallel", "parallel"]}
+      ins(%0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
+    ^bb0(%arg3: f32, %arg4: f32):       // no predecessors
+      %1 = mulf %arg3, %arg4 : f32
+      linalg.yield %1 : f32
+  } -> tensor<?x?x?xf32>
+  return %1 : tensor<?x?x?xf32>
+}
+
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)>
+//  CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3, d0, d1)>
+//      CHECK: func @generic_op_reshape_producer_fusion
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+//      CHECK:   %[[T0:.+]] = linalg.tensor_reshape %[[ARG1]]
+// CHECK-SAME:     [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME:     tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
+//      CHECK:   %[[T1:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[MAP3]], #[[MAP4]], #[[MAP4]]]
+// CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME:     ins(%[[ARG0]], %[[T0]] : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+//      CHECK:   %[[T2:.+]] = linalg.tensor_reshape
+// CHECK-SAME:     [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME:     tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
+//      CHECK:   return %[[T2]]
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
+                                         %arg1 : tensor<?x?xf32>) ->
+                                         tensor<?x?x4x5xf32>
+{
+  %0 = linalg.generic {
+     indexing_maps = [#map0, #map0, #map0],
+     iterator_types = ["parallel", "parallel"]}
+      ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) {
+    ^bb0(%arg3: f32, %arg4: f32):       // no predecessors
+      %1 = mulf %arg3, %arg4 : f32
+      linalg.yield %1 : f32
+  } -> tensor<?x?xf32>
+  %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
+                                 affine_map<(i, j, k, l) -> (j, k, l)>] :
+    tensor<?x?xf32> into tensor<?x?x4x5xf32>
+  return %1 : tensor<?x?x4x5xf32>
+}
+
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+//      CHECK: func @generic_op_reshape_consumer_fusion
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//      CHECK:   %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
+// CHECK-SAME:     [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME:     tensor<?x?xf32> into tensor<?x?x4x5xf32>
+//      CHECK:   %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]]
+// CHECK-SAME:     [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME:     tensor<?x?xf32> into tensor<?x?x4x5xf32>
+//      CHECK:   %[[T2:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]]
+// CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME:     ins(%[[T0]], %[[T1]] : tensor<?x?x4x5xf32>, tensor<?x?x4x5xf32>)
+//      CHECK:   return %[[T2]] : tensor<?x?x4x5xf32>
+
+
+// -----
+
+func @reshape_as_consumer_permutation
+  (%a : tensor<?x?x?xf32>, %b : tensor<?x?xf32>)
+    -> tensor<?x?x?x?x?x?xf32> {
+  %c = linalg.generic {
+         indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
+                          affine_map<(d0, d1, d2) -> (d1, d2)>,
+                          affine_map<(d0, d1, d2) -> (d0, d2, d1)>],
+         iterator_types = ["parallel", "parallel", "parallel"]}
+         ins(%a, %b : tensor<?x?x?xf32>, tensor<?x?xf32>) {
+       ^bb0(%arg0 : f32, %arg1: f32):
+         %1 = addf %arg0, %arg1 : f32
+         linalg.yield %1 : f32
+       } -> tensor<?x?x?xf32>
+  %d = linalg.tensor_reshape %c
+         [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>,
+          affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>,
+          affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>]
+       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
+  return %d : tensor<?x?x?x?x?x?xf32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d5)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+//  CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+//  CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>
+//  CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
+//  CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)>
+//      CHECK: func @reshape_as_consumer_permutation
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//      CHECK:   %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
+// CHECK-SAME:     [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME:     tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
+//      CHECK:   %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]]
+// CHECK-SAME:     [#[[MAP3]], #[[MAP4]]]
+// CHECK-SAME:     tensor<?x?xf32> into tensor<?x?x?x?xf32>
+//      CHECK:   %[[T2:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]]]
+// CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME:     ins(%[[T0]], %[[T1]] : tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+//      CHECK:   return %[[T2]] : tensor<?x?x?x?x?x?xf32>
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d2)>
+
+func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
+                                            -> tensor<8x33x4xf32> {
+  %cst = constant dense<2.000000e+00> : tensor<264x4xf32>
+  %0 = linalg.generic {
+     indexing_maps = [#map0, #map0, #map0],
+     iterator_types = ["parallel", "parallel"]}
+      ins(%arg0, %cst : tensor<264x4xf32>, tensor<264x4xf32>) {
+    ^bb0(%arg1: f32, %arg2: f32):  // no predecessors
+      %2 = mulf %arg1, %arg2 : f32
+      linalg.yield %2 : f32
+    } -> tensor<264x4xf32>
+  %1 = linalg.tensor_reshape %0 [#map1, #map2] :
+    tensor<264x4xf32> into tensor<8x33x4xf32>
+  return %1 : tensor<8x33x4xf32>
+}
+
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+//      CHECK: func @generic_op_reshape_consumer_static
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<264x4xf32>
+//      CHECK:   %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
+// CHECK-SAME:     [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME:     tensor<264x4xf32> into tensor<8x33x4xf32>
+//      CHECK:   %[[T1:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP2]]]
+// CHECK-SAME:     ["parallel", "parallel", "parallel"]
+// CHECK-SAME:     ins(%[[T0]] : tensor<8x33x4xf32>)
+//      CHECK:   return %[[T1]] : tensor<8x33x4xf32>
+
+// -----
+
+func @scalar_reshape(%arg0 : tensor<1x10xf32>, %arg1 : tensor<1xf32>)
+                     -> tensor<1x10xf32> {
+  %0 = linalg.tensor_reshape %arg1 [] : tensor<1xf32> into tensor<f32>
+  %1 = linalg.generic
+    {indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>],
+     iterator_types = ["parallel"]} ins(%0 : tensor<f32>) {
+  ^bb0(%arg2: f32):  // no predecessors
+    linalg.yield %arg2 : f32
+  } -> tensor<10xf32>
+  %2 = linalg.tensor_reshape %1 [affine_map<(d0, d1) -> (d0, d1)>]
+    : tensor<10xf32> into tensor<1x10xf32>
+  return %2 : tensor<1x10xf32>
+}
+
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> ()>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//      CHECK: func @scalar_reshape
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x10xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<1xf32>
+//      CHECK:   %[[T0:.+]] = linalg.tensor_reshape %[[ARG1]] []
+// CHECK-SAME:     tensor<1xf32> into tensor<f32>
+//      CHECK:   %[[T1:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME:     iterator_types = ["parallel", "parallel"]
+// CHECK-SAME:     ins(%[[T0]] : tensor<f32>)
+//      CHECK:   return %[[T1]] : tensor<1x10xf32>

diff  --git a/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir
new file mode 100644
index 000000000000..468ae80a1288
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir
@@ -0,0 +1,241 @@
+// RUN: mlir-opt -split-input-file -linalg-fold-reshape-ops-by-linearization %s | FileCheck %s
+
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xf32>,
+                                         %arg1 : tensor<?x?x4x?xf32>) ->
+                                         tensor<?x?x4x?xf32>
+{
+  %0 = linalg.tensor_reshape %arg0 [affine_map<(i, j, k, l) -> (i)>,
+                                    affine_map<(i, j, k, l) -> (j, k)>,
+                                    affine_map<(i, j, k, l) -> (l)>] :
+    tensor<?x?x?xf32> into tensor<?x?x4x?xf32>
+  %1 = linalg.generic {
+     indexing_maps = [#map0, #map0, #map0],
+     iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+      ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) {
+    ^bb0(%arg3: f32, %arg4: f32):       // no predecessors
+      %1 = mulf %arg3, %arg4 : f32
+      linalg.yield %1 : f32
+  } -> tensor<?x?x4x?xf32>
+  return %1 : tensor<?x?x4x?xf32>
+}
+
+// CHECK-LABEL: func @generic_op_reshape_producer_fusion
+//       CHECK: linalg.generic
+//  CHECK-SAME:   indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]]
+//   CHECK-NOT: linalg.generic
+
+
+// -----
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xf32>,
+                                         %arg1 : tensor<?x?x4x5xf32>) ->
+                                         tensor<?x?xf32>
+{
+  %0 = linalg.generic {
+     indexing_maps = [#map0, #map0, #map0],
+     iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+      ins(%arg0, %arg1 : tensor<?x?x4x5xf32>, tensor<?x?x4x5xf32>) {
+    ^bb0(%arg3: f32, %arg4: f32):       // no predecessors
+      %1 = mulf %arg3, %arg4 : f32
+      linalg.yield %1 : f32
+  } -> tensor<?x?x4x5xf32>
+  %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
+                                 affine_map<(i, j, k, l) -> (j, k, l)>] :
+    tensor<?x?x4x5xf32> into tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: func @generic_op_reshape_consumer_fusion
+//       CHECK: linalg.generic
+//  CHECK-SAME:   indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP1]]]
+//   CHECK-NOT: linalg.generic
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+func @generic_op_reshape_consumer_nofusion(%arg0 : tensor<?x?x?x5xf32>,
+                                           %arg1 : tensor<?x?x?x5xf32>) ->
+                                           tensor<?x?xf32>
+{
+  %0 = linalg.generic {
+     indexing_maps = [#map0, #map0, #map0],
+     iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+      ins(%arg0, %arg1 : tensor<?x?x?x5xf32>, tensor<?x?x?x5xf32>) {
+    ^bb0(%arg3: f32, %arg4: f32):       // no predecessors
+      %1 = mulf %arg3, %arg4 : f32
+      linalg.yield %1 : f32
+  } -> tensor<?x?x?x5xf32>
+  %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
+                                 affine_map<(i, j, k, l) -> (j, k, l)>] :
+    tensor<?x?x?x5xf32> into tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: func @generic_op_reshape_consumer_nofusion
+//       CHECK: linalg.tensor_reshape
+
+// -----
+
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+func @indexed_generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xi32>)
+  -> tensor<?x?x4x?xi32> {
+  %0 = linalg.tensor_reshape %arg0 [affine_map<(i, j, k, l) -> (i)>,
+                                    affine_map<(i, j, k, l) -> (j, k)>,
+                                    affine_map<(i, j, k, l) -> (l)>] :
+    tensor<?x?x?xi32> into tensor<?x?x4x?xi32>
+  %1 = linalg.indexed_generic {
+    indexing_maps = [#map0, #map0],
+    iterator_types = ["parallel", "parallel", "parallel", "parallel"] }
+    ins(%0 : tensor<?x?x4x?xi32>) {
+  ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: i32):       // no predecessors
+    %2 = index_cast %arg2 : index to i32
+    %3 = addi %arg6, %2 : i32
+    linalg.yield %3 : i32
+  } -> tensor<?x?x4x?xi32>
+  return %1 : tensor<?x?x4x?xi32>
+}
+
+// CHECK-LABEL: func @indexed_generic_op_reshape_producer_fusion
+//   CHECK-NOT: linalg.tensor_reshape
+//       CHECK: linalg.indexed_generic
+//  CHECK-SAME:   indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
+//   CHECK-NOT: linalg.tensor_reshape
+
+// -----
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xi32>)
+  -> tensor<?x?xi32> {
+  %0 = linalg.indexed_generic {
+    indexing_maps = [#map0, #map0],
+    iterator_types = ["parallel", "parallel", "parallel", "parallel"] }
+    ins(%arg0 : tensor<?x?x4x5xi32>) {
+  ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: i32):       // no predecessors
+    %2 = index_cast %arg2 : index to i32
+    %3 = addi %arg6, %2 : i32
+    linalg.yield %3 : i32
+  } -> tensor<?x?x4x5xi32>
+  %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
+                                 affine_map<(i, j, k, l) -> (j, k, l)>] :
+    tensor<?x?x4x5xi32> into tensor<?x?xi32>
+  return %1 : tensor<?x?xi32>
+}
+
+// CHECK-LABEL: func @indexed_generic_op_reshape_consumer_fusion
+//   CHECK-NOT: linalg.tensor_reshape
+//       CHECK: linalg.indexed_generic
+//  CHECK-SAME:   indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
+//   CHECK-NOT: linalg.tensor_reshape
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 + d2 * 7)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+#map0 = affine_map<(d0, d1, d2) -> (d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func @generic_op_021_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<3x7x5xf32> {
+  %0 = linalg.tensor_reshape %arg0 [#map0, #map1] : tensor<3x35xf32> into tensor<3x5x7xf32>
+  %1 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<3x5x7xf32>) {
+    ^bb0(%arg2: f32):  // no predecessors
+      linalg.yield %arg2 : f32
+    } -> tensor<3x7x5xf32>
+    return %1 : tensor<3x7x5xf32>
+}
+
+// CHECK-LABEL: func @generic_op_021_permultation_reshape_producer_fusion
+//   CHECK-NOT: linalg.tensor_reshape
+//       CHECK: linalg.generic
+//  CHECK-SAME:   indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
+//   CHECK-NOT: linalg.tensor_reshape
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0 * 7 + d1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+#map0 = affine_map<(d0, d1, d2) -> (d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func @generic_op_120_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x7x3xf32> {
+  %0 = linalg.tensor_reshape %arg0 [#map0, #map1] : tensor<3x35xf32> into tensor<3x5x7xf32>
+  %1 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<3x5x7xf32>) {
+    ^bb0(%arg2: f32):  // no predecessors
+      linalg.yield %arg2 : f32
+    } -> tensor<5x7x3xf32>
+    return %1 : tensor<5x7x3xf32>
+}
+
+// CHECK-LABEL: func @generic_op_120_permultation_reshape_producer_fusion
+//   CHECK-NOT: linalg.tensor_reshape
+//       CHECK: linalg.generic
+//  CHECK-SAME:   indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
+//   CHECK-NOT: linalg.tensor_reshape
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+#map0 = affine_map<(d0, d1, d2) -> (d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func @generic_op_102_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x3x7xf32> {
+  %0 = linalg.tensor_reshape %arg0 [#map0, #map1] : tensor<3x35xf32> into tensor<3x5x7xf32>
+  %1 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<3x5x7xf32>) {
+    ^bb0(%arg2: f32):  // no predecessors
+      linalg.yield %arg2 : f32
+    } -> tensor<5x3x7xf32>
+    return %1 : tensor<5x3x7xf32>
+}
+
+// CHECK-LABEL: func @generic_op_102_permultation_reshape_producer_fusion
+//   CHECK-NOT: linalg.tensor_reshape
+//       CHECK: linalg.generic
+//  CHECK-SAME:   indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
+//   CHECK-NOT: linalg.tensor_reshape
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)>
+
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0)>
+#map3 = affine_map<(d0, d1, d2) -> (d1, d2)>
+func @generic_op_102_permultation_reshape_consumer_fusion(%arg0 : tensor<3x5x7xf32>) -> tensor<5x21xf32> {
+  %0 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<3x5x7xf32>) {
+    ^bb0(%arg2: f32):  // no predecessors
+      linalg.yield %arg2 : f32
+  } -> tensor<5x3x7xf32>
+  %1 = linalg.tensor_reshape %0 [#map2, #map3] : tensor<5x3x7xf32> into tensor<5x21xf32>
+  return %1 : tensor<5x21xf32>
+}
+
+// CHECK-LABEL: func @generic_op_102_permultation_reshape_consumer_fusion
+//   CHECK-NOT: linalg.tensor_reshape
+//       CHECK: linalg.generic
+//  CHECK-SAME:   indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
+//   CHECK-NOT: linalg.tensor_reshape


        


More information about the Mlir-commits mailing list