[Mlir-commits] [mlir] cf194da - [mlir][linalg] Remove IndexedGenericOp support from FusionOnTensors...

Tobias Gysi llvmlistbot at llvm.org
Thu May 13 07:58:09 PDT 2021


Author: Tobias Gysi
Date: 2021-05-13T14:57:16Z
New Revision: cf194da1bbf79d392688dba0c74875829e9873f2

URL: https://github.com/llvm/llvm-project/commit/cf194da1bbf79d392688dba0c74875829e9873f2
DIFF: https://github.com/llvm/llvm-project/commit/cf194da1bbf79d392688dba0c74875829e9873f2.diff

LOG: [mlir][linalg] Remove IndexedGenericOp support from FusionOnTensors...

after introducing the IndexedGenericOp to GenericOp canonicalization (https://reviews.llvm.org/D101612).

Differential Revision: https://reviews.llvm.org/D102163

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
    mlir/test/Dialect/Linalg/fusion-tensor.mlir
    mlir/test/Dialect/Linalg/reshape_fusion.mlir
    mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 93c99038e322a..501c34f5c46b0 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -57,17 +57,16 @@ void populateFoldReshapeOpsByExpansionPatterns(
     ControlElementwiseOpsFusionFn controlFoldingReshapes = skipUnitDimReshape);
 
 /// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
-/// producer (consumer) generic/indexed_generic operation by linearizing the
-/// indexing map used to access the source (target) of the reshape operation in
-/// the generic/indexed_generic operation.
+/// producer (consumer) generic operation by linearizing the indexing map used
+/// to access the source (target) of the reshape operation in the generic
+/// operation.
 void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns);
 
 /// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
-/// producer (consumer) generic/indexed_generic operation by linearizing the
-/// indexing map used to access the source (target) of the reshape operation in
-/// the generic/indexed_generic operation. The patterns are applied only when
-/// the tensor reshape involved is collapsing (introducing) unit-extent
-/// dimensions.
+/// producer (consumer) generic operation by linearizing the indexing map used
+/// to access the source (target) of the reshape operation in the generic
+/// operation. The patterns are applied only when the tensor reshape involved is
+/// collapsing (introducing) unit-extent dimensions.
 void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
     RewritePatternSet &patterns);
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 6a0c597b7c3cd..4ecc72ca3d094 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -26,8 +26,8 @@
 using namespace mlir;
 using namespace mlir::linalg;
 
-/// Implementation of fusion of generic ops and indexed_generic ops.
-static bool areElementwiseOpsFusable(LinalgOp producer, LinalgOp consumer,
+/// Conditions for elementwise fusion of generic operations.
+static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
                                      unsigned consumerIdx) {
   // Producer and consumer must have tensor semantics.
   if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
@@ -95,57 +95,20 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
 /// Generate the region of the fused tensor operation. The region of the fused
 /// op must be empty.
 static void
-generateFusedElementwiseOpRegion(PatternRewriter &rewriter, Operation *fusedOp,
-                                 LinalgOp producer, LinalgOp consumer,
+generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
+                                 GenericOp producer, GenericOp consumer,
                                  AffineMap consumerToProducerLoopsMap,
                                  unsigned consumerIdx, unsigned nloops) {
   // Build the region of the fused op.
   Block &producerBlock = producer->getRegion(0).front();
   Block &consumerBlock = consumer->getRegion(0).front();
   Block *fusedBlock = new Block();
-  fusedOp->getRegion(0).push_back(fusedBlock);
+  fusedOp.region().push_back(fusedBlock);
   BlockAndValueMapping mapper;
   OpBuilder::InsertionGuard guard(rewriter);
   rewriter.setInsertionPointToStart(fusedBlock);
 
-  // The block arguments are
-  // [index_0, index_1, ... ,
-  //   consumer_operand_0, ... , consumer_operand_(`consumerIdx`-1),
-  //   producer_operand_0, ... , producer_operand_(n-1)],
-  //   consumer_operand_(`consumerIdx`), .. consumer_operand_(m-1)]
-  // , where n is the number of producer's operand and m is the number
-  // consumer's operand.
-  // If both `numProducerIndices` and `numConsumerIndices` are zero, this is a
-  // generic op. In this case, there are no indices in block arguments.
-  unsigned numProducerIndices = isa<IndexedGenericOp>(producer.getOperation())
-                                    ? producer.getNumLoops()
-                                    : 0;
-  unsigned numConsumerIndices = isa<IndexedGenericOp>(consumer.getOperation())
-                                    ? consumer.getNumLoops()
-                                    : 0;
-  unsigned numFusedOpIndices =
-      (isa<IndexedGenericOp>(producer.getOperation()) ||
-       isa<IndexedGenericOp>(consumer.getOperation()))
-          ? std::max(producer.getNumLoops(), consumer.getNumLoops())
-          : 0;
-
-  // 0. Firstly, add all the indices to the block arguments.
-  for (unsigned i = 0, e = numFusedOpIndices; i < e; ++i)
-    fusedBlock->addArgument(rewriter.getIndexType());
-  // 1. Map consumer indices to fusedBlock indices 1-1.
-  mapper.map(consumerBlock.getArguments().take_front(numConsumerIndices),
-             fusedBlock->getArguments().take_front(numConsumerIndices));
-  // 2a. Embed producer indices into fusedBlock index space 1-1.
-  for (auto it :
-       llvm::zip(producerBlock.getArguments().take_front(numProducerIndices),
-                 fusedBlock->getArguments().take_front(numProducerIndices))) {
-    auto newIndex = rewriter.create<mlir::AffineApplyOp>(
-        producer.getLoc(),
-        consumerToProducerLoopsMap.getSubMap(std::get<0>(it).getArgNumber()),
-        fusedBlock->getArguments().take_front(numFusedOpIndices));
-    mapper.map(std::get<0>(it), newIndex);
-  }
-  // 2b. Add an index operation for every fused loop dimension and use the
+  // 2. Add an index operation for every fused loop dimension and use the
   // `consumerToProducerLoopsMap` to map the producer indices.
   if (producer.hasIndexSemantics()) {
     // Add an index operation for every fused loop dimension.
@@ -169,34 +132,30 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, Operation *fusedOp,
   assert(consumerIdx < consumer.getNumInputs() &&
          "expected producer of input operand");
   // 3. Consumer input operands up to consumerIdx (exclusive).
-  for (BlockArgument bbArg : consumerBlock.getArguments()
-                                 .drop_front(numConsumerIndices)
-                                 .take_front(consumerIdx)) // input assumption.
+  for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
+           consumerIdx)) // input assumption.
     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
 
   // Replacing consumerIdx requires getting the cloned, yielded, value from
   // the (cloned) producer block. This happens in step 9.
 
   // 4. Splice in producer's input operands.
-  for (BlockArgument bbArg : producerBlock.getArguments()
-                                 .drop_front(numProducerIndices)
-                                 .take_front(producer.getNumInputs()))
+  for (BlockArgument bbArg :
+       producerBlock.getArguments().take_front(producer.getNumInputs()))
     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
 
   // 4.b. Producer output operand/map that is fused needs to be mapped to the
   // producer bbArg if it is an "initTensor" (i.e. its value is actually read).
   assert(producer->getNumResults() == 1 && "expected single result producer");
   if (producer.isInitTensor(&producer.getOutputOpOperands()[0])) {
-    BlockArgument bbArg =
-        producerBlock.getArguments()
-            .drop_front(numConsumerIndices + producer.getNumInputs())
-            // TODO: bbArg index of
-            .front();
+    BlockArgument bbArg = producerBlock.getArguments()
+                              .drop_front(producer.getNumInputs())
+                              // TODO: bbArg index of
+                              .front();
     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
   }
   // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
   for (BlockArgument bbArg : consumerBlock.getArguments()
-                                 .drop_front(numConsumerIndices)
                                  .take_front(consumer.getNumInputs())
                                  .drop_front(consumerIdx + 1))
     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
@@ -232,23 +191,21 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, Operation *fusedOp,
       assert(!producer->isAncestor(replacement.getDefiningOp()) &&
              "yielded value must have been mapped");
   }
-  mapper.map(consumerBlock.getArgument(consumerIdx + numConsumerIndices),
-             replacement);
+  mapper.map(consumerBlock.getArgument(consumerIdx), replacement);
   // 10. Clone operations from the consumer to the fused op.
   for (auto &op : consumerBlock.getOperations())
     rewriter.clone(op, mapper);
 
   // Sanity checks.
-  assert(fusedBlock->getNumArguments() ==
-             fusedOp->getNumOperands() + numFusedOpIndices &&
-         "Ill-formed LinalgOp region");
+  assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
+         "Ill-formed GenericOp region");
 }
 
-static Optional<SmallVector<Value, 1>>
-fuseElementwiseOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
+static Optional<SmallVector<Value>>
+fuseElementwiseOpsImpl(GenericOp producer, OpOperand &consumerOpOperand,
                        const ControlElementwiseOpsFusionFn &controlFn,
                        PatternRewriter &rewriter) {
-  LinalgOp consumer = cast<LinalgOp>(consumerOpOperand.getOwner());
+  auto consumer = cast<GenericOp>(consumerOpOperand.getOwner());
   unsigned consumerIdx = consumerOpOperand.getOperandNumber();
   if (!areElementwiseOpsFusable(producer, consumer, consumerIdx) ||
       !controlFn(producer->getResult(0), consumerOpOperand))
@@ -311,27 +268,14 @@ fuseElementwiseOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
   assert(producer->getNumResults() == 1 && "expected single result producer");
 
   // Generate the fused op.
-  Operation *fusedOp;
-  if (isa<GenericOp>(producer.getOperation()) &&
-      isa<GenericOp>(consumer.getOperation())) {
-    fusedOp = rewriter.create<GenericOp>(
-        consumer.getLoc(), consumer->getResultTypes(),
-        /*inputs=*/fusedOperands,
-        // TODO: handle outputs.
-        consumer.getOutputs(), rewriter.getAffineMapArrayAttr(fusedIndexMaps),
-        consumer.iterator_types(),
-        /*doc=*/nullptr,
-        /*library_call=*/nullptr);
-  } else {
-    fusedOp = rewriter.create<IndexedGenericOp>(
-        consumer.getLoc(), consumer->getResultTypes(),
-        /*inputs=*/fusedOperands,
-        // TODO: handle outputs.
-        consumer.getOutputs(), rewriter.getAffineMapArrayAttr(fusedIndexMaps),
-        consumer.iterator_types(),
-        /*doc=*/nullptr,
-        /*library_call=*/nullptr);
-  }
+  auto fusedOp = rewriter.create<GenericOp>(
+      consumer.getLoc(), consumer->getResultTypes(),
+      /*inputs=*/fusedOperands,
+      // TODO: handle outputs.
+      consumer.getOutputs(), rewriter.getAffineMapArrayAttr(fusedIndexMaps),
+      consumer.iterator_types(),
+      /*doc=*/nullptr,
+      /*library_call=*/nullptr);
 
   // Construct an AffineMap from consumer loops to producer loops.
   // consumer loop -> tensor index
@@ -348,7 +292,7 @@ fuseElementwiseOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
   generateFusedElementwiseOpRegion(rewriter, fusedOp, producer, consumer,
                                    consumerToProducerLoopsMap, consumerIdx,
                                    consumer.getNumLoops());
-  return SmallVector<Value, 1>(fusedOp->getResults());
+  return SmallVector<Value>(fusedOp->getResults());
 }
 
 /// Linearize the expressions in `sourceMap` based on the `reassociationMaps`
@@ -373,7 +317,7 @@ fuseElementwiseOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
 static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
                                         ArrayRef<int64_t> sourceShape,
                                         ArrayRef<AffineMap> reassociationMaps) {
-  SmallVector<AffineExpr, 4> resultExprs;
+  SmallVector<AffineExpr> resultExprs;
   resultExprs.reserve(reassociationMaps.size());
   ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults();
   MLIRContext *context = sourceMap.getContext();
@@ -386,8 +330,8 @@ static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
     assert(!collapsedDims.empty());
     unsigned startDim =
         collapsedDims.front().cast<AffineDimExpr>().getPosition();
-    SmallVector<int64_t, 4> sizes;
-    SmallVector<AffineExpr, 4> dimExprs;
+    SmallVector<int64_t> sizes;
+    SmallVector<AffineExpr> dimExprs;
     for (auto en :
          llvm::zip(sourceShape.slice(startDim, collapsedDims.size()),
                    sourceExprs.slice(startDim, collapsedDims.size()))) {
@@ -426,22 +370,6 @@ static bool isTensorReshapeOpFoldableByLinearization(TensorReshapeOp reshapeOp,
   return useIndexMap.isPermutation();
 }
 
-/// Based on the type of `op` create a linalg op of the same type, i.e. if `op`
-/// is a linalg.generic operation, the create a `linalg.generic` operation with
-/// the given `args`. Expects `op` to be `linalg.generic` or
-/// `linalg.indexed_generic`.
-template <typename... Args>
-static LinalgOp createLinalgOpOfSameType(LinalgOp op, PatternRewriter &rewriter,
-                                         Args... args) {
-  if (isa<GenericOp>(op.getOperation()))
-    return rewriter.create<GenericOp>(args...);
-  if (isa<IndexedGenericOp>(op.getOperation()))
-    return rewriter.create<IndexedGenericOp>(args...);
-  llvm_unreachable(
-      "expected only linalg.generic or linalg.indexed_generic ops");
-  return nullptr;
-}
-
 /// Check if the reshape operation is only expansion into/collapsing of
 /// unit-dimension.
 static bool isUnitDimExpansionOnly(ArrayRef<int64_t> expandedShape,
@@ -459,10 +387,10 @@ static bool isUnitDimExpansionOnly(ArrayRef<int64_t> expandedShape,
   return true;
 }
 
-/// Conditions for folding a generic/indexed-generic operation with a reshape op
-/// by expanding the iteration space dimensionality for tensor operations. These
-/// are preconditions assumed by `foldReshapeByDimExpansion` which implements
-/// the following fusion pattern.
+/// Conditions for folding a generic operation with a reshape op by expanding
+/// the iteration space dimensionality for tensor operations. These are
+/// preconditions assumed by `foldReshapeByDimExpansion` which implements the
+/// following fusion pattern.
 ///
 ///  Consider
 ///
@@ -476,12 +404,12 @@ static bool isUnitDimExpansionOnly(ArrayRef<int64_t> expandedShape,
 ///          affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>]
 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
 ///
-///  The reshape can be folded into the `linalgOp` if the
-///  generic/indexed-generic op loop dimensionality is increased to match the
-///  result (operand) of the tensor_reshape when the reshape is expanding
-///  (folding). The indexing_map of the fused tensor in the `linalgOp` and the
-///  reassociation map helps compute the indexing maps of the modified op. For
-///  the above example, based on the reassociation map it can be concluded that
+///  The reshape can be folded into the `genericOp` if its loop dimensionality
+///  is increased to match the result (operand) of the tensor_reshape when the
+///  reshape is expanding (folding). The indexing_map of the fused tensor in the
+///  `genericOp` and the reassociation map helps compute the indexing maps of
+///  the modified op. For the above example, based on the reassociation map it
+///  can be concluded that
 ///
 ///  - The loop used to access the first dimension of the fused tensor is split
 ///    into two.
@@ -520,41 +448,40 @@ static bool isUnitDimExpansionOnly(ArrayRef<int64_t> expandedShape,
 ///
 ///  The added reshapes are again expanding patterns, so they will get fused
 ///  with its producers if possible.
-static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
+static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
                                                unsigned fusedTensorIndex) {
   // Is fusable only if:
-  // - The linalgOp is a generic op, or an indexed_generic.
-  // - All the indexing maps for operands and results in linalgOp are projected
+  // - All the indexing maps for operands and results are projected
   //   permutations.
   // - The fused tensor is not a scalar.
-  // - All the loops in linalgOp are parallel loops.
-  return isa<GenericOp, IndexedGenericOp>(linalgOp.getOperation()) &&
-         linalgOp.hasTensorSemantics() &&
-         llvm::all_of(linalgOp.indexing_maps().getValue(),
+  // - All the loops are parallel loops.
+  return genericOp.hasTensorSemantics() &&
+         llvm::all_of(genericOp.indexing_maps().getValue(),
                       [](Attribute attr) {
                         return attr.cast<AffineMapAttr>()
                             .getValue()
                             .isProjectedPermutation();
                       }) &&
-         linalgOp.getIndexingMap(fusedTensorIndex).getNumResults() > 0 &&
-         llvm::all_of(linalgOp.iterator_types(), [](Attribute attr) {
+         genericOp.getIndexingMap(fusedTensorIndex).getNumResults() > 0 &&
+         llvm::all_of(genericOp.iterator_types(), [](Attribute attr) {
            return attr.cast<StringAttr>().getValue() ==
                   getParallelIteratorTypeName();
          });
 }
 
 namespace {
-/// Information needed to expand a generic/indexed_generic operation to fold the
-/// reshape with it.
+/// Information needed to expand a generic operation to fold the reshape with
+/// it.
 class ExpansionInfo {
 public:
   // Computes the mapping from original dimensions of the op to the dimensions
   // of the expanded op given the `indexingMap` of the fused operand/result of
-  // the generic/indexed_generic op, the `reassocationMaps` of the reshape op
-  // and the shape of the expanded op.
+  // the generic op, the `reassocationMaps` of the reshape op and the shape of
+  // the expanded op.
   LogicalResult compute(LinalgOp linalgOp, unsigned fusedTensorIndex,
                         ArrayRef<AffineMap> reassociationMaps,
-                        ArrayRef<int64_t> expandedShape);
+                        ArrayRef<int64_t> expandedShape,
+                        PatternRewriter &rewriter);
   unsigned getOrigOpNumDims() const { return reassociation.size(); }
   unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
   ReassociationIndicesRef getExpandedDims(unsigned i) const {
@@ -567,10 +494,10 @@ class ExpansionInfo {
 private:
   /// Reassociation from the dimensions in the original operation to the
   /// dimension of the expanded operation.
-  SmallVector<ReassociationIndices, 4> reassociation;
+  SmallVector<ReassociationIndices> reassociation;
   /// Mapping from extent of loops in the original operation, to the extent of
   /// loops in the expanded operation.
-  SmallVector<SmallVector<int64_t, 4>, 4> expandedShapeMap;
+  SmallVector<SmallVector<int64_t>> expandedShapeMap;
   unsigned expandedOpNumDims;
 };
 } // namespace
@@ -578,7 +505,8 @@ class ExpansionInfo {
 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
                                      unsigned fusedTensorIndex,
                                      ArrayRef<AffineMap> reassociationMaps,
-                                     ArrayRef<int64_t> expandedShape) {
+                                     ArrayRef<int64_t> expandedShape,
+                                     PatternRewriter &rewriter) {
   if (reassociationMaps.empty())
     return failure();
   AffineMap fusedIndexMap = linalgOp.getIndexingMap(fusedTensorIndex);
@@ -586,13 +514,13 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
   Optional<SmallVector<int64_t, 4>> originalLoopRange =
       linalgOp.getStaticLoopRanges();
   if (!originalLoopRange)
-    return linalgOp.emitError("unable to find loop range for operation");
+    return rewriter.notifyMatchFailure(linalgOp, "unable to find loop range");
 
   reassociation.clear();
   expandedShapeMap.clear();
   // Compute the number of dimension in the expanded op that correspond to each
   // dimension of the original op.
-  SmallVector<unsigned, 4> numExpandedDims(fusedIndexMap.getNumDims(), 1);
+  SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
   expandedShapeMap.resize(fusedIndexMap.getNumDims());
   for (auto resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
     unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition();
@@ -627,17 +555,19 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
 /// Note that this could be extended to handle dynamic case, but the
 /// implementation below uses `affine.apply` which seems to have issues when the
 /// shapes are not static.
-LogicalResult isIndexedOpExpandable(LinalgOp linalgOp,
-                                    const ExpansionInfo &expansionInfo) {
+LogicalResult isGenericOpExpandable(GenericOp genericOp,
+                                    const ExpansionInfo &expansionInfo,
+                                    PatternRewriter &rewriter) {
+  if (!genericOp.hasIndexSemantics())
+    return success();
   for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
     ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
     if (expandedShape.size() == 1)
       continue;
     for (int64_t shape : expandedShape.drop_front()) {
       if (ShapedType::isDynamic(shape)) {
-        return linalgOp.emitError(
-            "unable to fuse indexed generic op where the expanded dim is "
-            "dynamic");
+        return rewriter.notifyMatchFailure(
+            genericOp, "cannot expand due to index semantics and dynamic dims");
       }
     }
   }
@@ -649,7 +579,7 @@ LogicalResult isIndexedOpExpandable(LinalgOp linalgOp,
 static AffineMap
 getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
                            const ExpansionInfo &expansionInfo) {
-  SmallVector<AffineExpr, 4> newExprs;
+  SmallVector<AffineExpr> newExprs;
   for (AffineExpr expr : indexingMap.getResults()) {
     unsigned pos = expr.cast<AffineDimExpr>().getPosition();
     SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
@@ -668,7 +598,7 @@ getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
 static RankedTensorType getExpandedType(RankedTensorType originalType,
                                         AffineMap indexingMap,
                                         const ExpansionInfo &expansionInfo) {
-  SmallVector<int64_t, 4> expandedShape;
+  SmallVector<int64_t> expandedShape;
   for (AffineExpr expr : indexingMap.getResults()) {
     unsigned dim = expr.cast<AffineDimExpr>().getPosition();
     auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
@@ -682,15 +612,15 @@ static RankedTensorType getExpandedType(RankedTensorType originalType,
 /// the expanded operation. The same method is used to compute the
 /// `linalg.tensor_reshape` used to collapse the result of the expanded op to
 /// get the value that can replace all uses of the results of the original op.
-static SmallVector<ReassociationIndices, 4>
+static SmallVector<ReassociationIndices>
 getReassociationForExpansion(AffineMap indexingMap,
                              const ExpansionInfo &expansionInfo) {
-  SmallVector<ReassociationIndices, 4> reassociation;
+  SmallVector<ReassociationIndices> reassociation;
   unsigned numReshapeDims = 0;
   for (AffineExpr expr : indexingMap.getResults()) {
     unsigned dim = expr.cast<AffineDimExpr>().getPosition();
     auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
-    auto indices = llvm::to_vector<2>(
+    SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
         llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
     reassociation.emplace_back(std::move(indices));
     numReshapeDims += numExpandedDims;
@@ -698,66 +628,14 @@ getReassociationForExpansion(AffineMap indexingMap,
   return reassociation;
 }
 
-/// Build the body of the expanded IndexedGenericOp. The arguments for the
-/// induction variables of the original operation need to be recovered by
-/// linearizing the arguments of the corresponding dimensions of the expanded
-/// op. For now it is assumed that the shapes of the expanded op needed for
-/// linearization are static.
-static void buildExpandedIndexedGenericOpRegion(
-    PatternRewriter &rewriter, Location loc, Region &originalOpRegion,
-    Region &fusedOpRegion, const ExpansionInfo &expansionInfo) {
-  assert(fusedOpRegion.empty() && "expected fused op to have empty region");
-  // Create an entry block in the fused region with same number of arguments
-  // as the fused op
-  Block *fusedEntryBlock = new Block;
-  fusedOpRegion.push_back(fusedEntryBlock);
-  rewriter.cloneRegionBefore(originalOpRegion, fusedOpRegion,
-                             fusedOpRegion.end());
-
-  // Merge the entry block of the fused op with the cloned blocks. For this
-  // compute the value for arguments of the region in the original operation
-  // in terms of the arguments of the fused op. Since the original operation
-  // is expanded, the expanded dimensions need to be folded back to get the
-  // replacement value for the arguments corresponding to interation index.
-  // For now this expects that all the loop ranges are constants, which is
-  // true if the shapes are all static. This has already been checked in the
-  // precondition.
-  using namespace edsc::op;
-  using namespace edsc::intrinsics;
-  OpBuilder::InsertionGuard guard(rewriter);
-  SmallVector<Value, 4> argReplacements(originalOpRegion.getNumArguments());
-  rewriter.setInsertionPointToStart(fusedEntryBlock);
-  edsc::ScopedContext scopedContext(rewriter, loc);
-  IndexType indexType = rewriter.getIndexType();
-  for (auto i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
-    Value linearizedIndex = fusedEntryBlock->addArgument(indexType);
-    ArrayRef<int64_t> expandedDimsShape =
-        expansionInfo.getExpandedShapeOfDim(i).drop_front();
-    for (unsigned shape : expandedDimsShape) {
-      assert(!ShapedType::isDynamic(shape));
-      linearizedIndex = linearizedIndex * std_constant_index(shape);
-      linearizedIndex =
-          linearizedIndex + fusedEntryBlock->addArgument(indexType);
-    }
-    argReplacements[i] = linearizedIndex;
-  }
-  for (auto i : llvm::seq<unsigned>(expansionInfo.getOrigOpNumDims(),
-                                    argReplacements.size())) {
-    argReplacements[i] =
-        fusedEntryBlock->addArgument(originalOpRegion.getArgument(i).getType());
-  }
-  rewriter.mergeBlocks(fusedEntryBlock->getNextNode(), fusedEntryBlock,
-                       argReplacements);
-}
-
 /// Update the body of an expanded linalg operation having index semantics. The
 /// indices of the original operation need to be recovered by linearizing the
 /// indices of the correspoding dimensions of the expanded operation. For now it
 /// is assumed that the shapes of the expanded operation needed for
 /// linearization are static.
-static void updateExpandedIndexOpRegion(PatternRewriter &rewriter, Location loc,
-                                        Region &fusedRegion,
-                                        const ExpansionInfo &expansionInfo) {
+static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
+                                          Location loc, Region &fusedRegion,
+                                          const ExpansionInfo &expansionInfo) {
   // Replace the original indices by the linearization of the expanded indices.
   for (IndexOp indexOp :
        llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
@@ -793,112 +671,100 @@ static void updateExpandedIndexOpRegion(PatternRewriter &rewriter, Location loc,
   }
 }
 
-/// Implements the fusion of a tensor_reshape op and a generic/indexed_generic
-/// op as explained in `isFusableWithReshapeByExpansion`. Assumes that those
-/// conditions have been satisfied.
-static Optional<SmallVector<Value, 1>>
-fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
+/// Implements the fusion of a tensor_reshape op and a generic op as explained
+/// in `isFusableWithReshapeByExpansion`. Assumes that those conditions have
+/// been satisfied.
+static Optional<SmallVector<Value>>
+fuseWithReshapeByExpansion(GenericOp genericOp, TensorReshapeOp reshapeOp,
                            unsigned fusedTensorIndex,
                            PatternRewriter &rewriter) {
-  assert(isFusableWithReshapeByDimExpansion(linalgOp, fusedTensorIndex) &&
+  assert(isFusableWithReshapeByDimExpansion(genericOp, fusedTensorIndex) &&
          "preconditions for fuse operation failed");
   // Check if reshape is expanding or collapsing.
   bool isExpanding =
       reshapeOp.getSrcType().getRank() < reshapeOp.getResultType().getRank();
   RankedTensorType expandedType =
       isExpanding ? reshapeOp.getResultType() : reshapeOp.getSrcType();
-  bool hasIndexSemantics = linalgOp.hasIndexSemantics() ||
-                           isa<IndexedGenericOp>(linalgOp.getOperation());
 
   ExpansionInfo expansionInfo;
-  if (failed(expansionInfo.compute(linalgOp, fusedTensorIndex,
+  if (failed(expansionInfo.compute(genericOp, fusedTensorIndex,
                                    reshapeOp.getReassociationMaps(),
-                                   expandedType.getShape())))
+                                   expandedType.getShape(), rewriter)))
     return llvm::None;
 
-  if (hasIndexSemantics &&
-      failed(isIndexedOpExpandable(linalgOp, expansionInfo)))
+  if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter)))
     return llvm::None;
 
   SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
-      llvm::map_range(linalgOp.getIndexingMaps(), [&](AffineMap m) {
+      llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap m) {
         return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
       }));
 
-  SmallVector<Value, 4> expandedOpOperands;
-  for (auto operand : llvm::enumerate(linalgOp.getInputs())) {
+  SmallVector<Value> expandedOpOperands;
+  for (auto operand : llvm::enumerate(genericOp.getInputs())) {
     if (operand.index() == fusedTensorIndex) {
       expandedOpOperands.push_back(reshapeOp.src());
       continue;
     }
-    AffineMap indexingMap = linalgOp.getInputIndexingMap(operand.index());
+    AffineMap indexingMap = genericOp.getInputIndexingMap(operand.index());
     RankedTensorType expandedOperandType =
         getExpandedType(operand.value().getType().cast<RankedTensorType>(),
                         indexingMap, expansionInfo);
     if (expandedOperandType != operand.value().getType()) {
       // Reshape the operand to get the right type.
-      SmallVector<ReassociationIndices, 4> reassociation =
+      SmallVector<ReassociationIndices> reassociation =
           getReassociationForExpansion(indexingMap, expansionInfo);
       expandedOpOperands.push_back(rewriter.create<TensorReshapeOp>(
-          linalgOp.getLoc(), expandedOperandType, operand.value(),
+          genericOp.getLoc(), expandedOperandType, operand.value(),
           reassociation));
       continue;
     }
     expandedOpOperands.push_back(operand.value());
   }
 
-  Location loc = linalgOp.getLoc();
-  SmallVector<Value, 1> outputs;
-  for (auto result : llvm::enumerate(linalgOp.getOutputs())) {
-    AffineMap indexingMap = linalgOp.getOutputIndexingMap(result.index());
+  Location loc = genericOp.getLoc();
+  SmallVector<Value> outputs;
+  for (auto result : llvm::enumerate(genericOp.getOutputs())) {
+    AffineMap indexingMap = genericOp.getOutputIndexingMap(result.index());
     RankedTensorType expandedOutputType =
         getExpandedType(result.value().getType().cast<RankedTensorType>(),
                         indexingMap, expansionInfo);
     if (expandedOutputType != result.value().getType()) {
-      SmallVector<ReassociationIndices, 4> reassociation =
+      SmallVector<ReassociationIndices> reassociation =
           getReassociationForExpansion(indexingMap, expansionInfo);
       outputs.push_back(rewriter.create<TensorReshapeOp>(
-          linalgOp.getLoc(), expandedOutputType, result.value(),
+          genericOp.getLoc(), expandedOutputType, result.value(),
           reassociation));
     }
   }
 
   // The iterator types of the expanded op are all parallel.
-  SmallVector<StringRef, 4> iteratorTypes(expansionInfo.getExpandedOpNumDims(),
-                                          getParallelIteratorTypeName());
+  SmallVector<StringRef> iteratorTypes(expansionInfo.getExpandedOpNumDims(),
+                                       getParallelIteratorTypeName());
 
   TypeRange resultTypes = ValueRange(outputs).getTypes();
-  LinalgOp fusedOp = createLinalgOpOfSameType(
-      linalgOp, rewriter, linalgOp.getLoc(), resultTypes,
-      /*inputs=*/expandedOpOperands, outputs, expandedOpIndexingMaps,
-      iteratorTypes);
+  auto fusedOp =
+      rewriter.create<GenericOp>(genericOp.getLoc(), resultTypes,
+                                 /*inputs=*/expandedOpOperands, outputs,
+                                 expandedOpIndexingMaps, iteratorTypes);
   Region &fusedRegion = fusedOp->getRegion(0);
-  Region &originalRegion = linalgOp->getRegion(0);
-
-  if (isa<GenericOp>(linalgOp.getOperation())) {
-    rewriter.cloneRegionBefore(originalRegion, fusedRegion,
-                               fusedRegion.begin());
-  } else {
-    assert(isa<IndexedGenericOp>(linalgOp.getOperation()));
-    buildExpandedIndexedGenericOpRegion(rewriter, loc, originalRegion,
-                                        fusedRegion, expansionInfo);
-  }
+  Region &originalRegion = genericOp->getRegion(0);
+  rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
 
   // Update the index accesses after the expansion.
-  if (linalgOp.hasIndexSemantics())
-    updateExpandedIndexOpRegion(rewriter, loc, fusedRegion, expansionInfo);
+  updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo);
 
   // Reshape the result values to their original shape if this is a collapsing
   // reshape folded into its consumer.
-  SmallVector<Value, 1> resultVals;
-  for (auto result : llvm::enumerate(linalgOp->getResults())) {
+  SmallVector<Value> resultVals;
+  for (auto result : llvm::enumerate(genericOp->getResults())) {
     if (!isExpanding &&
         resultTypes[result.index()] != result.value().getType()) {
-      SmallVector<ReassociationIndices, 4> reassociation =
+      SmallVector<ReassociationIndices> reassociation =
           getReassociationForExpansion(
-              linalgOp.getOutputIndexingMap(result.index()), expansionInfo);
+              genericOp.getOutputIndexingMap(result.index()), expansionInfo);
       resultVals.push_back(rewriter.create<TensorReshapeOp>(
-          linalgOp.getLoc(), result.value().getType(),
+          genericOp.getLoc(), result.value().getType(),
           fusedOp->getResult(result.index()), reassociation));
     } else {
       resultVals.push_back(fusedOp->getResult(result.index()));
@@ -934,22 +800,21 @@ namespace {
 /// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... }
 ///        ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ...
 ///        -> tensor<?x?x4x?xf32>
-template <typename LinalgOpTy, bool foldUnitDimReshapesOnly>
+template <bool foldUnitDimReshapesOnly>
 struct FoldProducerReshapeOpByLinearization
-    : public OpRewritePattern<LinalgOpTy> {
-  using OpRewritePattern<LinalgOpTy>::OpRewritePattern;
+    : public OpRewritePattern<GenericOp> {
+  using OpRewritePattern<GenericOp>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(LinalgOpTy op,
+  LogicalResult matchAndRewrite(GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
-    if (!op.hasTensorSemantics())
+    if (!genericOp.hasTensorSemantics())
       return failure();
-    LinalgOp linalgOp = cast<LinalgOp>(op.getOperation());
-    for (auto operand : llvm::enumerate(linalgOp.getInputs())) {
+    for (auto operand : llvm::enumerate(genericOp.getInputs())) {
       TensorReshapeOp reshapeOp =
           operand.value().getDefiningOp<TensorReshapeOp>();
       if (!reshapeOp ||
           !isTensorReshapeOpFoldableByLinearization(
-              reshapeOp, linalgOp.getInputIndexingMap(operand.index()),
+              reshapeOp, genericOp.getInputIndexingMap(operand.index()),
               /*asProducer =*/true) ||
           (foldUnitDimReshapesOnly &&
            !isUnitDimExpansionOnly(reshapeOp.getResultType().getShape(),
@@ -957,15 +822,15 @@ struct FoldProducerReshapeOpByLinearization
         continue;
 
       // Compute the fused operands list,
-      SmallVector<Value, 2> fusedOperands(linalgOp.getInputs());
+      SmallVector<Value> fusedOperands(genericOp.getInputs());
       fusedOperands[operand.index()] = reshapeOp.src();
-      fusedOperands.append(linalgOp.getOutputs().begin(),
-                           linalgOp.getOutputs().end());
+      fusedOperands.append(genericOp.getOutputs().begin(),
+                           genericOp.getOutputs().end());
 
       // Compute indexing_maps for the fused operation. The indexing_maps for
       // the operands of the consumers that arent fused are the same.
       SmallVector<AffineMap, 4> fusedIndexMaps = llvm::to_vector<4>(
-          op.indexing_maps().template getAsValueRange<AffineMapAttr>());
+          genericOp.indexing_maps().template getAsValueRange<AffineMapAttr>());
 
       // Accepted consumer maps are either identity or permutation.
       auto invMap = inversePermutation(fusedIndexMaps[operand.index()]);
@@ -984,13 +849,14 @@ struct FoldProducerReshapeOpByLinearization
       // inverted. Without this the resultant op is not legal.
       if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
         return rewriter.notifyMatchFailure(
-            op, "fused op loop bound computation failed");
+            genericOp, "fused op loop bound computation failed");
       }
 
-      rewriter.startRootUpdate(op);
-      op->setOperands(fusedOperands);
-      op.indexing_mapsAttr(rewriter.getAffineMapArrayAttr(fusedIndexMaps));
-      rewriter.finalizeRootUpdate(op);
+      rewriter.startRootUpdate(genericOp);
+      genericOp->setOperands(fusedOperands);
+      genericOp.indexing_mapsAttr(
+          rewriter.getAffineMapArrayAttr(fusedIndexMaps));
+      rewriter.finalizeRootUpdate(genericOp);
       return success();
     }
     return failure();
@@ -1013,7 +879,7 @@ getReassociationIndices(ArrayRef<AffineMap> maps) {
 
 /// Pattern to move rank reducing reshape after an elementwise linalg generic
 /// op. This is useful to expose more fusion opportunities between named ops and
-/// generic op. This can only be done if there is no broadcast or permuation
+/// generic ops. This can only be done if there is no broadcast or permuation
 /// within the dimensions we need to merge.
 ///
 /// For example,
@@ -1040,27 +906,27 @@ getReassociationIndices(ArrayRef<AffineMap> maps) {
 ///  %3 = linalg.tensor_reshape %2 [
 ///    #affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
 ///    : tensor<12544x16xf32> into tensor<112x112x16xf32>
-template <typename GenericOpTy>
-struct PushExpandingReshape : public OpRewritePattern<GenericOpTy> {
-  using OpRewritePattern<GenericOpTy>::OpRewritePattern;
+struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
+  using OpRewritePattern<GenericOp>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(GenericOpTy op,
+  LogicalResult matchAndRewrite(GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
     // Only apply to elementwise linalg on tensor.
-    if (!op.hasTensorSemantics() ||
-        op.getNumParallelLoops() != op.getNumLoops())
+    if (!genericOp.hasTensorSemantics() ||
+        genericOp.getNumParallelLoops() != genericOp.getNumLoops())
       return failure();
     // Only support identity output maps. It could be extended to permuations if
     // needed.
-    if (llvm::any_of(op.getOutputIndexingMaps(),
+    if (llvm::any_of(genericOp.getOutputIndexingMaps(),
                      [](AffineMap map) { return !map.isIdentity(); }))
       return failure();
-    int64_t destRank = op.getNumParallelLoops();
-    SmallVector<Value, 4> newOperands = llvm::to_vector<4>(op.getInputs());
+    int64_t destRank = genericOp.getNumParallelLoops();
+    SmallVector<Value, 4> newOperands =
+        llvm::to_vector<4>(genericOp.getInputs());
     TensorReshapeOp reshapeFound;
     // 1. Look for tensor_reshape operands and figure out save the dimensions
     // merged.
-    for (auto operand : llvm::enumerate(op.getInputs())) {
+    for (auto operand : llvm::enumerate(genericOp.getInputs())) {
       TensorReshapeOp reshapeOp =
           operand.value().template getDefiningOp<TensorReshapeOp>();
       if (!reshapeOp || reshapeOp.getSrcType().getRank() >
@@ -1069,7 +935,7 @@ struct PushExpandingReshape : public OpRewritePattern<GenericOpTy> {
       }
       // TODO: We could support non-identity map as long as the merged
       // dimensions are still contiguous.
-      if (!op.getIndexingMaps()[operand.index()].isIdentity())
+      if (!genericOp.getIndexingMaps()[operand.index()].isIdentity())
         continue;
       if (reshapeFound) {
         // Only support a second reshape op if it has the same reassociate maps.
@@ -1087,7 +953,7 @@ struct PushExpandingReshape : public OpRewritePattern<GenericOpTy> {
     // Calculate the reassociation indices and rassociated reverse map.
     SmallVector<ReassociationIndices> reassociation =
         getReassociationIndices(reshapeFound.getReassociationMaps());
-    SmallVector<unsigned, 4> remap(destRank);
+    SmallVector<unsigned> remap(destRank);
     for (auto &indices : llvm::enumerate(reassociation)) {
       for (int64_t index : indices.value()) {
         remap[index] = indices.index();
@@ -1096,9 +962,9 @@ struct PushExpandingReshape : public OpRewritePattern<GenericOpTy> {
     // 2. Verify that we can merge the dimensions in the linalg and that we
     // don't need to create new reshapes operands. Inserting new reshape
     // operands would defeat the purpose of the transformation.
-    for (auto operand : llvm::enumerate(op.getInputs())) {
+    for (auto operand : llvm::enumerate(genericOp.getInputs())) {
       if (operand.value() == newOperands[operand.index()]) {
-        AffineMap map = op.getIndexingMaps()[operand.index()];
+        AffineMap map = genericOp.getIndexingMaps()[operand.index()];
         for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) {
           if (reassociation[remap[map.getDimPosition(i)]].size() > 1)
             return failure();
@@ -1108,70 +974,69 @@ struct PushExpandingReshape : public OpRewritePattern<GenericOpTy> {
 
     // 3. Calculate the affine map remapping and the reassociation to apply to
     // output tensors.
-    SmallVector<AffineMap, 4> newMaps;
+    SmallVector<AffineMap> newMaps;
     unsigned newRank = reassociation.size();
-    for (auto map : op.getIndexingMaps()) {
+    for (auto map : genericOp.getIndexingMaps()) {
       SmallVector<AffineExpr> newExprs;
       for (auto expr : map.getResults()) {
         unsigned position = expr.template cast<AffineDimExpr>().getPosition();
         // Skip dimension merged except for the last of the group.
         if (reassociation[remap[position]].back() == position) {
           newExprs.push_back(
-              getAffineDimExpr(remap[position], op.getContext()));
+              getAffineDimExpr(remap[position], genericOp.getContext()));
         }
       }
-      newMaps.push_back(AffineMap::get(newRank, 0, newExprs, op.getContext()));
+      newMaps.push_back(
+          AffineMap::get(newRank, 0, newExprs, genericOp.getContext()));
     }
 
     // 4. Reshape the output tensors.
     SmallVector<Value> newOutputs;
     SmallVector<Type> newOutputTypes;
-    for (auto output : op.outputs()) {
+    for (auto output : genericOp.outputs()) {
       auto newOutputType = RankedTensorType::get(
           reshapeFound.getSrcType().getShape(),
           output.getType().template cast<RankedTensorType>().getElementType());
       Value newOutput = rewriter.create<TensorReshapeOp>(
-          op->getLoc(), newOutputType, output, reassociation);
+          genericOp->getLoc(), newOutputType, output, reassociation);
       newOutputTypes.push_back(newOutputType);
       newOutputs.push_back(newOutput);
     }
     // 5. Create a new generic op with lowerer rank.
-    SmallVector<StringRef, 4> iteratorTypes(newRank,
-                                            getParallelIteratorTypeName());
-    auto newOp =
-        rewriter.create<GenericOpTy>(op->getLoc(), newOutputTypes, newOperands,
-                                     newOutputs, newMaps, iteratorTypes);
-    rewriter.inlineRegionBefore(op.region(), newOp.region(),
+    SmallVector<StringRef> iteratorTypes(newRank,
+                                         getParallelIteratorTypeName());
+    auto newOp = rewriter.create<GenericOp>(genericOp->getLoc(), newOutputTypes,
+                                            newOperands, newOutputs, newMaps,
+                                            iteratorTypes);
+    rewriter.inlineRegionBefore(genericOp.region(), newOp.region(),
                                 newOp.region().begin());
     // 6. Reshape the so that the type matches the uses.
     SmallVector<Value> newResults;
     for (auto result : llvm::enumerate(newOp->getResults())) {
       newResults.push_back(rewriter.create<TensorReshapeOp>(
-          op->getLoc(), op.getOutputTensorTypes()[result.index()],
+          genericOp->getLoc(), genericOp.getOutputTensorTypes()[result.index()],
           result.value(), reassociation));
     }
-    rewriter.replaceOp(op, newResults);
+    rewriter.replaceOp(genericOp, newResults);
     return success();
   }
 };
 
-/// Pattern to fuse a tensor_reshape op with its consumer
-/// generic/indexed_generic op, when the reshape op is collapsing
-/// dimensions. The dimensionality of the loop in the consumer is expanded.
-template <typename GenericOpTy>
+/// Pattern to fuse a tensor_reshape op with its consumer generic op, when the
+/// reshape op is collapsing dimensions. The dimensionality of the loop in the
+/// consumer is expanded.
 class FoldWithProducerReshapeOpByExpansion
-    : public OpRewritePattern<GenericOpTy> {
+    : public OpRewritePattern<GenericOp> {
 public:
   FoldWithProducerReshapeOpByExpansion(
       MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
       PatternBenefit benefit = 1)
-      : OpRewritePattern<GenericOpTy>(context, benefit),
+      : OpRewritePattern<GenericOp>(context, benefit),
         controlFoldingReshapes(foldReshapes) {}
 
-  LogicalResult matchAndRewrite(GenericOpTy genericOp,
+  LogicalResult matchAndRewrite(GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
-    LinalgOp linalgOp = cast<LinalgOp>(genericOp.getOperation());
-    for (auto operand : llvm::enumerate(linalgOp.getInputs())) {
+    for (auto operand : llvm::enumerate(genericOp.getInputs())) {
       TensorReshapeOp reshapeOp =
           operand.value().getDefiningOp<TensorReshapeOp>();
       if (!reshapeOp)
@@ -1181,14 +1046,14 @@ class FoldWithProducerReshapeOpByExpansion
       // - All constraints of fusing with reshape by expansion are met.
       if (reshapeOp.getSrcType().getRank() <
               reshapeOp.getResultType().getRank() ||
-          !isFusableWithReshapeByDimExpansion(linalgOp, operand.index()) ||
+          !isFusableWithReshapeByDimExpansion(genericOp, operand.index()) ||
           (!controlFoldingReshapes(
               reshapeOp->getResult(0),
-              linalgOp.getInputOpOperands()[operand.index()])))
+              genericOp.getInputOpOperands()[operand.index()])))
         continue;
 
-      Optional<SmallVector<Value, 1>> replacementValues =
-          fuseWithReshapeByExpansion(linalgOp, reshapeOp, operand.index(),
+      Optional<SmallVector<Value>> replacementValues =
+          fuseWithReshapeByExpansion(genericOp, reshapeOp, operand.index(),
                                      rewriter);
       if (!replacementValues)
         return failure();
@@ -1211,10 +1076,9 @@ struct FoldConsumerReshapeOpByLinearization
 
   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
                                 PatternRewriter &rewriter) const override {
-    LinalgOp producer = reshapeOp.src().getDefiningOp<LinalgOp>();
-    if (!producer ||
-        !isa<GenericOp, IndexedGenericOp>(producer.getOperation()) ||
-        !producer.hasTensorSemantics() || producer.getNumOutputs() != 1 ||
+    GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>();
+    if (!producer || !producer.hasTensorSemantics() ||
+        producer.getNumOutputs() != 1 ||
         !isTensorReshapeOpFoldableByLinearization(
             reshapeOp, producer.getOutputIndexingMap(0),
             /*asProducer =*/false) ||
@@ -1251,8 +1115,8 @@ struct FoldConsumerReshapeOpByLinearization
     Location loc = producer.getLoc();
     Value output = rewriter.create<TensorReshapeOp>(
         loc, producer.getOutputs()[0], reshapeOp.getReassociationExprs());
-    LinalgOp fusedOp = createLinalgOpOfSameType(
-        producer, rewriter, loc, reshapeOp.getResultType(),
+    auto fusedOp = rewriter.create<GenericOp>(
+        loc, reshapeOp.getResultType(),
         /*inputs=*/producer.getInputs(),
         // TODO: handle outputs.
         /*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
@@ -1280,16 +1144,15 @@ struct FoldReshapeWithGenericOpByExpansion
     // - All constraints of fusing with reshape by expansion are met.
     if (reshapeOp.getSrcType().getRank() > reshapeOp.getResultType().getRank())
       return failure();
-    LinalgOp producer = reshapeOp.src().getDefiningOp<LinalgOp>();
+    GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>();
     if (!producer || producer.getNumOutputs() != 1 ||
         !isFusableWithReshapeByDimExpansion(producer,
                                             producer.getNumInputs()) ||
         isUnitDimExpansionOnly(reshapeOp.getResultType().getShape(),
                                reshapeOp.getReassociationMaps()))
       return failure();
-    Optional<SmallVector<Value, 1>> replacementValues =
-        fuseWithReshapeByExpansion(producer, reshapeOp, producer.getNumInputs(),
-                                   rewriter);
+    Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion(
+        producer, reshapeOp, producer.getNumInputs(), rewriter);
     if (!replacementValues)
       return failure();
     rewriter.replaceOp(reshapeOp, replacementValues.getValue());
@@ -1297,20 +1160,18 @@ struct FoldReshapeWithGenericOpByExpansion
   }
 };
 
-/// Pattern to fold a GenericOp/IndexedGenericOp with a splat constant.
-template <typename LinalgOpTy>
-class FoldSplatConstants : public OpRewritePattern<LinalgOpTy> {
+/// Pattern to fold a generic op with a splat constant.
+class FoldSplatConstants : public OpRewritePattern<GenericOp> {
 public:
   FoldSplatConstants(MLIRContext *context, ControlElementwiseOpsFusionFn &fun,
                      PatternBenefit benefit = 1)
-      : OpRewritePattern<LinalgOpTy>(context, benefit), controlFn(fun) {}
+      : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
 
-  LogicalResult matchAndRewrite(LinalgOpTy op,
+  LogicalResult matchAndRewrite(GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
-    if (!op.hasTensorSemantics())
+    if (!genericOp.hasTensorSemantics())
       return failure();
-    LinalgOp linalgOp = cast<LinalgOp>(op.getOperation());
-    for (auto operand : llvm::enumerate(linalgOp.getInputOpOperands())) {
+    for (auto operand : llvm::enumerate(genericOp.getInputOpOperands())) {
       Operation *def = operand.value().get().getDefiningOp();
       DenseElementsAttr constantAttr;
       if (!def ||
@@ -1320,49 +1181,46 @@ class FoldSplatConstants : public OpRewritePattern<LinalgOpTy> {
         continue;
 
       // The indexing_maps for the operands of the fused operation are same as
-      // those for the operands of the linalgOp without the indexing map at
+      // those for the operands of the genericOp without the indexing map at
       // operand.index()
       SmallVector<AffineMap, 4> fusedIndexMaps = llvm::to_vector<4>(
-          linalgOp.indexing_maps().getAsValueRange<AffineMapAttr>());
+          genericOp.indexing_maps().getAsValueRange<AffineMapAttr>());
       fusedIndexMaps.erase(std::next(fusedIndexMaps.begin(), operand.index()));
 
       // Check if the operation shapes to loops map is computable.
       if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
         return rewriter.notifyMatchFailure(
-            linalgOp, "fused op loop bound computation failed");
+            genericOp, "fused op loop bound computation failed");
       }
 
-      // The operands list is same as the linalgOp with the argument for
+      // The operands list is same as the genericOp with the argument for
       // constant index dropped.
-      SmallVector<Value, 4> fusedOperands(linalgOp.getInputs());
+      SmallVector<Value> fusedOperands(genericOp.getInputs());
       fusedOperands.erase(std::next(fusedOperands.begin(), operand.index()));
 
       // Create a constant scalar value from the splat constant.
       Value scalarConstant = rewriter.create<ConstantOp>(
           def->getLoc(), constantAttr.getSplatValue());
 
-      LinalgOp fusedOp = createLinalgOpOfSameType(
-          linalgOp, rewriter, rewriter.getUnknownLoc(),
-          linalgOp->getResultTypes(),
+      auto fusedOp = rewriter.create<GenericOp>(
+          rewriter.getUnknownLoc(), genericOp->getResultTypes(),
           /*inputs=*/fusedOperands,
-          /*outputs=*/linalgOp.getOutputs(),
+          /*outputs=*/genericOp.getOutputs(),
           rewriter.getAffineMapArrayAttr(fusedIndexMaps),
-          linalgOp.iterator_types(),
+          genericOp.iterator_types(),
           /*doc=*/nullptr,
           /*library_call=*/nullptr);
 
       // Map the block argument corresponding to the replaced argument with the
       // scalar constant.
-      Region &linalgOpRegion = linalgOp->getRegion(0);
-      Block &entryBlock = *linalgOpRegion.begin();
-      unsigned argIndex = entryBlock.getNumArguments() -
-                          linalgOp.getNumShapedOperands() + operand.index();
+      Region &region = genericOp->getRegion(0);
+      Block &entryBlock = *region.begin();
       BlockAndValueMapping mapping;
-      mapping.map(entryBlock.getArgument(argIndex), scalarConstant);
+      mapping.map(entryBlock.getArgument(operand.index()), scalarConstant);
       Region &fusedRegion = fusedOp->getRegion(0);
-      rewriter.cloneRegionBefore(linalgOpRegion, fusedRegion,
-                                 fusedRegion.begin(), mapping);
-      rewriter.replaceOp(linalgOp, fusedOp->getResults());
+      rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
+                                 mapping);
+      rewriter.replaceOp(genericOp, fusedOp->getResults());
       return success();
     }
     return failure();
@@ -1373,20 +1231,15 @@ class FoldSplatConstants : public OpRewritePattern<LinalgOpTy> {
 };
 } // namespace
 
-static Optional<SmallVector<Value, 1>>
+static Optional<SmallVector<Value>>
 fuseElementwiseOps(PatternRewriter &rewriter, OpOperand &consumerOpOperand,
+                   GenericOp producer,
                    const ControlElementwiseOpsFusionFn &controlFn) {
-  Operation *producer = consumerOpOperand.get().getDefiningOp();
-  if (!producer || producer->getNumResults() != 1)
-    return llvm::None;
-
-  // Fuse when consumer is GenericOp or IndexedGenericOp.
-  if (!isa<GenericOp, IndexedGenericOp>(consumerOpOperand.getOwner()) ||
-      !isa<GenericOp, IndexedGenericOp>(producer))
+  if (producer->getNumResults() != 1)
     return llvm::None;
 
-  return fuseElementwiseOpsImpl(cast<LinalgOp>(producer), consumerOpOperand,
-                                controlFn, rewriter);
+  return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn,
+                                rewriter);
 }
 
 bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
@@ -1398,25 +1251,24 @@ bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
 
 namespace {
 /// Patterns to fuse a generic op, with the producer of its operands.
-template <typename LinalgOpTy>
-class FuseElementwiseOps : public OpRewritePattern<LinalgOpTy> {
+class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
 public:
   FuseElementwiseOps(MLIRContext *context, ControlElementwiseOpsFusionFn &fun,
                      PatternBenefit benefit = 1)
-      : OpRewritePattern<LinalgOpTy>(context, benefit), controlFn(fun) {}
+      : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
 
-  LogicalResult matchAndRewrite(LinalgOpTy op,
+  LogicalResult matchAndRewrite(GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
     // Find the first operand that is defined by another generic op on tensors.
-    for (OpOperand &opOperand : op.getShapedOpOperands()) {
-      LinalgOp producerOp =
-          dyn_cast_or_null<LinalgOp>(opOperand.get().getDefiningOp());
-      if (!producerOp || !producerOp.hasTensorSemantics())
+    for (OpOperand &opOperand : genericOp.getShapedOpOperands()) {
+      auto producer =
+          dyn_cast_or_null<GenericOp>(opOperand.get().getDefiningOp());
+      if (!producer || !producer.hasTensorSemantics())
         continue;
-      Optional<SmallVector<Value, 1>> fusedOpResults =
-          fuseElementwiseOps(rewriter, opOperand, controlFn);
+      Optional<SmallVector<Value>> fusedOpResults =
+          fuseElementwiseOps(rewriter, opOperand, producer, controlFn);
       if (fusedOpResults) {
-        rewriter.replaceOp(op, *fusedOpResults);
+        rewriter.replaceOp(genericOp, *fusedOpResults);
         return success();
       }
     }
@@ -1445,8 +1297,7 @@ struct FusionOfTensorOpsPass
   }
 };
 
-/// Pass to test folding of reshape op with generic/indexed_generic ops by
-/// linearization.
+/// Pass to test folding of reshape ops with generic ops by linearization.
 struct FoldReshapeOpsByLinearizationPass
     : public LinalgFoldReshapeOpsByLinearizationBase<
           FoldReshapeOpsByLinearizationPass> {
@@ -1462,16 +1313,14 @@ struct FoldReshapeOpsByLinearizationPass
 
 void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<FoldProducerReshapeOpByLinearization<GenericOp, false>,
-               FoldProducerReshapeOpByLinearization<IndexedGenericOp, false>,
+  patterns.add<FoldProducerReshapeOpByLinearization<false>,
                FoldConsumerReshapeOpByLinearization<false>>(
       patterns.getContext());
 }
 
 void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<FoldProducerReshapeOpByLinearization<GenericOp, true>,
-               FoldProducerReshapeOpByLinearization<IndexedGenericOp, true>,
+  patterns.add<FoldProducerReshapeOpByLinearization<true>,
                FoldConsumerReshapeOpByLinearization<true>>(
       patterns.getContext());
 }
@@ -1480,18 +1329,15 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
     RewritePatternSet &patterns,
     ControlElementwiseOpsFusionFn controlFoldingReshapes) {
   patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext());
-  patterns.add<FoldWithProducerReshapeOpByExpansion<GenericOp>,
-               FoldWithProducerReshapeOpByExpansion<IndexedGenericOp>>(
-      patterns.getContext(), controlFoldingReshapes);
+  patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
+                                                     controlFoldingReshapes);
 }
 
 void mlir::linalg::populateElementwiseOpsFusionPatterns(
     RewritePatternSet &patterns, LinalgElementwiseFusionOptions options) {
   auto *context = patterns.getContext();
-  patterns
-      .add<FuseElementwiseOps<GenericOp>, FuseElementwiseOps<IndexedGenericOp>,
-           FoldSplatConstants<GenericOp>, FoldSplatConstants<IndexedGenericOp>>(
-          context, options.controlElementwiseOpsFusionFn);
+  patterns.add<FuseElementwiseOps, FoldSplatConstants>(
+      context, options.controlElementwiseOpsFusionFn);
   populateFoldReshapeOpsByExpansionPatterns(patterns,
                                             options.controlFoldingReshapesFn);
   AffineApplyOp::getCanonicalizationPatterns(patterns, context);
@@ -1502,8 +1348,7 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
 
 void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) {
   auto *context = patterns.getContext();
-  patterns.add<PushExpandingReshape<GenericOp>,
-               PushExpandingReshape<IndexedGenericOp>>(context);
+  patterns.add<PushExpandingReshape>(context);
 }
 
 std::unique_ptr<Pass> mlir::createLinalgFusionOfTensorOpsPass() {

diff  --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
index 7b43c0ffbd5b0..36a1e45839ec5 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
@@ -205,39 +205,6 @@ func @generic_op_constant_fusion(%arg0 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32>
 
 // -----
 
-#map0 = affine_map<(d0, d1, d2) -> (d0)>
-#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-func @indexed_generic_op_constant_fusion(%arg0 : tensor<5x?x?xf32>)
-                                         -> tensor<5x?x?xf32>
-{
-  %c0 = constant 0 : index
-  %c1 = constant 1 : index
-  %c2 = constant 2 : index
-  %cst = constant dense<42.0> : tensor<5xf32>
-  %0 = memref.dim %arg0, %c1 : tensor<5x?x?xf32>
-  %1 = memref.dim %arg0, %c2 : tensor<5x?x?xf32>
-  %2 = linalg.init_tensor [5, %0, %1] : tensor<5x?x?xf32>
-  %3 = linalg.indexed_generic {
-    indexing_maps = [#map0, #map1, #map1],
-    iterator_types = ["parallel", "parallel", "parallel"]}
-    ins(%cst, %arg0 : tensor<5xf32>, tensor<5x?x?xf32>)
-    outs(%2 : tensor<5x?x?xf32>) {
-    ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: f32, %arg5 : f32, %arg6 : f32):
-      %4 = mulf %arg4, %arg5 : f32
-      linalg.yield %4 : f32
-    } -> tensor<5x?x?xf32>
-  return %3 : tensor<5x?x?xf32>
-}
-//   CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK-LABEL: func @indexed_generic_op_constant_fusion
-//       CHECK:   %[[CST:.*]] = constant {{.*}} : f32
-//       CHECK:   linalg.generic
-//       CHECK:   ^{{[a-zA-Z0-9_]*}}
-//  CHECK-SAME:     %[[ARG4:[a-zA-Z0-9_]*]]: f32, %{{.*}}: f32)
-//       CHECK:     mulf %[[CST]], %[[ARG4]]
-
-// -----
-
 #map0 = affine_map<(d0, d1, d2) -> ()>
 #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 func @generic_op_zero_dim_constant_fusion(%arg0 : tensor<5x?x?xf32>)
@@ -270,89 +237,6 @@ func @generic_op_zero_dim_constant_fusion(%arg0 : tensor<5x?x?xf32>)
 
 // -----
 
-#map0 = affine_map<(d0, d1, d2) -> ()>
-#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-func @indexed_generic_op_zero_dim_constant_fusion
-  (%arg0 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32>
-{
-  %c0 = constant 0 : index
-  %c1 = constant 1 : index
-  %c2 = constant 2 : index
-  %cst = constant dense<42.0> : tensor<f32>
-  %0 = memref.dim %arg0, %c1 : tensor<5x?x?xf32>
-  %1 = memref.dim %arg0, %c2 : tensor<5x?x?xf32>
-  %2 = linalg.init_tensor [5, %0, %1] : tensor<5x?x?xf32>
-  %3 = linalg.indexed_generic {
-    indexing_maps = [#map0, #map1, #map1],
-    iterator_types = ["parallel", "parallel", "parallel"]}
-    ins(%cst, %arg0 : tensor<f32>, tensor<5x?x?xf32>)
-    outs(%2 : tensor<5x?x?xf32>) {
-    ^bb0(%arg1 : index, %arg2 : index, %arg3 : index, %arg4: f32, %arg5: f32, %arg6: f32):
-      %4 = mulf %arg4, %arg5 : f32
-      linalg.yield %4 : f32
-    } -> tensor<5x?x?xf32>
-  return %3 : tensor<5x?x?xf32>
-}
-//   CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK-LABEL: func @indexed_generic_op_zero_dim_constant_fusion
-//       CHECK:   %[[CST:.*]] = constant {{.*}} : f32
-//       CHECK:   linalg.generic
-//       CHECK:   ^{{[a-zA-Z0-9_]*}}
-//  CHECK-SAME:     %[[ARG4:[a-zA-Z0-9_]*]]: f32, %{{.*}}: f32)
-//       CHECK:     mulf %[[CST]], %[[ARG4]]
-
-// -----
-
-#map0 = affine_map<(d0, d1) -> (d0, d1)>
-func @generic_op_indexed_generic_op_fusion(%arg0: tensor<?x?xi32>,
-                                           %arg1: tensor<?x?xi32>) -> tensor<?x?xi32> {
-  %c0 = constant 0 : index
-  %c1 = constant 1 : index
-  %0 = memref.dim %arg0, %c0 : tensor<?x?xi32>
-  %1 = memref.dim %arg0, %c1 : tensor<?x?xi32>
-  %2 = linalg.init_tensor [%0, %1] : tensor<?x?xi32>
-  %3 = linalg.generic {
-    indexing_maps = [#map0, #map0, #map0],
-    iterator_types = ["parallel", "parallel"] }
-    ins(%arg0, %arg1  : tensor<?x?xi32>, tensor<?x?xi32>)
-    outs(%2 : tensor<?x?xi32>) {
-    ^bb0(%arg2: i32, %arg3: i32, %arg4: i32):       // no predecessors
-      %10 = addi %arg2, %arg3 : i32
-      linalg.yield %10 : i32
-    } -> tensor<?x?xi32>
-  %4 = linalg.indexed_generic {
-    indexing_maps = [#map0, #map0],
-    iterator_types = ["parallel", "parallel"] }
-    ins(%3 : tensor<?x?xi32>)
-    outs(%2 : tensor<?x?xi32>) {
-    ^bb0(%arg2: index, %arg3: index, %arg4: i32, %arg5: i32):       // no predecessors
-      %5 = index_cast %arg2 : index to i32
-      %6 = index_cast %arg3 : index to i32
-      %7 = addi %arg4, %5 : i32
-      %8 = subi %7, %6 : i32
-      linalg.yield %8 : i32
-    } -> tensor<?x?xi32>
-  return %4 : tensor<?x?xi32>
-}
-//   CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-LABEL: func @generic_op_indexed_generic_op_fusion
-//   CHECK-NOT: linalg.indexed_generic
-//       CHECK: linalg.generic
-// CHECK-SAME:    indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]]
-//      CHECK: ^{{[a-zA-Z0-9_]*}}
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: i32
-// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]*]]: i32
-//      CHECK:   %[[ARG0:.+]] = linalg.index 0 : index
-//      CHECK:   %[[ARG1:.+]] = linalg.index 1 : index
-//      CHECK:   %[[VAL1:.+]] = addi %[[ARG2]], %[[ARG3]] : i32
-//      CHECK:   %[[ADD_OPERAND:.+]] = index_cast %[[ARG0]] : index to i32
-//      CHECK:   %[[SUB_OPERAND:.+]] = index_cast %[[ARG1]] : index to i32
-//      CHECK:   %[[VAL2:.+]] = addi %[[VAL1]], %[[ADD_OPERAND]] : i32
-//      CHECK:   %[[VAL3:.+]] = subi %[[VAL2]], %[[SUB_OPERAND]] : i32
-//      CHECK:   linalg.yield %[[VAL3]] : i32
-
-// -----
-
 #map0 = affine_map<(d0, d1) -> (d0, d1)>
 func @producer_indexed_consumer_fusion(%arg0: tensor<?x?xi32>,
                                        %arg1: tensor<?x?xi32>) -> tensor<?x?xi32> {
@@ -405,56 +289,6 @@ func @producer_indexed_consumer_fusion(%arg0: tensor<?x?xi32>,
 
 // -----
 
-#map0 = affine_map<(d0, d1) -> (d0, d1)>
-func @indexed_generic_op_generic_op_fusion(%arg0: tensor<?x?xi32>,
-                                           %arg1: tensor<?x?xi32>) -> tensor<?x?xi32> {
-  %c0 = constant 0 : index
-  %c1 = constant 1 : index
-  %0 = memref.dim %arg0, %c0 : tensor<?x?xi32>
-  %1 = memref.dim %arg0, %c1 : tensor<?x?xi32>
-  %2 = linalg.init_tensor [%0, %1] : tensor<?x?xi32>
-  %3 = linalg.indexed_generic {
-    indexing_maps = [#map0, #map0],
-    iterator_types = ["parallel", "parallel"] }
-    ins(%arg0 : tensor<?x?xi32>)
-    outs(%2 : tensor<?x?xi32>) {
-    ^bb0(%arg2: index, %arg3: index, %arg4: i32, %arg5: i32):       // no predecessors
-      %4 = index_cast %arg2 : index to i32
-      %5 = index_cast %arg3 : index to i32
-      %6 = addi %arg4, %4 : i32
-      %7 = subi %6, %5 : i32
-      linalg.yield %7 : i32
-    } -> tensor<?x?xi32>
-  %4 = linalg.generic {
-    indexing_maps = [#map0, #map0, #map0],
-    iterator_types = ["parallel", "parallel"] }
-    ins(%3, %arg1 : tensor<?x?xi32>, tensor<?x?xi32>)
-    outs(%2 : tensor<?x?xi32>) {
-    ^bb0(%arg2: i32, %arg3: i32, %arg4: i32):       // no predecessors
-      %10 = addi %arg2, %arg3 : i32
-      linalg.yield %10 : i32
-    } -> tensor<?x?xi32>
-  return %4 : tensor<?x?xi32>
-}
-//   CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-LABEL: func @indexed_generic_op_generic_op_fusion
-//       CHECK: linalg.generic
-// CHECK-SAME:    indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]]
-//      CHECK: ^{{[a-zA-Z0-9_]*}}
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: i32
-// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]*]]: i32
-//      CHECK:   %[[ARG0:.+]] = linalg.index 0 : index
-//      CHECK:   %[[ARG1:.+]] = linalg.index 1 : index
-//      CHECK:   %[[ADD_OPERAND:.+]] = index_cast %[[ARG0]] : index to i32
-//      CHECK:   %[[SUB_OPERAND:.+]] = index_cast %[[ARG1]] : index to i32
-//      CHECK:   %[[VAL1:.+]] = addi %[[ARG2]], %[[ADD_OPERAND]] : i32
-//      CHECK:   %[[VAL2:.+]] = subi %[[VAL1]], %[[SUB_OPERAND]] : i32
-//      CHECK:   %[[VAL3:.+]] = addi %[[VAL2]], %[[ARG3]] : i32
-//      CHECK:   linalg.yield %[[VAL3]] : i32
-//   CHECK-NOT: linalg.generic
-
-// -----
-
 #map0 = affine_map<(d0, d1) -> (d0, d1)>
 func @indexed_producer_consumer_fusion(%arg0: tensor<?x?xi32>) -> tensor<?x?xi32> {
   %c0 = constant 0 : index
@@ -506,63 +340,7 @@ func @indexed_producer_consumer_fusion(%arg0: tensor<?x?xi32>) -> tensor<?x?xi32
 
 // -----
 
-// The indices of the first indexed_generic op are swapped after fusion.
-#map0 = affine_map<(d0, d1) -> (d1, d0)>
-#map1 = affine_map<(d0, d1) -> (d0, d1)>
-func @indexed_generic_op_fusion(%arg0: tensor<?x?xi32>) -> tensor<?x?xi32> {
-  %c0 = constant 0 : index
-  %c1 = constant 1 : index
-  %0 = memref.dim %arg0, %c0 : tensor<?x?xi32>
-  %1 = memref.dim %arg0, %c1 : tensor<?x?xi32>
-  %2 = linalg.init_tensor [%0, %1] : tensor<?x?xi32>
-  %3 = linalg.indexed_generic {
-    indexing_maps = [#map0, #map0],
-    iterator_types = ["parallel", "parallel"] }
-    ins(%arg0 : tensor<?x?xi32>)
-    outs(%2 : tensor<?x?xi32>) {
-    ^bb0(%arg2: index, %arg3: index, %arg4: i32, %arg5: i32):       // no predecessors
-      %4 = index_cast %arg2 : index to i32
-      %5 = index_cast %arg3 : index to i32
-      %6 = addi %arg4, %4 : i32
-      %7 = subi %5, %6 : i32
-      linalg.yield %7 : i32
-    } -> tensor<?x?xi32>
-  %4= linalg.indexed_generic {
-    indexing_maps = [#map1, #map1],
-    iterator_types = ["parallel", "parallel"] }
-    ins(%3 : tensor<?x?xi32>)
-    outs(%2 : tensor<?x?xi32>) {
-    ^bb0(%arg2: index, %arg3: index, %arg4: i32, %arg5: i32):       // no predecessors
-      %5 = index_cast %arg2 : index to i32
-      %6 = index_cast %arg3 : index to i32
-      %7 = addi %arg4, %5 : i32
-      %8 = subi %7, %6 : i32
-      linalg.yield %8 : i32
-    } -> tensor<?x?xi32>
-  return %4 : tensor<?x?xi32>
-}
-//   CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-LABEL: func @indexed_generic_op_fusion
-//       CHECK: linalg.generic
-// CHECK-SAME:    indexing_maps = [#[[$MAP0]], #[[$MAP0]]]
-//      CHECK: ^{{[a-zA-Z0-9_]*}}
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: i32
-//      CHECK:   %[[ARG0:.+]] = linalg.index 0 : index
-//      CHECK:   %[[ARG1:.+]] = linalg.index 1 : index
-//      CHECK:   %[[ADD_OPERAND1:.+]] = index_cast %[[ARG1]] : index to i32
-//      CHECK:   %[[SUB_OPERAND1:.+]] = index_cast %[[ARG0]] : index to i32
-//      CHECK:   %[[VAL1:.+]] = addi %[[ARG2]], %[[ADD_OPERAND1]] : i32
-//      CHECK:   %[[VAL2:.+]] = subi %[[SUB_OPERAND1]], %[[VAL1]] : i32
-//      CHECK:   %[[ADD_OPERAND2:.+]] = index_cast %[[ARG0]] : index to i32
-//      CHECK:   %[[SUB_OPERAND2:.+]] = index_cast %[[ARG1]] : index to i32
-//      CHECK:   %[[VAL3:.+]] = addi %[[VAL2]], %[[ADD_OPERAND2]] : i32
-//      CHECK:   %[[VAL4:.+]] = subi %[[VAL3]], %[[SUB_OPERAND2]] : i32
-//      CHECK:   linalg.yield %[[VAL4]] : i32
-//   CHECK-NOT: linalg.generic
-
-// -----
-
-// The indices of the first indexed_generic op are swapped after fusion.
+// The indices of the first generic op are swapped after fusion.
 #map0 = affine_map<(d0, d1) -> (d1, d0)>
 #map1 = affine_map<(d0, d1) -> (d0, d1)>
 func @indexed_producer_indexed_consumer_fusion(%arg0: tensor<?x?xi32>)
@@ -625,45 +403,52 @@ func @indexed_producer_indexed_consumer_fusion(%arg0: tensor<?x?xi32>)
 
 // -----
 
-func @scalar_indexed_generic_fusion
-  (%arg0: tensor<5x1x1xf32>, %arg1 : tensor<i32>) -> tensor<10xf32>
-{
+#map1 = affine_map<(d0) -> (d0)>
+#map2 = affine_map<(d0, d1) -> (d0, d1)>
+#map3 = affine_map<(d0, d1) -> (d1)>
+func @one_dim_indexed_producer_consumer_fusion(%arg0 : tensor<?xi32>,
+                                               %arg1 : tensor<?x?xi32>) -> tensor<?x?xi32> {
   %c0 = constant 0 : index
-  %cst = constant dense<1.000000e+00> : tensor<10xf32>
-  %0 = linalg.init_tensor [] : tensor<f32>
-  %1 = linalg.indexed_generic
-    {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>],
-     iterator_types = []}
-    ins(%arg1 : tensor<i32>) outs(%0 : tensor<f32>) {
-    ^bb0(%arg2: i32, %arg3: f32):  // no predecessors
-      %3 = index_cast %arg2 : i32 to index
-      %4 = tensor.extract %arg0[%3, %c0, %c0] : tensor<5x1x1xf32>
-      linalg.yield %4 : f32
-    } -> tensor<f32>
-  %2 = linalg.init_tensor [10] : tensor<10xf32>
-  %3 = linalg.generic
-   {indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>,
-                     affine_map<(d0) -> (d0)>],
-    iterator_types = ["parallel"]}
-    ins(%1, %cst : tensor<f32>, tensor<10xf32>) outs(%2 : tensor<10xf32>) {
-    ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):  // no predecessors
-      %4 = mulf %arg2, %arg3 : f32
-      linalg.yield %4 : f32
-    } -> tensor<10xf32>
-  return %3 : tensor<10xf32>
+  %c1 = constant 1 : index
+  %d0 = memref.dim %arg0, %c0 : tensor<?xi32>
+  %0 = linalg.init_tensor [%d0] : tensor<?xi32>
+  %1 = linalg.generic
+      {indexing_maps = [#map1, #map1],
+       iterator_types = ["parallel"]}
+      ins(%arg0 : tensor<?xi32>) outs(%0 : tensor<?xi32>) {
+      ^bb0(%arg2 : i32, %arg3 : i32):
+        %2 = linalg.index 0 : index
+        %3 = index_cast %2 : index to i32
+        %4 = addi %arg2, %3 : i32
+        linalg.yield %4 : i32
+      } -> tensor<?xi32>
+  %2 = memref.dim %arg1, %c0 : tensor<?x?xi32>
+  %3 = memref.dim %arg1, %c1 : tensor<?x?xi32>
+  %4 = linalg.init_tensor [%2, %3] : tensor<?x?xi32>
+  %5 = linalg.generic
+      {indexing_maps = [#map2, #map3, #map2],
+       iterator_types = ["parallel", "parallel"]}
+      ins(%arg1, %1 : tensor<?x?xi32>, tensor<?xi32>)
+      outs(%4 : tensor<?x?xi32>) {
+      ^bb0(%arg2 : i32, %arg3 : i32, %arg4: i32):
+        %6 = addi %arg2, %arg3 : i32
+        linalg.yield %6 : i32
+     } -> tensor<?x?xi32>
+  return %5 : tensor<?x?xi32>
 }
-//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> ()>
-//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0)>
-//       CHECK: func @scalar_indexed_generic_fusion
-//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<5x1x1xf32>
-//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]: tensor<i32>
-//       CHECK:   %[[T0:.+]] = linalg.generic
-//  CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]]]
-//  CHECK-SAME:     iterator_types = ["parallel"]
-//  CHECK-SAME:     ins(%[[ARG1]] : tensor<i32>)
-//       CHECK:     tensor.extract %[[ARG0]]
-//       CHECK:     linalg.yield
-//       CHECK   return %[[T0]]
+//   CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+//   CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
+// CHECK-LABEL: func @one_dim_indexed_producer_consumer_fusion
+//       CHECK: linalg.generic
+// CHECK-SAME:    indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]]
+//      CHECK: ^{{[a-zA-Z0-9_]*}}
+// CHECK-SAME: (%[[ARG0:[a-zA-Z0-9_]*]]: i32, %[[ARG1:[a-zA-Z0-9_]*]]: i32
+//      CHECK:   %[[IDX1:.+]] = linalg.index 1 : index
+//      CHECK:   %[[VAL1:.+]] = index_cast %[[IDX1]] : index to i32
+//      CHECK:   %[[VAL2:.+]] = addi %[[ARG1]], %[[VAL1]] : i32
+//      CHECK:   %[[VAL3:.+]] = addi %[[ARG0]], %[[VAL2]] : i32
+//      CHECK:   linalg.yield %[[VAL3]] : i32
+//   CHECK-NOT: linalg.generic
 
 // -----
 

diff  --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index ff9aeeb986f65..9ff534c8b6549 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -164,56 +164,6 @@ func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
 
 // -----
 
-#map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
-#map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
-func @indexed_generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>,
-                                         %arg1 : tensor<?x?x?xi32>) ->
-                                         tensor<?x?x?xi32>
-{
-  %0 = linalg.tensor_reshape %arg0 [[0], [1, 2], [3]] :
-    tensor<?x?x4x?xi32> into tensor<?x?x?xi32>
-  %1 = linalg.indexed_generic {
-     indexing_maps = [#map0, #map1, #map1],
-     iterator_types = ["parallel", "parallel", "parallel"]}
-       ins(%0, %arg1 : tensor<?x?x?xi32>, tensor<?x?x?xi32>)
-      outs(%0 : tensor<?x?x?xi32>) {
-    ^bb0(%arg3 : index, %arg4 : index, %arg5 : index, %arg6: i32, %arg7: i32, %s: i32):
-      %1 = muli %arg6, %arg7 : i32
-      %2 = index_cast %arg3 : index to i32
-      %3 = addi %1, %2 : i32
-      %4 = index_cast %arg4 : index to i32
-      %5 = addi %3, %4 : i32
-      %6 = index_cast %arg5 : index to i32
-      %7 = addi %5, %6 : i32
-      linalg.yield %7 : i32
-  } -> tensor<?x?x?xi32>
-  return %1 : tensor<?x?x?xi32>
-}
-
-// The generic op version of the test check for the op structure. Only
-// checking the op body here.
-//       CHECK: #[[MAP:.+]] =  affine_map<(d0, d1) -> (d0 + d1 * 4)>
-//       CHECK: func @indexed_generic_op_reshape_producer_fusion
-//       CHECK:   linalg.generic
-//       CHECK:   ^{{.*}}(
-//  CHECK-SAME:     %[[ARG6:[a-zA-Z0-9]+]]: i32, %[[ARG7:[a-zA-Z0-9]+]]: i32,
-//  CHECK-SAME:     %[[ARG8:[a-zA-Z0-9]+]]: i32)
-//       CHECK:     %[[ARG2:.+]] = linalg.index 0 : index
-//       CHECK:     %[[ARG3:.+]] = linalg.index 1 : index
-//       CHECK:     %[[ARG4:.+]] = linalg.index 2 : index
-//       CHECK:     %[[ARG5:.+]] = linalg.index 3 : index
-//       CHECK:     %[[T3:.+]] = affine.apply #[[MAP]](%[[ARG3]], %[[ARG2]])
-//       CHECK:     %[[T4:.+]] = muli %[[ARG6]], %[[ARG7]]
-//       CHECK:     %[[T5:.+]] = index_cast %[[T3]]
-//       CHECK:     %[[T6:.+]] = addi %[[T4]], %[[T5]]
-//       CHECK:     %[[T7:.+]] = index_cast %[[ARG4]]
-//       CHECK:     %[[T8:.+]] = addi %[[T6]], %[[T7]]
-//       CHECK:     %[[T9:.+]] = index_cast %[[ARG5]]
-//       CHECK:     %[[T10:.+]] = addi %[[T8]], %[[T9]]
-//       CHECK:     linalg.yield %[[T10]]
-
-// -----
-
 #map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
 #map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
 func @indexed_consumer_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>,
@@ -266,50 +216,6 @@ func @indexed_consumer_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>,
 
 // -----
 
-#map0 = affine_map<(d0, d1) -> (d0, d1)>
-func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
-                                         %arg1 : tensor<?x?xi32>) ->
-                                         tensor<?x?x4x5xi32>
-{
-  %0 = linalg.indexed_generic {
-     indexing_maps = [#map0, #map0, #map0],
-     iterator_types = ["parallel", "parallel"]}
-       ins(%arg0, %arg1 : tensor<?x?xi32>, tensor<?x?xi32>)
-      outs(%arg0 : tensor<?x?xi32>) {
-    ^bb0(%arg3 : index, %arg4 : index, %arg5: i32, %arg6: i32, %s: i32):       // no predecessors
-      %1 = muli %arg5, %arg6 : i32
-      %2 = index_cast %arg3 : index to i32
-      %3 = addi %1, %2 : i32
-      %4 = index_cast %arg4 : index to i32
-      %5 = addi %3, %4 : i32
-      linalg.yield %5 : i32
-  } -> tensor<?x?xi32>
-  %1 = linalg.tensor_reshape %0 [[0], [1, 2, 3]] :
-    tensor<?x?xi32> into tensor<?x?x4x5xi32>
-  return %1 : tensor<?x?x4x5xi32>
-}
-// The generic op version of the test check for the op structure. Only
-// checking the op body here.
-//       CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 5 + d2 * 20)>
-//       CHECK: func @indexed_generic_op_reshape_consumer_fusion
-//       CHECK:   linalg.generic
-//       CHECK:   ^{{.*}}(
-//  CHECK-SAME:     %[[ARG6:[a-zA-Z0-9]+]]: i32, %[[ARG7:[a-zA-Z0-9]+]]: i32,
-//  CHECK-SAME:     %[[ARG8:[a-zA-Z0-9]+]]: i32)
-//       CHECK:     %[[ARG2:.+]] = linalg.index 0 : index
-//       CHECK:     %[[ARG3:.+]] = linalg.index 1 : index
-//       CHECK:     %[[ARG4:.+]] = linalg.index 2 : index
-//       CHECK:     %[[ARG5:.+]] = linalg.index 3 : index
-//       CHECK:     %[[T3:.+]] = affine.apply #[[MAP]](%[[ARG5]], %[[ARG4]], %[[ARG3]])
-//       CHECK:     %[[T4:.+]] = muli %[[ARG6]], %[[ARG7]]
-//       CHECK:     %[[T5:.+]] = index_cast %[[ARG2]]
-//       CHECK:     %[[T6:.+]] = addi %[[T4]], %[[T5]]
-//       CHECK:     %[[T7:.+]] = index_cast %[[T3]]
-//       CHECK:     %[[T8:.+]] = addi %[[T6]], %[[T7]]
-//       CHECK:     linalg.yield %[[T8]]
-
-// -----
-
 #map0 = affine_map<(d0, d1) -> (d0, d1)>
 func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
                                          %arg1 : tensor<?x?xi32>) ->
@@ -356,69 +262,6 @@ func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
 
 // -----
 
-func @reshape_as_consumer_permutation
-  (%a : tensor<210x6x4xi32>, %b : tensor<210x4xi32>)
-    -> tensor<2x3x4x5x6x7xi32> {
-  %shape = linalg.init_tensor [6, 4, 210] : tensor<6x4x210xi32>
-  %c = linalg.indexed_generic {
-         indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
-                          affine_map<(d0, d1, d2) -> (d1, d2)>,
-                          affine_map<(d0, d1, d2) -> (d0, d2, d1)>],
-         iterator_types = ["parallel", "parallel", "parallel"]}
-          ins(%a, %b : tensor<210x6x4xi32>, tensor<210x4xi32>)
-          outs(%shape : tensor<6x4x210xi32>) {
-       ^bb0(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : i32, %arg4: i32, %s: i32):
-         %1 = addi %arg3, %arg4 : i32
-         %2 = index_cast %arg0 : index to i32
-         %3 = addi %1, %2 : i32
-         %4 = index_cast %arg1 : index to i32
-         %5 = addi %3, %4 : i32
-         %6 = index_cast %arg2 : index to i32
-         %7 = addi %5, %6 : i32
-         linalg.yield %7 : i32
-       } -> tensor<6x4x210xi32>
-  %d = linalg.tensor_reshape %c [[0, 1], [2], [3, 4, 5]]
-       : tensor<6x4x210xi32> into tensor<2x3x4x5x6x7xi32>
-  return %d : tensor<2x3x4x5x6x7xi32>
-}
-//   CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>
-//   CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
-//   CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)>
-//   CHECK-DAG: #[[MAP8:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)>
-//   CHECK-DAG: #[[MAP9:.+]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 7 + d2 * 42)>
-//       CHECK: func @reshape_as_consumer_permutation
-//  CHECK-SAME:   %[[ARG0:.+]]: tensor<210x6x4xi32>
-//  CHECK-SAME:   %[[ARG1:.+]]: tensor<210x4xi32>
-//   CHECK-DAG:   %[[T1:.+]] = linalg.tensor_reshape %[[ARG0]]
-//  CHECK-SAME:     [0, 1, 2], [3, 4], [5]
-//   CHECK-DAG:   %[[T2:.+]] = linalg.tensor_reshape %[[ARG1]]
-//  CHECK-SAME:     [0, 1, 2], [3]
-//   CHECK-DAG:   %[[T0:.+]] = linalg.init_tensor [2, 3, 4, 5, 6, 7]
-//       CHECK:   %[[T4:.+]] = linalg.generic
-//  CHECK-SAME:     indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]]]
-//  CHECK-SAME:     ins(%[[T1]], %[[T2]] : tensor<5x6x7x2x3x4xi32>, tensor<5x6x7x4xi32>)
-//  CHECK-SAME:     outs(%[[T0]] : tensor<2x3x4x5x6x7xi32>)
-//       CHECK:   ^{{.+}}(
-//  CHECK-SAME:     %[[ARG8:[a-zA-Z0-9]+]]: i32, %[[ARG9:[a-zA-Z0-9]+]]: i32,
-//  CHECK-SAME:     %[[ARG10:[a-zA-Z0-9]+]]: i32)
-//       CHECK:     %[[ARG2:.+]] = linalg.index 0 : index
-//       CHECK:     %[[ARG3:.+]] = linalg.index 1 : index
-//       CHECK:     %[[ARG4:.+]] = linalg.index 2 : index
-//       CHECK:     %[[ARG5:.+]] = linalg.index 3 : index
-//       CHECK:     %[[ARG6:.+]] = linalg.index 4 : index
-//       CHECK:     %[[ARG7:.+]] = linalg.index 5 : index
-//   CHECK-DAG:       %[[T5:.+]] = affine.apply #[[MAP8]](%[[ARG3]], %[[ARG2]])
-//   CHECK-DAG:       %[[T6:.+]] = affine.apply #[[MAP9]](%[[ARG6]], %[[ARG5]], %[[ARG4]])
-//   CHECK-DAG:       %[[T7:.+]] = addi %[[ARG8]], %[[ARG9]]
-//       CHECK:       %[[T8:.+]] = index_cast %[[T5]]
-//       CHECK:       %[[T9:.+]] = addi %[[T7]], %[[T8]]
-//       CHECK:       %[[T10:.+]] = index_cast %[[T6]]
-//       CHECK:       %[[T11:.+]] = addi %[[T9]], %[[T10]]
-//       CHECK:       %[[T12:.+]] = index_cast %[[ARG7]]
-//       CHECK:       %[[T13:.+]] = addi %[[T11]], %[[T12]]
-
-// -----
-
 func @reshape_as_consumer_permutation
   (%a : tensor<210x6x4xi32>, %b : tensor<210x4xi32>)
     -> tensor<2x3x4x5x6x7xi32> {
@@ -487,59 +330,6 @@ func @reshape_as_consumer_permutation
 
 // -----
 
-func @reshape_as_producer_projected_permutation(
-    %arg0 : tensor<33x8x?xi32>, %shape : tensor<264x?x4xi32>) -> tensor<264x?x4xi32>
-{
-  %0 = linalg.tensor_reshape %arg0 [[0, 1], [2]]
-    : tensor<33x8x?xi32> into tensor<264x?xi32>
-  %1 = linalg.indexed_generic
-    {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>,
-                      affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
-     iterator_types = ["parallel", "parallel", "parallel"]}
-     ins(%0 : tensor<264x?xi32>)
-    outs(%shape : tensor<264x?x4xi32>) {
-  ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: i32, %s: i32):  // no predecessors
-    %2 = index_cast %arg1 : index to i32
-    %3 = addi %arg4, %2 : i32
-    %4 = index_cast %arg2 : index to i32
-    %5 = addi %3, %4 : i32
-    %6 = index_cast %arg3 : index to i32
-    %7 = addi %5, %6 : i32
-    linalg.yield %7 : i32
-  } -> tensor<264x?x4xi32>
-  return %1 : tensor<264x?x4xi32>
-}
-
-//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-//   CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 8)>
-//       CHECK: @reshape_as_producer_projected_permutation
-//  CHECK-SAME:   %[[ARG0:.+]]: tensor<33x8x?xi32>
-//       CHECK:   %[[RES:.+]] = linalg.generic
-//  CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]]]
-//  CHECK-SAME:     ins(%[[ARG0]] : tensor<33x8x?xi32>)
-//       CHECK:   ^{{.+}}(
-//  CHECK-SAME:     %[[ARG5:[a-zA-Z0-9]+]]: i32,
-//  CHECK-SAME:     %[[ARG7:[a-zA-Z0-9]+]]: i32)
-//       CHECK:       %[[ARG1:.+]] = linalg.index 0 : index
-//       CHECK:       %[[ARG2:.+]] = linalg.index 1 : index
-//       CHECK:       %[[ARG3:.+]] = linalg.index 2 : index
-//       CHECK:       %[[ARG4:.+]] = linalg.index 3 : index
-//       CHECK:       %[[T0:.+]] = affine.apply #[[MAP2]](%[[ARG2]], %[[ARG1]])
-//       CHECK:       %[[T1:.+]] = index_cast %[[T0]] : index to i32
-//       CHECK:       %[[T2:.+]] = addi %[[ARG5]], %[[T1]] : i32
-//       CHECK:       %[[T3:.+]] = index_cast %[[ARG3]] : index to i32
-//       CHECK:       %[[T4:.+]] = addi %[[T2]], %[[T3]] : i32
-//       CHECK:       %[[T5:.+]] = index_cast %[[ARG4]] : index to i32
-//       CHECK:       %[[T6:.+]] = addi %[[T4]], %[[T5]] : i32
-//       CHECK:       linalg.yield %[[T6]] : i32
-//       CHECK:    %[[RES2:.+]] = linalg.tensor_reshape %[[RES]]
-//  CHECK-SAME:      [0, 1], [2], [3]
-//  CHECK-SAME:    : tensor<33x8x?x4xi32> into tensor<264x?x4xi32>
-//       CHECK:  return %[[RES2]] : tensor<264x?x4xi32>
-
-// -----
-
 func @reshape_as_producer_projected_permutation(
     %arg0 : tensor<33x8x?xi32>, %shape : tensor<264x?x4xi32>) -> tensor<264x?x4xi32>
 {

diff  --git a/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir
index 623350de97a8f..15fc2b5de0ae2 100644
--- a/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir
@@ -1,75 +1,18 @@
 // RUN: mlir-opt -split-input-file -linalg-fold-reshape-ops-by-linearization %s | FileCheck %s
 
 #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xf32>,
-    %arg1 : tensor<?x?x4x?xf32>) -> tensor<?x?x4x?xf32> {
-  %0 = linalg.tensor_reshape %arg0 [[0], [1, 2], [3]] :
-    tensor<?x?x?xf32> into tensor<?x?x4x?xf32>
-  %1 = linalg.generic {
-    indexing_maps = [#map0, #map0, #map0],
-    iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
-    ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>)
-    outs(%0 : tensor<?x?x4x?xf32>) {
-    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):       // no predecessors
-      %1 = mulf %arg3, %arg4 : f32
-      linalg.yield %1 : f32
-  } -> tensor<?x?x4x?xf32>
-  return %1 : tensor<?x?x4x?xf32>
-}
-//   CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
-//   CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-//       CHECK: func @generic_op_reshape_producer_fusion
-//  CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?xf32>
-//       CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
-//  CHECK-SAME:   [0], [1, 2], [3]
-//       CHECK: linalg.generic
-//  CHECK-SAME:   indexing_maps = [#[[MAP3]], #[[MAP4]], #[[MAP4]]]
-//  CHECK-SAME:   ins(%[[ARG0]], %{{.+}} : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>)
-//  CHECK-SAME:   outs(%[[T0]] : tensor<?x?x4x?xf32>)
-
-// -----
-
-#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xf32>,
-    %arg1 : tensor<?x?x4x5xf32>) -> tensor<?x?xf32> {
-  %0 = linalg.generic {
-    indexing_maps = [#map0, #map0, #map0],
-    iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
-    ins(%arg0, %arg1 : tensor<?x?x4x5xf32>, tensor<?x?x4x5xf32>)
-    outs(%arg0 : tensor<?x?x4x5xf32>){
-    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):       // no predecessors
-      %1 = mulf %arg3, %arg4 : f32
-      linalg.yield %1 : f32
-  } -> tensor<?x?x4x5xf32>
-  %1 = linalg.tensor_reshape %0 [[0], [1, 2, 3]] :
-    tensor<?x?x4x5xf32> into tensor<?x?xf32>
-  return %1 : tensor<?x?xf32>
-}
-
-//   CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-//   CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>
-//       CHECK: func @generic_op_reshape_consumer_fusion
-//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x4x5xf32>
-//       CHECK:   %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
-//  CHECK-SAME:     [0], [1, 2, 3]
-//       CHECK:   linalg.generic
-//  CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP3]]]
-//  CHECK-SAME:     outs(%[[T0]] : tensor<?x?xf32>)
-
-// -----
-
-#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-func @indexed_generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xi32>)
+func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xi32>)
   -> tensor<?x?x4x?xi32> {
   %0 = linalg.tensor_reshape %arg0 [[0], [1, 2], [3]] :
     tensor<?x?x?xi32> into tensor<?x?x4x?xi32>
-  %1 = linalg.indexed_generic {
+  %1 = linalg.generic {
     indexing_maps = [#map0, #map0],
     iterator_types = ["parallel", "parallel", "parallel", "parallel"] }
     ins(%0 : tensor<?x?x4x?xi32>)
     outs(%0 : tensor<?x?x4x?xi32>) {
-  ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: i32, %arg7 : i32):       // no predecessors
-    %2 = index_cast %arg2 : index to i32
+  ^bb0(%arg6: i32, %arg7 : i32):       // no predecessors
+    %idx = linalg.index 0 : index
+    %2 = index_cast %idx : index to i32
     %3 = addi %arg6, %2 : i32
     linalg.yield %3 : i32
   } -> tensor<?x?x4x?xi32>
@@ -77,26 +20,29 @@ func @indexed_generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xi32>)
 }
 //   CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
 //   CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-//       CHECK: func @indexed_generic_op_reshape_producer_fusion
+//       CHECK: func @generic_op_reshape_producer_fusion
 //  CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?xi32>
 //       CHECK:   %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
 //  CHECK-SAME:     [0], [1, 2], [3]
-//       CHECK:   linalg.indexed_generic
+//       CHECK:   linalg.generic
 //  CHECK-SAME:     indexing_maps = [#[[MAP3]], #[[MAP4]]]
 //  CHECK-SAME:     ins(%[[ARG0]] : tensor<?x?x?xi32>)
 //  CHECK-SAME:     outs(%[[T0]] : tensor<?x?x4x?xi32>)
+//       CHECK:   %[[IDX:.+]] = linalg.index 0 : index
+//  CHECK-NEXT:   %[[IDX_CASTED:.+]] = index_cast %[[IDX]] : index to i32
 
 // -----
 
 #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xi32>)
+func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xi32>)
   -> tensor<?x?xi32> {
-  %0 = linalg.indexed_generic {
+  %0 = linalg.generic {
     indexing_maps = [#map0, #map0],
     iterator_types = ["parallel", "parallel", "parallel", "parallel"] }
     ins(%arg0 : tensor<?x?x4x5xi32>) outs(%arg0 : tensor<?x?x4x5xi32>) {
-  ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: i32, %arg7: i32):       // no predecessors
-    %2 = index_cast %arg2 : index to i32
+  ^bb0(%arg6: i32, %arg7: i32):       // no predecessors
+    %idx = linalg.index 0 : index
+    %2 = index_cast %idx : index to i32
     %3 = addi %arg6, %2 : i32
     linalg.yield %3 : i32
   } -> tensor<?x?x4x5xi32>
@@ -106,13 +52,15 @@ func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xi32>)
 }
 //   CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 //   CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>
-//       CHECK: func @indexed_generic_op_reshape_consumer_fusion
+//       CHECK: func @generic_op_reshape_consumer_fusion
 //  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x4x5xi32>
 //       CHECK:   %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
 //  CHECK-SAME:     [0], [1, 2, 3]
-//       CHECK:   linalg.indexed_generic
+//       CHECK:   linalg.generic
 //  CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP3]]]
 //  CHECK-SAME:     outs(%[[T0]] : tensor<?x?xi32>)
+//       CHECK:   %[[IDX:.+]] = linalg.index 0 : index
+//  CHECK-NEXT:   %[[IDX_CASTED:.+]] = index_cast %[[IDX]] : index to i32
 //   CHECK-NOT:   linalg.tensor_reshape
 
 // -----


        


More information about the Mlir-commits mailing list