[Mlir-commits] [mlir] 485c21b - [mlir] Split linalg reshape ops	into expand/collapse.
    Alexander Belyaev 
    llvmlistbot at llvm.org
       
    Thu Jun  3 02:40:36 PDT 2021
    
    
  
Author: Alexander Belyaev
Date: 2021-06-03T11:40:22+02:00
New Revision: 485c21be8ac300209bac0db03bc8f476aa8b9764
URL: https://github.com/llvm/llvm-project/commit/485c21be8ac300209bac0db03bc8f476aa8b9764
DIFF: https://github.com/llvm/llvm-project/commit/485c21be8ac300209bac0db03bc8f476aa8b9764.diff
LOG: [mlir] Split linalg reshape ops into expand/collapse.
Differential Revision: https://reviews.llvm.org/D103548
Added: 
    
Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
    mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
    mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
    mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
    mlir/test/Dialect/Linalg/bufferize.mlir
    mlir/test/Dialect/Linalg/canonicalize.mlir
    mlir/test/Dialect/Linalg/detensorize_0d.mlir
    mlir/test/Dialect/Linalg/detensorize_if.mlir
    mlir/test/Dialect/Linalg/detensorize_trivial.mlir
    mlir/test/Dialect/Linalg/detensorize_while.mlir
    mlir/test/Dialect/Linalg/detensorize_while_failure.mlir
    mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir
    mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
    mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
    mlir/test/Dialect/Linalg/invalid.mlir
    mlir/test/Dialect/Linalg/llvm.mlir
    mlir/test/Dialect/Linalg/reshape_fusion.mlir
    mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir
Removed: 
    
################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index abba73d74f874..d74e407640b0b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -330,7 +330,12 @@ class Linalg_ReshapeLikeOp<string mnemonic, list<OpTrait> traits = []> :
     // be either a contracting or expanding reshape.
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationIndices>":$reassociation,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+    [{
+      build($_builder, $_state, resultType, src, attrs);
+      $_state.addAttribute("reassociation",
+                          getReassociationIndicesAttribute($_builder, reassociation));
+    }]>,
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationExprs>":$reassociation,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
@@ -355,21 +360,33 @@ class Linalg_ReshapeLikeOp<string mnemonic, list<OpTrait> traits = []> :
       return reassociationIndices;
     };
   }];
+
+  let parser = [{ return ::parseReshapeLikeOp(parser, result); }];
 }
 
 def IndexListArrayAttr :
   TypedArrayAttrBase<I64ArrayAttr, "Array of 64-bit integer array attributes">;
 
-def Linalg_ReshapeOp : Linalg_ReshapeLikeOp<"reshape",
+class Linalg_ReshapeOp<string mnemonic> : Linalg_ReshapeLikeOp<mnemonic,
     [DeclareOpInterfaceMethods<ViewLikeOpInterface>]>,
     Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>,
     Results<(outs AnyStridedMemRef:$result)> {
-  let summary = "linalg.reshape produces a new view into the operand view";
+  let extraClassDeclaration = commonExtraClassDeclaration # [{
+    MemRefType getSrcType() { return src().getType().cast<MemRefType>(); }
+    MemRefType getResultType() { return result().getType().cast<MemRefType>(); }
+  }];
+  let hasFolder = 1;
+  let hasCanonicalizer = 1;
+  let printer = [{ return ::print(p, *this); }];
+}
+
+def Linalg_ExpandShapeOp : Linalg_ReshapeOp<"expand_shape"> {
+  let summary = "operation to produce a memref with a higher rank.";
   let description = [{
-    The `linalg.reshape` op produces a new view whose sizes are a reassociation
-    of the original `view`. Depending on whether or not the reassociated
-    MemRefType is contiguous, the resulting memref may require explicit alloc
-    and copies.
+    The `linalg.expand_shape` op produces a new view with a higher rank whose
+    sizes are a reassociation of the original `view`. Depending on whether or
+    not the reassociated MemRefType is contiguous, the resulting memref may
+    require explicit alloc and copies.
 
     A reassociation is defined as a continuous grouping of dimensions and is
     represented with an array of I64ArrayAttr attribute.
@@ -381,85 +398,67 @@ def Linalg_ReshapeOp : Linalg_ReshapeLikeOp<"reshape",
     All other cases are undefined behavior and a reshape op may not lower to
     LLVM if it cannot be proven statically that it does not require alloc+copy.
 
-    A reshape may either collapse or expand dimensions, depending on the
-    relationship between source and target memref ranks. The verification rule
-    is that the reassociation maps are applied to the memref with the larger
-    rank to obtain the memref with the smaller rank. In the case of a dimension
-    expansion, the reassociation maps can be interpreted as inverse maps.
-
-    The result memref type of a reshape when dimensions are collapsed
-    (operand memref type when dimensions are expanded) can be
-    zero-ranked if the operand memref type (or the result memref type
-    when dimensions are expanded) is statically shaped with all
-    dimensions being unit extent. In such cases the reassociation map
-    is empty.
+    The operand memref type when dimensions can be zero-ranked if the result
+    memref type is statically shaped with all dimensions being unit extent. In
+    such case the reassociation map is empty.
 
-    Examples:
+    The verification rule is that the reassociation maps are applied to the
+    result memref with the larger rank to obtain the operand memref with the
+    smaller rank.
 
-    ```mlir
-    // Dimension collapse (i, j) -> i' and k -> k'
-    %1 = linalg.reshape %0 [[0, 1], [2]] :
-      memref<?x?x?xf32, stride_spec> into memref<?x?xf32, stride_spec_2>
-    ```
+    Example:
 
     ```mlir
     // Dimension expansion i -> (i', j') and (k) -> (k')
-    %1 = linalg.reshape %0 [[0, 1], [2]] :
+    %1 = linalg.expand_shape %0 [[0, 1], [2]] :
       memref<?x?xf32, stride_spec> into memref<?x?x?xf32, stride_spec_2>
     ```
   }];
-  let extraClassDeclaration = commonExtraClassDeclaration # [{
-    MemRefType getSrcType() { return src().getType().cast<MemRefType>(); }
-    MemRefType getResultType() { return result().getType().cast<MemRefType>(); }
-  }];
-  let hasFolder = 1;
-  let hasCanonicalizer = 1;
-  let printer = [{ return ::print(p, *this); }];
-  let parser = [{ return ::parseReshapeLikeOp(parser, result); }];
 }
 
-def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<
-    "tensor_reshape",
-    [DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-      ["reifyReturnTypeShapesPerResultDim"]>]>,
-    Arguments<(ins AnyTensor:$src,
-                   IndexListArrayAttr:$reassociation)>,
-    Results<(outs AnyTensor:$result)> {
-  let summary = "linalg.tensor_reshape produces a new reshaped tensor.";
+def Linalg_CollapseShapeOp : Linalg_ReshapeOp<"collapse_shape"> {
+  let summary = "operation to produce a memref with a smaller rank.";
   let description = [{
-    The `linalg.reshape` op produces a new tensor whose sizes are a
-    reassociation of the original `src`.
+    The `linalg.collapse_shape` op produces a new view with a smaller rank
+    whose sizes are a reassociation of the original `view`. Depending on
+    whether or not the reassociated MemRefType is contiguous, the resulting
+    memref may require explicit alloc and copies.
 
     A reassociation is defined as a continuous grouping of dimensions and is
     represented with an array of I64ArrayAttr attribute.
 
-    A reshape may either collapse or expand dimensions, depending on the
-    relationship between source and target tensor ranks. The verification rule
-    is that the reassociation maps are applied to the tensor with the larger
-    rank to obtain the tensor with the smaller rank. In the case of a dimension
-    expansion, the reassociation maps can be interpreted as inverse maps.
+    For now, it is assumed that either:
+      1. a reassociation produces and consumes contiguous MemRefType or,
+      2. the reshape op will be folded into its consumers (by changing the shape
+         of the computations).
+    All other cases are undefined behavior and a reshape op may not lower to
+    LLVM if it cannot be proven statically that it does not require alloc+copy.
+
+    The result memref type of a reshape can be zero-ranked if the operand
+    memref type is statically shaped with all dimensions being unit extent. In
+    such case the reassociation map is empty.
 
-    The result tensor type of a reshape when dimensions are collapsed
-    (operand tensor type when dimensions are expanded) can be
-    zero-ranked if the operand tensor type (or the result tensor type
-    when dimensions are expanded) is statically shaped with all
-    dimensions being unit extent. In such cases the reassociation map
-    is empty.
+    The verification rule is that the reassociation maps are applied to the
+    operand memref with the larger rank to obtain the result memref with the
+    smaller rank.
 
     Examples:
 
     ```mlir
     // Dimension collapse (i, j) -> i' and k -> k'
-    %b = linalg.tensor_reshape %a [[0, 1], [2]]
-        : tensor<?x?x?xf32> into tensor<?x?xf32>
-    ```
-
-    ```mlir
-    // Dimension expansion i -> (i', j') and (k) -> (k')
-    %b = linalg.tensor_reshape %a [[0, 1], [2]]
-        : tensor<?x?xf32> into tensor<?x?x?xf32>
+    %1 = linalg.collapse_shape %0 [[0, 1], [2]] :
+      memref<?x?x?xf32, stride_spec> into memref<?x?xf32, stride_spec_2>
     ```
   }];
+}
+
+class Linalg_TensorReshapeOp<string mnemonic> : Linalg_ReshapeLikeOp<
+    mnemonic,
+    [DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+      ["reifyReturnTypeShapesPerResultDim"]>]>,
+    Arguments<(ins AnyTensor:$src,
+                   IndexListArrayAttr:$reassociation)>,
+    Results<(outs AnyTensor:$result)> {
   let extraClassDeclaration = commonExtraClassDeclaration # [{
     RankedTensorType getSrcType() {
       return src().getType().cast<RankedTensorType>();
@@ -474,6 +473,60 @@ def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<
   let parser = [{ return ::parseReshapeLikeOp(parser, result); }];
 }
 
+def Linalg_TensorExpandShapeOp : Linalg_TensorReshapeOp<"tensor_expand_shape"> {
+  let summary = "operation to produce a tensor with a higher rank";
+  let description = [{
+    The `linalg.tensor_expand_shape` op produces a new tensor with a higher
+    rank whose sizes are a reassociation of the original `src`.
+
+    A reassociation is defined as a continuous grouping of dimensions and is
+    represented with an array of I64ArrayAttr attribute.
+
+    The verification rule is that the reassociation maps are applied to the
+    result tensor with the higher rank to obtain the operand tensor with the
+    smaller rank.
+
+    The operand tensor type of a reshape can be zero-ranked if the result
+    tensor type is statically shaped with all dimensions being unit extent. In
+    such cases the reassociation map is empty.
+
+    Examples:
+
+    ```mlir
+    // Dimension expansion i -> (i', j') and (k) -> (k')
+    %b = linalg.tensor_expand_shape %a [[0, 1], [2]]
+        : tensor<?x?xf32> into tensor<?x?x?xf32>
+    ```
+  }];
+}
+
+def Linalg_TensorCollapseShapeOp : Linalg_TensorReshapeOp<"tensor_collapse_shape"> {
+  let summary = "operation to produce a tensor with a smaller rank";
+  let description = [{
+    The `linalg.tensor_collapse_shape` op produces a new tensor with a smaller
+    rank whose sizes are a reassociation of the original `src`.
+
+    A reassociation is defined as a continuous grouping of dimensions and is
+    represented with an array of I64ArrayAttr attribute.
+
+    The verification rule is that the reassociation maps are applied to the
+    operand tensor with the higher rank to obtain the result tensor with the
+    smaller rank.
+
+    The result tensor type of a reshape can be zero-ranked if the operand
+    tensor type is statically shaped with all dimensions being unit extent. In
+    such case the reassociation map is empty.
+
+    Examples:
+
+    ```mlir
+    // Dimension collapse (i, j) -> i' and k -> k'
+    %b = linalg.tensor_collapse_shape %a [[0, 1], [2]]
+        : tensor<?x?x?xf32> into tensor<?x?xf32>
+    ```
+  }];
+}
+
 def Linalg_YieldOp : Linalg_Op<"yield", [NoSideEffect, ReturnLike, Terminator]>,
     Arguments<(ins Variadic<AnyType>:$values)> {
   let summary = "Linalg yield operation";
diff  --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index 435f5b9ec7bca..66e9e332f93af 100644
--- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -95,9 +95,11 @@ class RangeOpConversion : public ConvertOpToLLVMPattern<RangeOp> {
 // ReshapeOp creates a new view descriptor of the proper rank.
 // For now, the only conversion supported is for target MemRef with static sizes
 // and strides.
+template <typename ReshapeOp>
 class ReshapeOpConversion : public ConvertOpToLLVMPattern<ReshapeOp> {
 public:
   using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
+  using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
 
   LogicalResult
   matchAndRewrite(ReshapeOp reshapeOp, ArrayRef<Value> operands,
@@ -118,8 +120,9 @@ class ReshapeOpConversion : public ConvertOpToLLVMPattern<ReshapeOp> {
     ReshapeOpAdaptor adaptor(operands);
     MemRefDescriptor baseDesc(adaptor.src());
     Location loc = reshapeOp->getLoc();
-    auto desc = MemRefDescriptor::undef(rewriter, reshapeOp->getLoc(),
-                                        typeConverter->convertType(dstType));
+    auto desc =
+        MemRefDescriptor::undef(rewriter, reshapeOp->getLoc(),
+                                this->typeConverter->convertType(dstType));
     desc.setAllocatedPtr(rewriter, loc, baseDesc.allocatedPtr(rewriter, loc));
     desc.setAlignedPtr(rewriter, loc, baseDesc.alignedPtr(rewriter, loc));
     desc.setOffset(rewriter, loc, baseDesc.offset(rewriter, loc));
@@ -149,7 +152,8 @@ class YieldOpConversion : public ConvertOpToLLVMPattern<linalg::YieldOp> {
 /// Populate the given list with patterns that convert from Linalg to LLVM.
 void mlir::populateLinalgToLLVMConversionPatterns(LLVMTypeConverter &converter,
                                                   RewritePatternSet &patterns) {
-  patterns.add<RangeOpConversion, ReshapeOpConversion, YieldOpConversion>(
+  patterns.add<RangeOpConversion, ReshapeOpConversion<ExpandShapeOp>,
+               ReshapeOpConversion<CollapseShapeOp>, YieldOpConversion>(
       converter);
 
   // Populate the type conversions for the linalg types.
diff  --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
index b5c1894e5dcd6..28dd7bb860c78 100644
--- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
+++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
@@ -191,7 +191,8 @@ void ConvertLinalgToStandardPass::runOnOperation() {
   target.addLegalDialect<AffineDialect, memref::MemRefDialect, scf::SCFDialect,
                          StandardOpsDialect>();
   target.addLegalOp<ModuleOp, FuncOp, ReturnOp>();
-  target.addLegalOp<linalg::ReshapeOp, linalg::RangeOp>();
+  target.addLegalOp<linalg::ExpandShapeOp, linalg::CollapseShapeOp,
+                    linalg::RangeOp>();
   RewritePatternSet patterns(&getContext());
   populateLinalgToStandardConversionPatterns(patterns);
   if (failed(applyFullConversion(module, target, std::move(patterns))))
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 113c34304e761..808bb8d5a5d09 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1188,16 +1188,20 @@ class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
           getIdentityExprs(resultTy.getShape().size())};
 
       auto collapsedTy = RankedTensorType::get({totalElems}, elemTy);
-      Value collapsedOp = rewriter.create<linalg::TensorReshapeOp>(
+      Value collapsedOp = rewriter.create<linalg::TensorCollapseShapeOp>(
           loc, collapsedTy, args[0], collapsingMap);
-      rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
+      rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
           reshape, resultTy, collapsedOp, expandingMap);
 
       return success();
     }
 
-    rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
-        reshape, resultTy, args[0], reassociationMap);
+    if (resultTy.getRank() < args[0].getType().cast<ShapedType>().getRank())
+      rewriter.replaceOpWithNewOp<linalg::TensorCollapseShapeOp>(
+          reshape, resultTy, args[0], reassociationMap);
+    else
+      rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
+          reshape, resultTy, args[0], reassociationMap);
 
     return success();
   }
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 034bf188d175e..04c6e4b9d4029 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -882,13 +882,14 @@ struct FoldInitTensorWithSubTensorOp : public OpRewritePattern<SubTensorOp> {
   }
 };
 
+template <typename TensorReshapeOp>
 struct FoldInitTensorWithTensorReshapeOp
     : public OpRewritePattern<TensorReshapeOp> {
   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
                                 PatternRewriter &rewriter) const override {
-    if (!reshapeOp.src().getDefiningOp<InitTensorOp>())
+    if (!reshapeOp.src().template getDefiningOp<InitTensorOp>())
       return failure();
     Location loc = reshapeOp.getLoc();
     SmallVector<SmallVector<Value>, 4> resultShapes;
@@ -912,7 +913,9 @@ struct FoldInitTensorWithTensorReshapeOp
 
 void InitTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                MLIRContext *context) {
-  results.add<FoldInitTensorWithSubTensorOp, FoldInitTensorWithTensorReshapeOp,
+  results.add<FoldInitTensorWithSubTensorOp,
+              FoldInitTensorWithTensorReshapeOp<TensorExpandShapeOp>,
+              FoldInitTensorWithTensorReshapeOp<TensorCollapseShapeOp>,
               ReplaceStaticShapeDims>(context);
 }
 
@@ -1206,12 +1209,24 @@ static void print(OpAsmPrinter &p, ReshapeLikeOp op) {
   p << ": " << op.src().getType() << " into " << op.getType();
 }
 
-static void print(OpAsmPrinter &p, linalg::ReshapeOp op) {
-  print<linalg::ReshapeOp>(p, op);
+static void print(OpAsmPrinter &p, linalg::ExpandShapeOp op) {
+  print<linalg::ExpandShapeOp>(p, op);
 }
 
-static void print(OpAsmPrinter &p, linalg::TensorReshapeOp op) {
-  print<linalg::TensorReshapeOp>(p, op);
+static void print(OpAsmPrinter &p, linalg::CollapseShapeOp op) {
+  print<linalg::CollapseShapeOp>(p, op);
+}
+
+static void print(OpAsmPrinter &p, linalg::TensorExpandShapeOp op) {
+  print<linalg::TensorExpandShapeOp>(p, op);
+}
+
+static void print(OpAsmPrinter &p, linalg::TensorCollapseShapeOp op) {
+  print<linalg::TensorCollapseShapeOp>(p, op);
+}
+
+static constexpr StringRef getReassociationAttrName() {
+  return "reassociation";
 }
 
 static ParseResult parseReshapeLikeOp(OpAsmParser &parser,
@@ -1253,7 +1268,7 @@ static ParseResult parseReshapeLikeOp(OpAsmParser &parser,
     break;
   }
 
-  result.addAttribute(ReshapeOp::getReassociationAttrName(),
+  result.addAttribute(getReassociationAttrName(),
                       b.getArrayAttr(reassociation));
 
   // Parse optional attributes.
@@ -1334,36 +1349,10 @@ struct CollapseReshapeOps : public OpRewritePattern<ReshapeOpTy> {
     ShapedType srcReshapeSrcType = srcReshapeOp.getSrcType();
     ShapedType intermediateType = reshapeOp.getSrcType();
     ShapedType resultType = reshapeOp.getResultType();
-
-    auto areReshapeOpsFoldable = [](ShapedType largerType,
-                                    ShapedType intermediateType,
-                                    ShapedType smallerType) -> bool {
-      return largerType.getRank() > intermediateType.getRank() &&
-             intermediateType.getRank() > smallerType.getRank();
-    };
     Optional<SmallVector<ReassociationIndices>> reassociationIndices =
-        llvm::None;
-    // Check if producer and consumer are both expanding dims or both collapsing
-    // dims. In this case, try to compose the affine maps. This works for
-    // dynamic shapes too.
-    if (areReshapeOpsFoldable(resultType, intermediateType,
-                              srcReshapeSrcType) ||
-        areReshapeOpsFoldable(srcReshapeSrcType, intermediateType,
-                              resultType)) {
-      reassociationIndices = collapseReassociationIndices(
-          srcReshapeOp.getReassociationMaps(), reshapeOp.getReassociationMaps(),
-          rewriter.getContext());
-    }
-    if (!reassociationIndices) {
-      // If the source reshape can be collapsed/expanded into the target reshape
-      // they can still be folded. This can only be reasoned about statically
-      // for cases where
-      // - either all shapes are static, or
-      // - The number of dynamic dimensions matches in the source of source and
-      //   result with all other dimensions being 1.
-      reassociationIndices =
-          getReassociationIndicesForReshape(srcReshapeSrcType, resultType);
-    }
+        collapseReassociationIndices(srcReshapeOp.getReassociationMaps(),
+                                     reshapeOp.getReassociationMaps(),
+                                     rewriter.getContext());
     if (!reassociationIndices)
       return failure();
     rewriter.replaceOpWithNewOp<ReshapeOpTy>(
@@ -1371,15 +1360,55 @@ struct CollapseReshapeOps : public OpRewritePattern<ReshapeOpTy> {
     return success();
   }
 };
+
+/// Pattern to collapse producer/consumer reshape ops that are both collapsing
+/// dimensions or are both expanding dimensions.
+template <typename ReshapeOpTy, typename InverseReshapeOpTy>
+struct CollapseMixedReshapeOps : public OpRewritePattern<ReshapeOpTy> {
+  using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
+  LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
+                                PatternRewriter &rewriter) const override {
+    auto srcReshapeOp =
+        reshapeOp.src().template getDefiningOp<InverseReshapeOpTy>();
+    if (!srcReshapeOp)
+      return failure();
+
+    ShapedType srcReshapeSrcType = srcReshapeOp.getSrcType();
+    ShapedType intermediateType = reshapeOp.getSrcType();
+    ShapedType resultType = reshapeOp.getResultType();
+
+    // If the source reshape can be collapsed/expanded into the target reshape
+    // they can still be folded. This can only be reasoned about statically
+    // for cases where
+    // - either all shapes are static, or
+    // - The number of dynamic dimensions matches in the source of source and
+    //   result with all other dimensions being 1.
+    Optional<SmallVector<ReassociationIndices>> reassociationIndices =
+        getReassociationIndicesForReshape(srcReshapeSrcType, resultType);
+    if (!reassociationIndices)
+      return failure();
+    bool originalOpExpands =
+        intermediateType.getRank() > srcReshapeSrcType.getRank();
+    bool resultingOpExpands =
+        resultType.getRank() > srcReshapeSrcType.getRank();
+    if (!(resultingOpExpands ^ originalOpExpands))
+      rewriter.replaceOpWithNewOp<InverseReshapeOpTy>(
+          reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices);
+    else
+      rewriter.replaceOpWithNewOp<ReshapeOpTy>(
+          reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices);
+    return success();
+  }
+};
 } // namespace
 
-template <typename ReshapeOpTy>
+template <typename ReshapeOpTy, typename InverseReshapeOpTy>
 static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
                                   ArrayRef<Attribute> operands) {
   // Fold producer-consumer reshape ops that where the operand type of the
   // producer is same as the return type of the consumer.
-  ReshapeOpTy reshapeSrcOp =
-      reshapeOp.src().template getDefiningOp<ReshapeOpTy>();
+  auto reshapeSrcOp =
+      reshapeOp.src().template getDefiningOp<InverseReshapeOpTy>();
   if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
     return reshapeSrcOp.src();
   // Reshape of a constant can be replaced with a new constant.
@@ -1564,20 +1593,38 @@ convertReassociationIndicesToExprs(
   return reassociationMaps;
 }
 
-SmallVector<AffineMap, 4> ReshapeOp::getReassociationMaps() {
+SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
+  return getSymbolLessAffineMaps(getReassociationExprs());
+}
+SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
+  OpBuilder b(this->getContext());
+  return convertReassociationIndicesToExprs(b, getReassociationIndices());
+}
+SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
+  return getSymbolLessAffineMaps(getReassociationExprs());
+}
+SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
+  OpBuilder b(this->getContext());
+  return convertReassociationIndicesToExprs(b, getReassociationIndices());
+}
+
+SmallVector<AffineMap, 4> TensorCollapseShapeOp::getReassociationMaps() {
   return getSymbolLessAffineMaps(getReassociationExprs());
 }
-SmallVector<ReassociationExprs, 4> ReshapeOp::getReassociationExprs() {
+SmallVector<ReassociationExprs, 4>
+TensorCollapseShapeOp::getReassociationExprs() {
   OpBuilder b(this->getContext());
   return convertReassociationIndicesToExprs(b, getReassociationIndices());
 }
-SmallVector<AffineMap, 4> TensorReshapeOp::getReassociationMaps() {
+SmallVector<AffineMap, 4> TensorExpandShapeOp::getReassociationMaps() {
   return getSymbolLessAffineMaps(getReassociationExprs());
 }
-SmallVector<ReassociationExprs, 4> TensorReshapeOp::getReassociationExprs() {
+SmallVector<ReassociationExprs, 4>
+TensorExpandShapeOp::getReassociationExprs() {
   OpBuilder b(this->getContext());
   return convertReassociationIndicesToExprs(b, getReassociationIndices());
 }
+
 /// For reshape op compute the shape at dimension `dimIndex` of the output in
 /// terms of shape of the `src`, when the reshape op is a collapsing
 /// operation. It is the product of the shape of the collapsed dimensions of the
@@ -1708,7 +1755,7 @@ getReassociationIndicesAttribute(OpBuilder &b,
   return b.getArrayAttr(reassociationAttr);
 }
 
-void mlir::linalg::ReshapeOp::build(
+void mlir::linalg::ExpandShapeOp::build(
     OpBuilder &b, OperationState &result, Value src,
     ArrayRef<ReassociationIndices> reassociation,
     ArrayRef<NamedAttribute> attrs) {
@@ -1717,20 +1764,26 @@ void mlir::linalg::ReshapeOp::build(
       memRefType, getSymbolLessAffineMaps(
                       convertReassociationIndicesToExprs(b, reassociation)));
   build(b, result, resultType, src, attrs);
-  result.addAttribute(ReshapeOp::getReassociationAttrName(),
+  result.addAttribute(getReassociationAttrName(),
                       getReassociationIndicesAttribute(b, reassociation));
 }
 
-void mlir::linalg::ReshapeOp::build(
-    OpBuilder &b, OperationState &result, Type resultType, Value src,
+Value mlir::linalg::ExpandShapeOp::getViewSource() { return src(); }
+
+void mlir::linalg::CollapseShapeOp::build(
+    OpBuilder &b, OperationState &result, Value src,
     ArrayRef<ReassociationIndices> reassociation,
     ArrayRef<NamedAttribute> attrs) {
+  auto memRefType = src.getType().cast<MemRefType>();
+  auto resultType = computeReshapeCollapsedType(
+      memRefType, getSymbolLessAffineMaps(
+                      convertReassociationIndicesToExprs(b, reassociation)));
   build(b, result, resultType, src, attrs);
-  result.addAttribute(ReshapeOp::getReassociationAttrName(),
+  result.addAttribute(getReassociationAttrName(),
                       getReassociationIndicesAttribute(b, reassociation));
 }
 
-Value mlir::linalg::ReshapeOp::getViewSource() { return src(); }
+Value mlir::linalg::CollapseShapeOp::getViewSource() { return src(); }
 
 /// Verify that shapes of the reshaped types using following rules
 /// 1) if a dimension in the collapsed type is static, then the corresponding
@@ -1785,18 +1838,17 @@ static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType,
 
 // Common verifier for reshape-like types. Fills `expandedType` and
 // `collapsedType` with the proper `src` or `result` type.
-template <typename Op, typename T>
-static LogicalResult verifyReshapeLikeTypes(Op op, T &expandedType,
-                                            T &collapsedType) {
-  expandedType = op.getSrcType();
-  collapsedType = op.getResultType();
+template <typename Op, typename T,
+          bool isExpansion = std::is_same<Op, TensorExpandShapeOp>::value ||
+                             std::is_same<Op, ExpandShapeOp>::value>
+static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
+                                            T collapsedType) {
   unsigned expandedRank = expandedType.getRank();
   unsigned collapsedRank = collapsedType.getRank();
-  bool isCollapse = expandedRank > collapsedRank;
-  if (!isCollapse) {
-    std::swap(expandedRank, collapsedRank);
-    std::swap(expandedType, collapsedType);
-  }
+  if (expandedRank < collapsedRank)
+    return op.emitOpError("expected the type ")
+           << expandedType
+           << " to have higher rank than the type = " << collapsedType;
   if (expandedRank == 0)
     return op.emitOpError("expected non-zero memref ranks");
   if (expandedRank == collapsedRank)
@@ -1825,11 +1877,13 @@ static LogicalResult verifyReshapeLikeTypes(Op op, T &expandedType,
   if (!isReassociationValid(maps, &invalidIdx))
     return op.emitOpError("expected reassociation map #")
            << invalidIdx << " to be valid and contiguous";
-  return verifyReshapeLikeShapes(op, collapsedType, expandedType, !isCollapse);
+  return verifyReshapeLikeShapes(op, collapsedType, expandedType, isExpansion);
 }
 
-static LogicalResult verify(ReshapeOp op) {
-  MemRefType expandedType, collapsedType;
+template <typename TensorReshapeOp>
+static LogicalResult verifyReshapeOp(TensorReshapeOp op,
+                                     MemRefType expandedType,
+                                     MemRefType collapsedType) {
   if (failed(verifyReshapeLikeTypes(op, expandedType, collapsedType)))
     return failure();
   auto maps = op.getReassociationMaps();
@@ -1840,9 +1894,24 @@ static LogicalResult verify(ReshapeOp op) {
   return success();
 }
 
-void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                            MLIRContext *context) {
-  results.add<CollapseReshapeOps<ReshapeOp>>(context);
+static LogicalResult verify(ExpandShapeOp op) {
+  return verifyReshapeOp(op, op.getResultType(), op.getSrcType());
+}
+
+void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                                MLIRContext *context) {
+  results.add<CollapseReshapeOps<ExpandShapeOp>,
+              CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>>(context);
+}
+
+static LogicalResult verify(CollapseShapeOp op) {
+  return verifyReshapeOp(op, op.getSrcType(), op.getResultType());
+}
+
+void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                                  MLIRContext *context) {
+  results.add<CollapseReshapeOps<CollapseShapeOp>,
+              CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1877,7 +1946,7 @@ computeTensorReshapeCollapsedType(RankedTensorType type,
   return RankedTensorType::get(newShape, type.getElementType());
 }
 
-void mlir::linalg::TensorReshapeOp::build(
+void mlir::linalg::TensorCollapseShapeOp::build(
     OpBuilder &b, OperationState &result, Value src,
     ArrayRef<ReassociationIndices> reassociation,
     ArrayRef<NamedAttribute> attrs) {
@@ -1886,21 +1955,27 @@ void mlir::linalg::TensorReshapeOp::build(
       getSymbolLessAffineMaps(
           convertReassociationIndicesToExprs(b, reassociation)));
   build(b, result, resultType, src, attrs);
-  result.addAttribute(ReshapeOp::getReassociationAttrName(),
+  result.addAttribute(getReassociationAttrName(),
                       getReassociationIndicesAttribute(b, reassociation));
 }
 
-void mlir::linalg::TensorReshapeOp::build(
-    OpBuilder &b, OperationState &result, Type resultType, Value src,
+void mlir::linalg::TensorExpandShapeOp::build(
+    OpBuilder &b, OperationState &result, Value src,
     ArrayRef<ReassociationIndices> reassociation,
     ArrayRef<NamedAttribute> attrs) {
+  auto resultType = computeTensorReshapeCollapsedType(
+      src.getType().cast<RankedTensorType>(),
+      getSymbolLessAffineMaps(
+          convertReassociationIndicesToExprs(b, reassociation)));
   build(b, result, resultType, src, attrs);
-  result.addAttribute(ReshapeOp::getReassociationAttrName(),
+  result.addAttribute(getReassociationAttrName(),
                       getReassociationIndicesAttribute(b, reassociation));
 }
 
-static LogicalResult verify(TensorReshapeOp op) {
-  RankedTensorType expandedType, collapsedType;
+template <typename TensorReshapeOp>
+static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
+                                           RankedTensorType expandedType,
+                                           RankedTensorType collapsedType) {
   if (failed(verifyReshapeLikeTypes(op, expandedType, collapsedType)))
     return failure();
 
@@ -1913,9 +1988,18 @@ static LogicalResult verify(TensorReshapeOp op) {
   return success();
 }
 
+static LogicalResult verify(TensorExpandShapeOp op) {
+  return verifyTensorReshapeOp(op, op.getResultType(), op.getSrcType());
+}
+
+static LogicalResult verify(TensorCollapseShapeOp op) {
+  return verifyTensorReshapeOp(op, op.getSrcType(), op.getResultType());
+}
+
 namespace {
 /// Reshape of a splat constant can be replaced with a constant of the result
 /// type.
+template <typename TensorReshapeOp>
 struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
@@ -1936,11 +2020,12 @@ struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
 ///
 /// For such op chains, we can create new linalg.fill ops with the result
 /// type of the linalg.tensor_reshape op.
+template <typename TensorReshapeOp>
 struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
                                 PatternRewriter &rewriter) const override {
-    auto oldFill = reshapeOp.src().getDefiningOp<FillOp>();
+    auto oldFill = reshapeOp.src().template getDefiningOp<FillOp>();
     if (!oldFill)
       return failure();
 
@@ -1955,14 +2040,38 @@ struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
 };
 } // namespace
 
-void TensorReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                                  MLIRContext *context) {
-  results.add<CollapseReshapeOps<TensorReshapeOp>, FoldFillWithTensorReshape,
-              FoldInitTensorWithTensorReshapeOp, FoldReshapeWithConstant>(
-      context);
+void TensorExpandShapeOp::getCanonicalizationPatterns(
+    RewritePatternSet &results, MLIRContext *context) {
+  results
+      .add<CollapseReshapeOps<TensorExpandShapeOp>,
+           CollapseMixedReshapeOps<TensorExpandShapeOp, TensorCollapseShapeOp>,
+           FoldFillWithTensorReshape<TensorExpandShapeOp>,
+           FoldInitTensorWithTensorReshapeOp<TensorExpandShapeOp>,
+           FoldReshapeWithConstant<TensorExpandShapeOp>>(context);
+}
+
+void TensorCollapseShapeOp::getCanonicalizationPatterns(
+    RewritePatternSet &results, MLIRContext *context) {
+  results
+      .add<CollapseReshapeOps<TensorCollapseShapeOp>,
+           CollapseMixedReshapeOps<TensorCollapseShapeOp, TensorExpandShapeOp>,
+           FoldFillWithTensorReshape<TensorCollapseShapeOp>,
+           FoldInitTensorWithTensorReshapeOp<TensorCollapseShapeOp>,
+           FoldReshapeWithConstant<TensorCollapseShapeOp>>(context);
+}
+
+LogicalResult TensorExpandShapeOp::reifyReturnTypeShapesPerResultDim(
+    OpBuilder &b, SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
+  auto resultShape =
+      getAsValues(b, getLoc(),
+                  getReshapeOutputShapeFromInputShape(
+                      b, getLoc(), src(), getResultType().getShape(),
+                      getReassociationMaps()));
+  reifiedReturnShapes.emplace_back(std::move(resultShape));
+  return success();
 }
 
-LogicalResult TensorReshapeOp::reifyReturnTypeShapesPerResultDim(
+LogicalResult TensorCollapseShapeOp::reifyReturnTypeShapesPerResultDim(
     OpBuilder &b, SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
   auto resultShape =
       getAsValues(b, getLoc(),
@@ -2753,13 +2862,23 @@ std::string mlir::linalg::generateLibraryCallName(Operation *op) {
 // TODO: Consider making all this boilerplate easy to autogenerate
 // with Tablegen. This seems a desirable property in the context of
 // OpInterfaces where a Linalg "named" op **isa** LinalgOp.
-OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
   if (succeeded(foldMemRefCast(*this)))
     return getResult();
-  return foldReshapeOp(*this, operands);
+  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
+}
+OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
+  if (succeeded(foldMemRefCast(*this)))
+    return getResult();
+  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands);
+}
+OpFoldResult TensorExpandShapeOp::fold(ArrayRef<Attribute> operands) {
+  return foldReshapeOp<TensorExpandShapeOp, TensorCollapseShapeOp>(*this,
+                                                                   operands);
 }
-OpFoldResult TensorReshapeOp::fold(ArrayRef<Attribute> operands) {
-  return foldReshapeOp(*this, operands);
+OpFoldResult TensorCollapseShapeOp::fold(ArrayRef<Attribute> operands) {
+  return foldReshapeOp<TensorCollapseShapeOp, TensorExpandShapeOp>(*this,
+                                                                   operands);
 }
 
 //===----------------------------------------------------------------------===//
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
index 2627732790c77..1e627fe1ed82d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
@@ -149,17 +149,25 @@ class BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> {
 
 /// Conversion pattern that replaces `linalg.tensor_reshape` with
 /// `linalg.reshape`.
+template <typename TensorReshapeOp,
+          typename Adaptor = typename TensorReshapeOp::Adaptor>
 class BufferizeTensorReshapeOp : public OpConversionPattern<TensorReshapeOp> {
 public:
   using OpConversionPattern<TensorReshapeOp>::OpConversionPattern;
+  using ReshapeOp = typename std::conditional_t<
+      std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value, ExpandShapeOp,
+      CollapseShapeOp>;
 
   LogicalResult
   matchAndRewrite(TensorReshapeOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
-    linalg::TensorReshapeOpAdaptor adaptor(operands, op->getAttrDictionary());
-    rewriter.replaceOpWithNewOp<linalg::ReshapeOp>(
-        op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(),
-        adaptor.src(), adaptor.reassociation());
+    Adaptor adaptor(operands, op->getAttrDictionary());
+    rewriter.replaceOpWithNewOp<ReshapeOp>(op,
+                                           this->getTypeConverter()
+                                               ->convertType(op.getType())
+                                               .template cast<MemRefType>(),
+                                           adaptor.src(),
+                                           adaptor.reassociation());
     return success();
   }
 };
@@ -348,7 +356,8 @@ void mlir::linalg::populateLinalgBufferizePatterns(
       BufferizeAnyLinalgOp,
       BufferizeFillOp,
       BufferizeInitTensorOp,
-      BufferizeTensorReshapeOp,
+      BufferizeTensorReshapeOp<TensorExpandShapeOp>,
+      BufferizeTensorReshapeOp<TensorCollapseShapeOp>,
       SubTensorOpConverter,
       SubTensorInsertOpConverter
     >(typeConverter, patterns.getContext());
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index 04b63535b6dee..c4056d9103842 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -31,7 +31,7 @@ static Value sourceMaterializationCallback(OpBuilder &builder, Type type,
 
   // FromElementsOp results in a tensor<1xdtype>, we need to reshape that to
   // a tensor<dtype> instead.
-  return builder.create<linalg::TensorReshapeOp>(
+  return builder.create<linalg::TensorCollapseShapeOp>(
       loc, type, createNewTensorOp, ArrayRef<ReassociationExprs>{});
 }
 
@@ -159,8 +159,8 @@ class DetensorizeTypeConverter : public TypeConverter {
 /// Canonicalizes the pattern of the form
 ///
 /// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32>
-/// %reshaped_tensor = linalg.tensor_reshape %tensor [] : tensor<1xi32> into
-///   tensor<i32>
+/// %reshaped_tensor = linalg.tensor_collapse_shape %tensor []
+///     : tensor<1xi32> into tensor<i32>
 /// %extracted_element = tensor.extract %reshaped_tensor[] : tensor<i32>
 ///
 /// to just %element.
@@ -170,10 +170,11 @@ struct ExtractFromReshapeFromElements
 
   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
                                 PatternRewriter &rewriter) const final {
-    if (extract.indices().size() != 0)
+    if (!extract.indices().empty())
       return failure();
 
-    auto tensorReshape = extract.tensor().getDefiningOp<TensorReshapeOp>();
+    auto tensorReshape =
+        extract.tensor().getDefiningOp<TensorCollapseShapeOp>();
     if (tensorReshape == nullptr)
       return failure();
 
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 54297f847b432..5e8820535a41d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -362,10 +362,11 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
       for (auto operand : llvm::enumerate(values)) {
         if (operand.value().getType() == newInputOutputTypes[flattenedIdx])
           res.push_back(operand.value());
-        else
-          res.push_back(rewriter.create<linalg::TensorReshapeOp>(
+        else {
+          res.push_back(rewriter.create<TensorCollapseShapeOp>(
               loc, newInputOutputTypes[flattenedIdx], operand.value(),
               convertAffineMapArrayToExprs(reassociationMaps[flattenedIdx])));
+        }
         ++flattenedIdx;
       }
       return res;
@@ -395,11 +396,11 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
       RankedTensorType origResultType = genericOp.getResult(result.index())
                                             .getType()
                                             .template cast<RankedTensorType>();
-      if (origResultType != result.value().getType())
-        resultReplacements.push_back(rewriter.create<linalg::TensorReshapeOp>(
+      if (origResultType != result.value().getType()) {
+        resultReplacements.push_back(rewriter.create<TensorExpandShapeOp>(
             loc, origResultType, result.value(),
             convertAffineMapArrayToExprs(reassociationMaps[index])));
-      else
+      } else
         resultReplacements.push_back(result.value());
     }
     rewriter.replaceOp(genericOp, resultReplacements);
@@ -460,8 +461,8 @@ struct UseRankReducedSubTensorOp : public OpRewritePattern<SubTensorOp> {
     Location loc = subTensorOp.getLoc();
     Value newSubTensor = rewriter.create<SubTensorOp>(
         loc, rankReducedType, subTensorOp.source(), offsets, sizes, strides);
-    rewriter.replaceOpWithNewOp<TensorReshapeOp>(subTensorOp, resultType,
-                                                 newSubTensor, *reassociation);
+    rewriter.replaceOpWithNewOp<TensorExpandShapeOp>(
+        subTensorOp, resultType, newSubTensor, *reassociation);
     return success();
   }
 };
@@ -482,7 +483,7 @@ struct UseRankReducedSubTensorInsertOp
         reassociation->size() == static_cast<size_t>(sourceType.getRank()))
       return failure();
     Location loc = insertOp.getLoc();
-    auto reshapedSource = rewriter.create<TensorReshapeOp>(
+    auto reshapedSource = rewriter.create<TensorCollapseShapeOp>(
         loc, insertOp.source(), *reassociation);
     rewriter.replaceOpWithNewOp<SubTensorInsertOp>(
         insertOp, reshapedSource, insertOp.dest(), insertOp.getMixedOffsets(),
@@ -500,7 +501,8 @@ void mlir::linalg::populateFoldUnitExtentDimsPatterns(
   patterns.add<FoldUnitDimLoops, ReplaceUnitExtentTensors,
                UseRankReducedSubTensorOp, UseRankReducedSubTensorInsertOp>(
       context);
-  TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
+  TensorCollapseShapeOp::getCanonicalizationPatterns(patterns, context);
+  TensorExpandShapeOp::getCanonicalizationPatterns(patterns, context);
 }
 
 namespace {
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 9b2292f46c3a8..e202e627a18f7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -313,8 +313,7 @@ fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand,
 /// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>`
 ///
 /// and reshape:
-/// %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
-///                                affine_map<(i, j, k, l) -> (j, k, l)>] :
+/// %1 = linalg.tensor_collapse_shape %0 [[0], [0, 1, 2]] :
 ///        tensor<?x?x4x5xf32> into tensor<?x?xf32>
 ///
 /// would be rewritten into:
@@ -355,24 +354,21 @@ static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
                         resultExprs, context);
 }
 
-/// Checks if the `reshapeOp` can be fused with it consumer (if `asProducer` is
-/// true) or its producer (if `asProducer` is false) given the indexing map at
-/// its use.
-static bool isTensorReshapeOpFoldableByLinearization(TensorReshapeOp reshapeOp,
-                                                     AffineMap useIndexMap,
-                                                     bool asProducer) {
-  RankedTensorType returnType = reshapeOp.getResultType();
-  RankedTensorType operandType = reshapeOp.getSrcType();
-  // Reshape is fusable with its consumer (i.e. reshape as a producer) when its
-  // operand is of lesser rank than the result. Fusing when operand has higher
-  // rank will require use of mods and divs in the indexing maps of the fused op
-  // which would make it non-invertible. Similarly reshape is fused with its
-  // producer (i.e. reshape as consumer) only if the return type has lesser
-  // rank.
-  if ((asProducer && reshapeOp.getSrcType().hasStaticShape() &&
-       returnType.getRank() < operandType.getRank()) ||
-      (!asProducer && reshapeOp.getResultType().hasStaticShape() &&
-       operandType.getRank() < returnType.getRank()))
+// TensorExpandShapeOp is fusable with its consumer (i.e. reshape as a
+// producer). Fusing when operand has higher rank will require use of mods and
+// divs in the indexing maps of the fused op which would make it non-invertible.
+static bool isTensorReshapeOpFoldableByLinearization(
+    TensorExpandShapeOp expandOp, AffineMap useIndexMap, bool asProducer) {
+  if (!asProducer && expandOp.getResultType().hasStaticShape())
+    return false;
+  return useIndexMap.isPermutation();
+}
+
+// TensorCollapseShapeOp is fusable with its producer (i.e. reshape as a
+// consumer).
+static bool isTensorReshapeOpFoldableByLinearization(
+    TensorCollapseShapeOp collapseOp, AffineMap useIndexMap, bool asProducer) {
+  if (asProducer && collapseOp.getSrcType().hasStaticShape())
     return false;
   return useIndexMap.isPermutation();
 }
@@ -405,17 +401,14 @@ static bool isUnitDimExpansionOnly(ArrayRef<int64_t> expandedShape,
 ///         indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
 ///                          affine_map<(d0, d1, d2) -> (d1, d2)>,
 ///                          affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
-///  %d = linalg.tensor_reshape %c
-///         [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>,
-///          affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>,
-///          affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>]
+///  %d = linalg.tensor_expand_shape %c [[0, 1], [2], [3, 4, 5]]
 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
 ///
 ///  The reshape can be folded into the `genericOp` if its loop dimensionality
-///  is increased to match the result (operand) of the tensor_reshape when the
-///  reshape is expanding (folding). The indexing_map of the fused tensor in the
-///  `genericOp` and the reassociation map helps compute the indexing maps of
-///  the modified op. For the above example, based on the reassociation map it
+///  is increased to match the result (operand) of the tensor_expand_shape.
+///  The indexing_map of the fused tensor in the `genericOp` and the
+///  reassociation map helps compute the indexing maps of the modified op.
+///  For the above example, based on the reassociation map it
 ///  can be concluded that
 ///
 ///  - The loop used to access the first dimension of the fused tensor is split
@@ -443,14 +436,9 @@ static bool isUnitDimExpansionOnly(ArrayRef<int64_t> expandedShape,
 ///  Since operands to the linalg generic are now 5D, reshapes can be introduced
 ///  to make it consistent
 ///
-///  %0 = linalg.tensor_reshape %a
-///         [affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e2),
-///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e3, e4),
-///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e5)]
+///  %0 = linalg.tensor_expand_shape %a [[0, 1, 2], [3, 4], [5]]
 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
-///  %1 = linalg.tensor_reshape %b
-///         [affine_map<(e0, e1, e2, e3) -> (e0, e1, e2),
-///          affine_map<(e0, e1, e2, e3) -> (e3)]
+///  %1 = linalg.tensor_expand_shape %b [[0, 1, 2], [3]]
 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
 ///
 ///  The added reshapes are again expanding patterns, so they will get fused
@@ -614,11 +602,12 @@ static RankedTensorType getExpandedType(RankedTensorType originalType,
   return RankedTensorType::get(expandedShape, originalType.getElementType());
 }
 
-/// Returns the reassociation maps to use in the `linalg.tensor_reshape`
-/// operation to convert the operands of the origial operation to operands of
+/// Returns the reassociation maps to use in the `linalg.tensor_expand_shape`
+/// operation to convert the operands of the original operation to operands of
 /// the expanded operation. The same method is used to compute the
-/// `linalg.tensor_reshape` used to collapse the result of the expanded op to
-/// get the value that can replace all uses of the results of the original op.
+/// `linalg.tensor_collapse_shape` used to collapse the result of the expanded
+/// op to get the value that can replace all uses of the results of the original
+/// op.
 static SmallVector<ReassociationIndices>
 getReassociationForExpansion(AffineMap indexingMap,
                              const ExpansionInfo &expansionInfo) {
@@ -678,25 +667,29 @@ static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
   }
 }
 
-/// Implements the fusion of a tensor_reshape op and a generic op as explained
-/// in `isFusableWithReshapeByExpansion`. Assumes that those conditions have
-/// been satisfied.
+/// Implements the fusion of a tensor_collapse_shape or a tensor_expand_shape op
+/// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
+/// that those conditions have been satisfied.
 static Optional<SmallVector<Value>>
-fuseWithReshapeByExpansion(GenericOp genericOp, TensorReshapeOp reshapeOp,
+fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
                            OpOperand *fusableOpOperand,
                            PatternRewriter &rewriter) {
   assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) &&
          "preconditions for fuse operation failed");
   // Check if reshape is expanding or collapsing.
-  bool isExpanding =
-      reshapeOp.getSrcType().getRank() < reshapeOp.getResultType().getRank();
-  RankedTensorType expandedType =
-      isExpanding ? reshapeOp.getResultType() : reshapeOp.getSrcType();
+  auto expandingReshapeOp = dyn_cast<TensorExpandShapeOp>(*reshapeOp);
+  auto collapsingReshapeOp = dyn_cast<TensorCollapseShapeOp>(*reshapeOp);
+  bool isExpanding = (expandingReshapeOp != nullptr);
+  RankedTensorType expandedType = isExpanding
+                                      ? expandingReshapeOp.getResultType()
+                                      : collapsingReshapeOp.getSrcType();
 
   ExpansionInfo expansionInfo;
-  if (failed(expansionInfo.compute(genericOp, fusableOpOperand,
-                                   reshapeOp.getReassociationMaps(),
-                                   expandedType.getShape(), rewriter)))
+  if (failed(expansionInfo.compute(
+          genericOp, fusableOpOperand,
+          isExpanding ? expandingReshapeOp.getReassociationMaps()
+                      : collapsingReshapeOp.getReassociationMaps(),
+          expandedType.getShape(), rewriter)))
     return llvm::None;
 
   if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter)))
@@ -710,7 +703,8 @@ fuseWithReshapeByExpansion(GenericOp genericOp, TensorReshapeOp reshapeOp,
   SmallVector<Value> expandedOpOperands;
   for (OpOperand *opOperand : genericOp.getInputOperands()) {
     if (opOperand == fusableOpOperand) {
-      expandedOpOperands.push_back(reshapeOp.src());
+      expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.src()
+                                               : collapsingReshapeOp.src());
       continue;
     }
     AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
@@ -721,7 +715,7 @@ fuseWithReshapeByExpansion(GenericOp genericOp, TensorReshapeOp reshapeOp,
       // Reshape the operand to get the right type.
       SmallVector<ReassociationIndices> reassociation =
           getReassociationForExpansion(indexingMap, expansionInfo);
-      expandedOpOperands.push_back(rewriter.create<TensorReshapeOp>(
+      expandedOpOperands.push_back(rewriter.create<TensorExpandShapeOp>(
           genericOp.getLoc(), expandedOperandType, opOperand->get(),
           reassociation));
       continue;
@@ -739,7 +733,7 @@ fuseWithReshapeByExpansion(GenericOp genericOp, TensorReshapeOp reshapeOp,
     if (expandedOutputType != opOperand->get().getType()) {
       SmallVector<ReassociationIndices> reassociation =
           getReassociationForExpansion(indexingMap, expansionInfo);
-      outputs.push_back(rewriter.create<TensorReshapeOp>(
+      outputs.push_back(rewriter.create<TensorExpandShapeOp>(
           genericOp.getLoc(), expandedOutputType, opOperand->get(),
           reassociation));
     }
@@ -772,7 +766,7 @@ fuseWithReshapeByExpansion(GenericOp genericOp, TensorReshapeOp reshapeOp,
               genericOp.getTiedIndexingMap(
                   genericOp.getOutputOperand(resultNumber)),
               expansionInfo);
-      resultVals.push_back(rewriter.create<TensorReshapeOp>(
+      resultVals.push_back(rewriter.create<TensorCollapseShapeOp>(
           genericOp.getLoc(), opResult.getType(),
           fusedOp->getResult(resultNumber), reassociation));
     } else {
@@ -785,18 +779,15 @@ fuseWithReshapeByExpansion(GenericOp genericOp, TensorReshapeOp reshapeOp,
 
 namespace {
 
-/// Pattern to fold tensor_reshape op with its consumer by using the source of
-/// the reshape op as the operand in the consumer (instead of the result of the
-/// tensor_reshapeop) when the tensor_reshape op is collapsing. The
-/// corresponding index map in the consumer needs to be modified to linearize
-/// the folded dimension.
+/// Pattern to fold tensor_expand_shape op with its consumer by using the source
+/// of the reshape op as the operand in the consumer (instead of the result of
+/// the tensor_collapse_shape). The corresponding index map in the consumer
+/// needs to be modified to linearize the folded dimension.
 ///
 /// For example,
 ///
 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-/// %0 = linalg.tensor_reshape %arg0
-///        [affine_map<(i, j, k, l) -> (i)>, affine_map<(i, j, k, l) -> (j, k)>,
-///         affine_map<(i, j, k, l) -> (l)>]
+/// %0 = linalg.tensor_expand_shape %arg0 [[0], [1, 2], [3]]
 ///      tensor<?x?x?xf32> into tensor<?x?x4x?xf32>
 /// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... }
 ///        ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) ...
@@ -809,7 +800,7 @@ namespace {
 /// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... }
 ///        ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ...
 ///        -> tensor<?x?x4x?xf32>
-template <bool foldUnitDimReshapesOnly>
+template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
 struct FoldProducerReshapeOpByLinearization
     : public OpRewritePattern<GenericOp> {
   using OpRewritePattern<GenericOp>::OpRewritePattern;
@@ -820,14 +811,19 @@ struct FoldProducerReshapeOpByLinearization
       return failure();
     SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
     for (auto en : llvm::enumerate(inputOperands)) {
-      TensorReshapeOp reshapeOp =
-          en.value()->get().getDefiningOp<TensorReshapeOp>();
-      if (!reshapeOp ||
-          !isTensorReshapeOpFoldableByLinearization(
+      auto reshapeOp = en.value()->get().getDefiningOp<TensorReshapeOp>();
+      if (!reshapeOp)
+        continue;
+
+      Value src = reshapeOp.src();
+      RankedTensorType operandType = reshapeOp.getSrcType();
+      RankedTensorType returnType = reshapeOp.getResultType();
+
+      if (!isTensorReshapeOpFoldableByLinearization(
               reshapeOp, genericOp.getTiedIndexingMap(en.value()),
               /*asProducer =*/true) ||
           (foldUnitDimReshapesOnly &&
-           !isUnitDimExpansionOnly(reshapeOp.getResultType().getShape(),
+           !isUnitDimExpansionOnly(returnType.getShape(),
                                    reshapeOp.getReassociationMaps())))
         continue;
 
@@ -845,9 +841,8 @@ struct FoldProducerReshapeOpByLinearization
       auto invMap = inversePermutation(fusedIndexMaps[en.index()]);
 
       // Compute the indexing map to use for the result of the producer.
-      AffineMap modifiedMap =
-          linearizeCollapsedDims(invMap, reshapeOp.getResultType().getShape(),
-                                 reshapeOp.getReassociationMaps());
+      AffineMap modifiedMap = linearizeCollapsedDims(
+          invMap, returnType.getShape(), reshapeOp.getReassociationMaps());
       for (AffineExpr expr : modifiedMap.getResults()) {
         if (!expr.isPureAffine())
           return failure();
@@ -893,8 +888,7 @@ getReassociationIndices(ArrayRef<AffineMap> maps) {
 ///
 /// For example,
 ///
-///  %0 = linalg.tensor_reshape %A [
-///    affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
+///  %0 = linalg.tensor_expand_shape %A [[0, 1], [2]]
 ///      : tensor<12544x16xf32> into tensor<112x112x16xf32>
 ///  %2 = linalg.generic {indexing_maps = [
 ///    affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
@@ -912,8 +906,7 @@ getReassociationIndices(ArrayRef<AffineMap> maps) {
 ///    iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1
 ///    : tensor<12544x16xf32>, tensor<16xf32>) outs(%1 : tensor<12544x16xf32>) {
 ///  } -> tensor<12544x16xf32>
-///  %3 = linalg.tensor_reshape %2 [
-///    #affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
+///  %3 = linalg.tensor_expand_shape %2 [[0, 1], [2]]
 ///    : tensor<12544x16xf32> into tensor<112x112x16xf32>
 struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
   using OpRewritePattern<GenericOp>::OpRewritePattern;
@@ -932,17 +925,15 @@ struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
       return failure();
     int64_t destRank = genericOp.getNumParallelLoops();
     SmallVector<Value> newOperands = genericOp.getInputOperands();
-    TensorReshapeOp reshapeFound;
-    // 1. Look for tensor_reshape operands and figure out save the dimensions
-    // merged.
+    TensorExpandShapeOp reshapeFound;
+    // 1. Look for tensor_expand_shape operands and figure out save the
+    // dimensions merged.
     SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
     for (auto en : llvm::enumerate(inputOperands)) {
-      TensorReshapeOp reshapeOp =
-          en.value()->get().template getDefiningOp<TensorReshapeOp>();
-      if (!reshapeOp || reshapeOp.getSrcType().getRank() >
-                            reshapeOp.getResultType().getRank()) {
+      auto reshapeOp =
+          en.value()->get().template getDefiningOp<TensorExpandShapeOp>();
+      if (!reshapeOp)
         continue;
-      }
       // TODO: We could support non-identity map as long as the merged
       // dimensions are still contiguous.
       if (!genericOp.getTiedIndexingMap(en.value()).isIdentity())
@@ -1007,7 +998,7 @@ struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
       auto newOutputType = RankedTensorType::get(
           reshapeFound.getSrcType().getShape(),
           output.getType().template cast<RankedTensorType>().getElementType());
-      Value newOutput = rewriter.create<TensorReshapeOp>(
+      Value newOutput = rewriter.create<TensorCollapseShapeOp>(
           genericOp->getLoc(), newOutputType, output, reassociation);
       newOutputTypes.push_back(newOutputType);
       newOutputs.push_back(newOutput);
@@ -1023,7 +1014,7 @@ struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
     // 6. Reshape the so that the type matches the uses.
     SmallVector<Value> newResults;
     for (auto result : llvm::enumerate(newOp->getResults())) {
-      newResults.push_back(rewriter.create<TensorReshapeOp>(
+      newResults.push_back(rewriter.create<TensorExpandShapeOp>(
           genericOp->getLoc(), genericOp.getOutputTensorTypes()[result.index()],
           result.value(), reassociation));
     }
@@ -1032,9 +1023,9 @@ struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
   }
 };
 
-/// Pattern to fuse a tensor_reshape op with its consumer generic op, when the
-/// reshape op is collapsing dimensions. The dimensionality of the loop in the
-/// consumer is expanded.
+/// Pattern to fuse a tensor_collapse_shape op with its consumer generic op,
+/// when the reshape op is collapsing dimensions. The dimensionality of the loop
+/// in the consumer is expanded.
 class FoldWithProducerReshapeOpByExpansion
     : public OpRewritePattern<GenericOp> {
 public:
@@ -1047,16 +1038,14 @@ class FoldWithProducerReshapeOpByExpansion
   LogicalResult matchAndRewrite(GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
     for (OpOperand *opOperand : genericOp.getInputOperands()) {
-      TensorReshapeOp reshapeOp =
-          opOperand->get().getDefiningOp<TensorReshapeOp>();
+      TensorCollapseShapeOp reshapeOp =
+          opOperand->get().getDefiningOp<TensorCollapseShapeOp>();
       if (!reshapeOp)
         continue;
       // Fold only if
       // - The tensor reshape op is folding.
       // - All constraints of fusing with reshape by expansion are met.
-      if (reshapeOp.getSrcType().getRank() <
-              reshapeOp.getResultType().getRank() ||
-          !isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
+      if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
           (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)))
         continue;
 
@@ -1074,16 +1063,17 @@ class FoldWithProducerReshapeOpByExpansion
   ControlElementwiseOpsFusionFn controlFoldingReshapes;
 };
 
-/// Pattern to fold tensor_reshape op with its producer. The corresponding index
-/// map in the consumer needs to be modified to linearize the folded dimension.
-template <bool foldUnitDimReshapesOnly>
+/// Pattern to fold tensor_collapse_shape or tensor_expand_shape op with its
+/// producer. The corresponding index map in the consumer needs to be modified
+/// to linearize the folded dimension.
+template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
 struct FoldConsumerReshapeOpByLinearization
     : public OpRewritePattern<TensorReshapeOp> {
   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
                                 PatternRewriter &rewriter) const override {
-    GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>();
+    GenericOp producer = reshapeOp.src().template getDefiningOp<GenericOp>();
     if (!producer || !producer.hasTensorSemantics() ||
         producer.getNumOutputs() != 1 ||
         !isTensorReshapeOpFoldableByLinearization(
@@ -1141,19 +1131,14 @@ struct FoldConsumerReshapeOpByLinearization
   }
 };
 
-/// Pattern to fold a tensor_reshape op with its producer generic op if the
-/// tensor_reshape op is expanding, by expanding the dimensionality of the loop
-/// in the producer op.
+/// Pattern to fold a tensor_expand_shape op with its producer generic op
+/// by expanding the dimensionality of the loop in the producer op.
 struct FoldReshapeWithGenericOpByExpansion
-    : public OpRewritePattern<TensorReshapeOp> {
-  using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
-  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
+    : public OpRewritePattern<TensorExpandShapeOp> {
+  using OpRewritePattern<TensorExpandShapeOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(TensorExpandShapeOp reshapeOp,
                                 PatternRewriter &rewriter) const override {
-    // Fold only if
-    // - The tensor reshape op is a expanding case.
-    // - All constraints of fusing with reshape by expansion are met.
-    if (reshapeOp.getSrcType().getRank() > reshapeOp.getResultType().getRank())
-      return failure();
+    // Fold only if all constraints of fusing with reshape by expansion are met.
     GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>();
     if (!producer || producer.getNumOutputs() != 1 ||
         !isFusableWithReshapeByDimExpansion(producer,
@@ -1260,9 +1245,14 @@ fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand,
 
 bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
                                       const OpOperand &consumer) {
-  auto reshapeOp = producer.getDefiningOp<linalg::TensorReshapeOp>();
-  return !isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(),
-                                 reshapeOp.getReassociationMaps());
+  auto expandShapeOp = producer.getDefiningOp<linalg::TensorExpandShapeOp>();
+  if (expandShapeOp)
+    return !isUnitDimExpansionOnly(expandShapeOp.getSrcType().getShape(),
+                                   expandShapeOp.getReassociationMaps());
+  auto collapseShapeOp =
+      producer.getDefiningOp<linalg::TensorCollapseShapeOp>();
+  return !isUnitDimExpansionOnly(collapseShapeOp.getSrcType().getShape(),
+                                 collapseShapeOp.getReassociationMaps());
 }
 
 namespace {
@@ -1375,16 +1365,22 @@ struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
 
 void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<FoldProducerReshapeOpByLinearization<false>,
-               FoldConsumerReshapeOpByLinearization<false>>(
-      patterns.getContext());
+  patterns
+      .add<FoldProducerReshapeOpByLinearization<false, TensorCollapseShapeOp>,
+           FoldProducerReshapeOpByLinearization<false, TensorExpandShapeOp>,
+           FoldConsumerReshapeOpByLinearization<false, TensorCollapseShapeOp>,
+           FoldConsumerReshapeOpByLinearization<false, TensorExpandShapeOp>>(
+          patterns.getContext());
 }
 
 void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<FoldProducerReshapeOpByLinearization<true>,
-               FoldConsumerReshapeOpByLinearization<true>>(
-      patterns.getContext());
+  patterns
+      .add<FoldProducerReshapeOpByLinearization<true, TensorCollapseShapeOp>,
+           FoldProducerReshapeOpByLinearization<true, TensorExpandShapeOp>,
+           FoldConsumerReshapeOpByLinearization<true, TensorCollapseShapeOp>,
+           FoldConsumerReshapeOpByLinearization<true, TensorExpandShapeOp>>(
+          patterns.getContext());
 }
 
 void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
@@ -1406,7 +1402,8 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
   AffineApplyOp::getCanonicalizationPatterns(patterns, context);
   GenericOp::getCanonicalizationPatterns(patterns, context);
   IndexedGenericOp::getCanonicalizationPatterns(patterns, context);
-  TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
+  TensorExpandShapeOp::getCanonicalizationPatterns(patterns, context);
+  TensorCollapseShapeOp::getCanonicalizationPatterns(patterns, context);
 }
 
 void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) {
diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 1d2996c95fa63..15d0bf5129dc4 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -61,7 +61,7 @@ func @test_abs(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
 // CHECK-LABEL: @test_broadcast
 func @test_broadcast(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
   // CHECK: [[INIT:%.+]] = linalg.init_tensor [2] : tensor<2xf32>
-  // CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg0
+  // CHECK: [[RESHAPE:%.+]] = linalg.tensor_collapse_shape %arg0
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins([[RESHAPE]], %arg1 : tensor<f32>, tensor<2xf32>) outs([[INIT]] : tensor<2xf32>) {
   // CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):
   // CHECK:   [[ELEMENT:%.+]] = addf %arg2, %arg3 : f32
@@ -79,7 +79,7 @@ func @test_broadcast(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32
 // CHECK-LABEL: @test_broadcast_swapped_args
 func @test_broadcast_swapped_args(%arg0: tensor<2xf32>, %arg1: tensor<1xf32>) -> tensor<2xf32> {
   // CHECK: [[INIT:%.+]] = linalg.init_tensor [2] : tensor<2xf32>
-  // CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg1
+  // CHECK: [[RESHAPE:%.+]] = linalg.tensor_collapse_shape %arg1
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%arg0, [[RESHAPE]] : tensor<2xf32>, tensor<f32>) outs([[INIT]] : tensor<2xf32>) {
   // CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):
   // CHECK:   [[ELEMENT:%.+]] = addf %arg2, %arg3 : f32
@@ -98,8 +98,8 @@ func @test_broadcast_swapped_args(%arg0: tensor<2xf32>, %arg1: tensor<1xf32>) ->
 // CHECK-LABEL: @test_multibroadcast
 func @test_multibroadcast(%arg0: tensor<1x3xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x3xf32> {
   // CHECK: [[INIT:%.+]] = linalg.init_tensor [2, 3] : tensor<2x3xf32>
-  // CHECK: [[RESHAPE1:%.+]] = linalg.tensor_reshape %arg0 {{\[}}[0, 1]]
-  // CHECK: [[RESHAPE2:%.+]] = linalg.tensor_reshape %arg1 {{\[}}[0, 1]]
+  // CHECK: [[RESHAPE1:%.+]] = linalg.tensor_collapse_shape %arg0 {{\[}}[0, 1]]
+  // CHECK: [[RESHAPE2:%.+]] = linalg.tensor_collapse_shape %arg1 {{\[}}[0, 1]]
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins([[RESHAPE1]], [[RESHAPE2]] : tensor<3xf32>, tensor<2xf32>) outs([[INIT]] : tensor<2x3xf32>) {
   // CHECK: ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):
   // CHECK:   [[ELEMENT:%.+]] = addf %arg2, %arg3 : f32
@@ -467,7 +467,7 @@ func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
 
 // CHECK-LABEL: @test_reshape_downrank
 func @test_reshape_downrank(%arg0: tensor<2x3xf32>) -> tensor<6xf32> {
-  // CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg0 {{\[}}[0, 1]]
+  // CHECK: [[RESHAPE:%.+]] = linalg.tensor_collapse_shape %arg0 {{\[}}[0, 1]]
   %0 = "tosa.reshape"(%arg0) {new_shape = [6]} : (tensor<2x3xf32>) -> tensor<6xf32>
   // CHECK: return [[RESHAPE]]
   return %0 : tensor<6xf32>
@@ -477,7 +477,7 @@ func @test_reshape_downrank(%arg0: tensor<2x3xf32>) -> tensor<6xf32> {
 
 // CHECK-LABEL: @test_reshape_uprank
 func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> {
-  // CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg0 {{\[}}[0, 1]]
+  // CHECK: [[RESHAPE:%.+]] = linalg.tensor_expand_shape %arg0 {{\[}}[0, 1]]
   %0 = "tosa.reshape"(%arg0) {new_shape = [2, 3]} : (tensor<6xf32>) -> tensor<2x3xf32>
   // CHECK: return [[RESHAPE]]
   return %0 : tensor<2x3xf32>
@@ -488,8 +488,8 @@ func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> {
 // CHECK-LABEL: @test_reshape_samerank
 func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> {
   // CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xf32>)
-  // CHECK-NEXT: %[[RESHAPE1:.*]] = linalg.tensor_reshape %[[ARG0]] {{\[}}[0, 1]]
-  // CHECK-NEXT: %[[RESHAPE2:.*]] = linalg.tensor_reshape %[[RESHAPE1]] {{\[}}[0, 1]]
+  // CHECK-NEXT: %[[RESHAPE1:.*]] = linalg.tensor_collapse_shape %[[ARG0]] {{\[}}[0, 1]]
+  // CHECK-NEXT: %[[RESHAPE2:.*]] = linalg.tensor_expand_shape %[[RESHAPE1]] {{\[}}[0, 1]]
   %0 = "tosa.reshape"(%arg0) {new_shape = [2, 3]} : (tensor<3x2xf32>) -> tensor<2x3xf32>
   // CHECK-NEXT: return %[[RESHAPE2]]
   return %0 : tensor<2x3xf32>
@@ -499,7 +499,7 @@ func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> {
 
 // CHECK-LABEL: @test_reshape_downrank_6D
 func @test_reshape_downrank_6D(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> {
-  // CHECK: linalg.tensor_reshape %arg0 {{\[}}[0, 1, 2], [3], [4, 5]]
+  // CHECK: linalg.tensor_collapse_shape %arg0 {{\[}}[0, 1, 2], [3], [4, 5]]
   %0 = "tosa.reshape"(%arg0) {new_shape = [2, 3]} : (tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32>
   return %0 : tensor<6x5x77xf32>
 }
@@ -549,7 +549,7 @@ func @reduce_float(%arg0: tensor<5x4xf32>) -> () {
   // CHECK: ^bb0(%arg1: f32, %arg2: f32)
   // CHECK:   [[RES:%.+]] = addf %arg1, %arg2 : f32
   // CHECK:   linalg.yield [[RES]] : f32
-  // CHECK: linalg.tensor_reshape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xf32> into tensor<1x4xf32>
+  // CHECK: linalg.tensor_expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xf32> into tensor<1x4xf32>
   %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<1x4xf32>
 
   // CHECK: [[INIT:%.+]] = linalg.init_tensor [5]
@@ -559,7 +559,7 @@ func @reduce_float(%arg0: tensor<5x4xf32>) -> () {
   // CHECK: ^bb0(%arg1: f32, %arg2: f32)
   // CHECK:   [[RES:%.+]] = addf %arg1, %arg2 : f32
   // CHECK:   linalg.yield [[RES]] : f32
-  // CHECK: linalg.tensor_reshape [[GENERIC]] {{\[}}[0, 1]] : tensor<5xf32> into tensor<5x1xf32>
+  // CHECK: linalg.tensor_expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<5xf32> into tensor<5x1xf32>
   %1 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<5x4xf32>) -> tensor<5x1xf32>
 
   // CHECK: constant 1.0
@@ -600,7 +600,7 @@ func @reduce_int(%arg0: tensor<5x4xi32>) -> () {
   // CHECK: ^bb0(%arg1: i32, %arg2: i32)
   // CHECK:   [[RES:%.+]] = addi %arg1, %arg2 : i32
   // CHECK:   linalg.yield [[RES]] : i32
-  // CHECK: linalg.tensor_reshape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xi32> into tensor<1x4xi32>
+  // CHECK: linalg.tensor_expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xi32> into tensor<1x4xi32>
   %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<1x4xi32>
 
   // CHECK: [[INIT:%.+]] = linalg.init_tensor [5]
@@ -610,7 +610,7 @@ func @reduce_int(%arg0: tensor<5x4xi32>) -> () {
   // CHECK: ^bb0(%arg1: i32, %arg2: i32)
   // CHECK:   [[RES:%.+]] = addi %arg1, %arg2 : i32
   // CHECK:   linalg.yield [[RES]] : i32
-  // CHECK: linalg.tensor_reshape [[GENERIC]] {{\[}}[0, 1]] : tensor<5xi32> into tensor<5x1xi32>
+  // CHECK: linalg.tensor_expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<5xi32> into tensor<5x1xi32>
   %1 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<5x4xi32>) -> tensor<5x1xi32>
 
   // CHECK: constant 1
@@ -650,7 +650,7 @@ func @reduce_bool(%arg0: tensor<5x4xi1>) -> () {
   // CHECK: ^bb0(%arg1: i1, %arg2: i1)
   // CHECK:   [[RES:%.+]] = and %arg1, %arg2 : i1
   // CHECK:   linalg.yield [[RES]] : i1
-  // CHECK: linalg.tensor_reshape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xi1> into tensor<1x4xi1>
+  // CHECK: linalg.tensor_expand_shape [[GENERIC]] {{\[}}[0, 1]] : tensor<4xi1> into tensor<1x4xi1>
   %0 = "tosa.reduce_all"(%arg0) {axis = 0 : i64} : (tensor<5x4xi1>) -> tensor<1x4xi1>
 
   // CHECK: constant false
@@ -822,19 +822,19 @@ func @tile(%arg0 : tensor<2x3xi8>) -> () {
   // CHECK: [[INIT:%.+]] = linalg.init_tensor [2, 2, 1, 3]
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3xi8>) outs([[INIT]] : tensor<2x2x1x3xi8>)
   // CHECK:   linalg.yield %arg1 : i8
-  // CHECK: linalg.tensor_reshape [[GENERIC]] {{\[}}[0, 1, 2], [3]]
+  // CHECK: linalg.tensor_collapse_shape [[GENERIC]] {{\[}}[0, 1, 2], [3]]
   %0 = "tosa.tile"(%arg0) {multiples = [2, 1]} : (tensor<2x3xi8>)  -> (tensor<4x3xi8>)
 
   // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 2, 2, 3]
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3xi8>) outs([[INIT]] : tensor<1x2x2x3xi8>)
   // CHECK:   linalg.yield %arg1 : i8
-  // CHECK: linalg.tensor_reshape [[GENERIC]] {{\[}}[0, 1], [2, 3]]
+  // CHECK: linalg.tensor_collapse_shape [[GENERIC]] {{\[}}[0, 1], [2, 3]]
   %1 = "tosa.tile"(%arg0) {multiples = [1, 2]} : (tensor<2x3xi8>)  -> (tensor<2x6xi8>)
 
   // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 2, 7, 3]
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3xi8>) outs([[INIT]] : tensor<5x2x7x3xi8>)
   // CHECK:   linalg.yield %arg1 : i8
-  // CHECK: linalg.tensor_reshape [[GENERIC]] {{\[}}[0, 1], [2, 3]]
+  // CHECK: linalg.tensor_collapse_shape [[GENERIC]] {{\[}}[0, 1], [2, 3]]
   %2 = "tosa.tile"(%arg0) {multiples = [5, 7]} : (tensor<2x3xi8>)  -> (tensor<10x21xi8>)
 
   return
@@ -1097,7 +1097,7 @@ func @max_pool_i32(%arg0: tensor<1x6x34x62xi32>) -> () {
 func @avg_pool(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>) {
   // Initial piece computes the sum of the pooling region, with appropriate padding.
   // CHECK: [[CONST:%.+]] = constant 0
-  // CHECK: [[PAD:%.+]] = linalg.pad_tensor %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0] 
+  // CHECK: [[PAD:%.+]] = linalg.pad_tensor %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
   // CHECK: [[CONST:%.+]] = constant 0
   // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 33, 62]
   // CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CONST]])
@@ -1188,9 +1188,9 @@ func @depthwise_conv(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf32>,
   // CHECK: ^bb0(%arg3: f32, %arg4: f32):  // no predecessors
   // CHECK:   linalg.yield %arg3 : f32
   // CHECK: } -> tensor<1x5x5x33xf32>
-  // CHECK: [[DBIAS:%.+]] = linalg.tensor_reshape [[BIAS]] {{\[}}[0], [1], [2], [3, 4]]
+  // CHECK: [[DBIAS:%.+]] = linalg.tensor_expand_shape [[BIAS]] {{\[}}[0], [1], [2], [3, 4]]
   // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_input_nhwc_filter_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>) outs([[DBIAS]] : tensor<1x5x5x3x11xf32>)
-  // CHECK: linalg.tensor_reshape %3 {{\[}}[0], [1], [2], [3, 4]]
+  // CHECK: linalg.tensor_collapse_shape %3 {{\[}}[0], [1], [2], [3, 4]]
   %2 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) { pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1] } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<33xf32>)  -> (tensor<1x5x5x33xf32>)
   return
 }
@@ -1202,10 +1202,10 @@ func @depthwise_conv(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf32>,
 func @resize_nearest(%input: tensor<1x2x2x1xf32>) -> () {
   // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 4, 4, 1]
   // CHECK: %[[GENERIC:.+]] = linalg.generic
-  // CHECK: %[[IDX0:.+]] = linalg.index 0 
-  // CHECK: %[[IDX1:.+]] = linalg.index 1 
-  // CHECK: %[[IDX2:.+]] = linalg.index 2 
-  // CHECK: %[[IDX3:.+]] = linalg.index 3 
+  // CHECK: %[[IDX0:.+]] = linalg.index 0
+  // CHECK: %[[IDX1:.+]] = linalg.index 1
+  // CHECK: %[[IDX2:.+]] = linalg.index 2
+  // CHECK: %[[IDX3:.+]] = linalg.index 3
   // CHECK-DAG: %[[XYMIN:.+]] = constant 0
   // CHECK-DAG: %[[YMAX:.+]] = constant 1
   // CHECK-DAG: %[[XMAX:.+]] = constant 1
@@ -1271,9 +1271,9 @@ func @resize_nearest(%input: tensor<1x2x2x1xf32>) -> () {
 func @resize_bilinear(%input: tensor<1x2x2x1xf32>) -> () {
   // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 4, 4, 1]
   // CHECK: %[[GENERIC:.+]] = linalg.generic
-  // CHECK: %[[IDX0:.+]] = linalg.index 0 
-  // CHECK: %[[IDX1:.+]] = linalg.index 1 
-  // CHECK: %[[IDX2:.+]] = linalg.index 2 
+  // CHECK: %[[IDX0:.+]] = linalg.index 0
+  // CHECK: %[[IDX1:.+]] = linalg.index 1
+  // CHECK: %[[IDX2:.+]] = linalg.index 2
   // CHECK: %[[IDX3:.+]] = linalg.index 3
   // CHECK: %[[XYMIN:.+]] = constant 0
   // CHECK: %[[YMAX:.+]] = constant 1
@@ -1353,9 +1353,9 @@ func @resize_bilinear(%input: tensor<1x2x2x1xf32>) -> () {
 func @resize_nearest_int(%input: tensor<1x2x2x1xi32>) -> () {
   // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 4, 4, 1]
   // CHECK: %[[GENERIC:.+]] = linalg.generic
-  // CHECK: %[[IDX0:.+]] = linalg.index 0 
-  // CHECK: %[[IDX1:.+]] = linalg.index 1 
-  // CHECK: %[[IDX2:.+]] = linalg.index 2 
+  // CHECK: %[[IDX0:.+]] = linalg.index 0
+  // CHECK: %[[IDX1:.+]] = linalg.index 1
+  // CHECK: %[[IDX2:.+]] = linalg.index 2
   // CHECK: %[[IDX3:.+]] = linalg.index 3
   // CHECK-DAG: %[[XYMIN:.+]] = constant 0
   // CHECK-DAG: %[[YMAX:.+]] = constant 1
@@ -1422,7 +1422,7 @@ func @resize_bilinear_int(%input: tensor<1x2x2x1xi8>) -> () {
   // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 4, 4, 1]
   // CHECK: %[[GENERIC:.+]] = linalg.generic
 
-  // CHECK: %[[IDX0:.+]] = linalg.index 0 
+  // CHECK: %[[IDX0:.+]] = linalg.index 0
   // CHECK: %[[IDX3:.+]] = linalg.index 3
 
   // CHECK: %[[XYMIN:.+]] = constant 0
diff  --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index 5733f5963ad11..9f3406e27d084 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -253,15 +253,15 @@ func @bufferize_fill(%arg0: tensor<?xf32>) -> tensor<?xf32> {
 
 // -----
 
-// CHECK-LABEL: func @bufferize_tensor_reshape(
+// CHECK-LABEL: func @bufferize_tensor_collapse_shape(
 // CHECK-SAME:    %[[IN:.*]]: tensor<4x5xf32>
-func @bufferize_tensor_reshape(%arg0: tensor<4x5xf32>) -> tensor<20xf32> {
-  %out = linalg.tensor_reshape %arg0 [[0, 1]] :
+func @bufferize_tensor_collapse_shape(%arg0: tensor<4x5xf32>) -> tensor<20xf32> {
+  %out = linalg.tensor_collapse_shape %arg0 [[0, 1]] :
      tensor<4x5xf32> into tensor<20xf32>
   return %out : tensor<20xf32>
 }
 // CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[IN]] : memref<4x5xf32>
-// CHECK: %[[RESHAPE:.*]] = linalg.reshape %[[MEMREF]] {{\[}}[0, 1]]
+// CHECK: %[[RESHAPE:.*]] = linalg.collapse_shape %[[MEMREF]] {{\[}}[0, 1]]
 // CHECK-SAME: : memref<4x5xf32> into memref<20xf32>
 // CHECK: %[[TENSOR:.*]] = memref.tensor_load %[[RESHAPE]] : memref<20xf32>
 // CHECK: return %[[TENSOR]]
diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index bcdbca57324aa..b3915477ce476 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -46,9 +46,9 @@ func @memref_cast_into_tiled_loop(%arg0: memref<192xf32>)  {
 // CHECK-LABEL: zero_rank_reshape_multi
 func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
   // CHECK: return %arg0
-  %0 = linalg.tensor_reshape %arg0 [] : tensor<f32> into tensor<1xf32>
-  %1 = linalg.tensor_reshape %0 [[0, 1]] : tensor<1xf32> into tensor<1x1xf32>
-  %2 = linalg.tensor_reshape %1 [] : tensor<1x1xf32> into tensor<f32>
+  %0 = linalg.tensor_expand_shape %arg0 [] : tensor<f32> into tensor<1xf32>
+  %1 = linalg.tensor_expand_shape %0 [[0, 1]] : tensor<1xf32> into tensor<1x1xf32>
+  %2 = linalg.tensor_collapse_shape %1 [] : tensor<1x1xf32> into tensor<f32>
   return %2 : tensor<f32>
 }
 
@@ -56,175 +56,175 @@ func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
 
 func @collapsing_tensor_reshapes(%arg0 : tensor<?x?x?x?x?xf32>) -> tensor<?x?xf32>
 {
-  %0 = linalg.tensor_reshape %arg0 [[0, 1], [2], [3, 4]]
+  %0 = linalg.tensor_collapse_shape %arg0 [[0, 1], [2], [3, 4]]
       : tensor<?x?x?x?x?xf32> into tensor<?x?x?xf32>
-  %1 = linalg.tensor_reshape %0 [[0, 1], [2]]
+  %1 = linalg.tensor_collapse_shape %0 [[0, 1], [2]]
       : tensor<?x?x?xf32> into tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
 // CHECK-LABEL: collapsing_tensor_reshapes
-//       CHECK:   linalg.tensor_reshape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
-//   CHECK-NOT:   linalg.tensor_reshape
+//       CHECK:   linalg.tensor_collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
+//   CHECK-NOT:   linalg.tensor_collapse_shape
 
 // -----
 
 func @collapsing_tensor_reshapes_to_zero_dim(%arg0 : tensor<1x1x1xf32>)
                                              -> tensor<f32> {
-  %0 = linalg.tensor_reshape %arg0 [[0, 1, 2]]
+  %0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2]]
       : tensor<1x1x1xf32> into tensor<1xf32>
-  %1 = linalg.tensor_reshape %0 [] : tensor<1xf32> into tensor<f32>
+  %1 = linalg.tensor_collapse_shape %0 [] : tensor<1xf32> into tensor<f32>
   return %1 : tensor<f32>
 }
 // CHECK-LABEL: collapsing_tensor_reshapes_to_zero
-//       CHECK:   linalg.tensor_reshape %{{.*}} []
+//       CHECK:   linalg.tensor_collapse_shape %{{.*}} []
 //  CHECK-SAME:     tensor<1x1x1xf32> into tensor<f32>
 
 // -----
 
 func @collapsing_memref_reshapes_to_zero_dim(%arg0 : memref<1x1x1xf32>)
                                              -> memref<f32> {
-  %0 = linalg.reshape %arg0 [[0, 1, 2]]
+  %0 = linalg.collapse_shape %arg0 [[0, 1, 2]]
       : memref<1x1x1xf32> into memref<1xf32>
-  %1 = linalg.reshape %0 [] : memref<1xf32> into memref<f32>
+  %1 = linalg.collapse_shape %0 [] : memref<1xf32> into memref<f32>
   return %1 : memref<f32>
 }
 // CHECK-LABEL: collapsing_memref_reshapes_to_zero
-//       CHECK:   linalg.reshape %{{.*}} []
+//       CHECK:   linalg.collapse_shape %{{.*}} []
 //  CHECK-SAME:     memref<1x1x1xf32> into memref<f32>
 
 // -----
 
 func @expanding_tensor_reshapes(%arg0 : tensor<?x?xf32>) -> tensor<?x6x4x?x5xf32>
 {
-  %0 = linalg.tensor_reshape %arg0 [[0, 1], [2]]
+  %0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2]]
       : tensor<?x?xf32> into tensor<?x4x?xf32>
-  %1 = linalg.tensor_reshape %0 [[0, 1], [2], [3, 4]]
+  %1 = linalg.tensor_expand_shape %0 [[0, 1], [2], [3, 4]]
       : tensor<?x4x?xf32> into tensor<?x6x4x?x5xf32>
   return %1 : tensor<?x6x4x?x5xf32>
 }
 // CHECK-LABEL: expanding_tensor_reshapes
-//       CHECK:   linalg.tensor_reshape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
-//   CHECK-NOT:   linalg.tensor_reshape
+//       CHECK:   linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
+//   CHECK-NOT:   linalg.tensor_expand_shape
 
 // -----
 
 func @collapsing_memref_reshapes(%arg0 : memref<?x?x?x?x?xf32>) -> memref<?x?xf32>
 {
-  %0 = linalg.reshape %arg0 [[0, 1], [2], [3, 4]]
+  %0 = linalg.collapse_shape %arg0 [[0, 1], [2], [3, 4]]
       : memref<?x?x?x?x?xf32> into memref<?x?x?xf32>
-  %1 = linalg.reshape %0 [[0, 1], [2]]
+  %1 = linalg.collapse_shape %0 [[0, 1], [2]]
       : memref<?x?x?xf32> into memref<?x?xf32>
   return %1 : memref<?x?xf32>
 }
 // CHECK-LABEL: collapsing_memref_reshapes
-//       CHECK:   linalg.reshape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
-//   CHECK-NOT:   linalg.reshape
+//       CHECK:   linalg.collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
+//   CHECK-NOT:   linalg.collapse_shape
 
 // -----
 
 func @expanding_memref_reshapes(%arg0 : memref<?x?xf32>) -> memref<?x6x4x5x?xf32>
 {
-  %0 = linalg.reshape %arg0 [[0, 1], [2]]
+  %0 = linalg.expand_shape %arg0 [[0, 1], [2]]
       : memref<?x?xf32> into memref<?x4x?xf32>
-  %1 = linalg.reshape %0 [[0, 1], [2], [3, 4]]
+  %1 = linalg.expand_shape %0 [[0, 1], [2], [3, 4]]
       : memref<?x4x?xf32> into memref<?x6x4x5x?xf32>
   return %1 : memref<?x6x4x5x?xf32>
 }
 // CHECK-LABEL: expanding_memref_reshapes
-//       CHECK:   linalg.reshape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
-//   CHECK-NOT:   linalg.reshape
+//       CHECK:   linalg.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
+//   CHECK-NOT:   linalg.expand_shape
 
 // -----
 
 func @expanding_tensor_reshapes_to_zero_dim(%arg0 : tensor<f32>)
                                              -> tensor<1x1x1xf32> {
-  %0 = linalg.tensor_reshape %arg0 [] : tensor<f32> into tensor<1xf32>
-  %1 = linalg.tensor_reshape %0 [[0, 1, 2]] 
+  %0 = linalg.tensor_expand_shape %arg0 [] : tensor<f32> into tensor<1xf32>
+  %1 = linalg.tensor_expand_shape %0 [[0, 1, 2]]
       : tensor<1xf32> into tensor<1x1x1xf32>
   return %1 : tensor<1x1x1xf32>
 }
 // CHECK-LABEL: expanding_tensor_reshapes_to_zero
-//       CHECK:   linalg.tensor_reshape %{{.*}} []
+//       CHECK:   linalg.tensor_expand_shape %{{.*}} []
 //  CHECK-SAME:     tensor<f32> into tensor<1x1x1xf32>
 
 // -----
 
 func @expanding_memref_reshapes_to_zero_dim(%arg0 : memref<f32>)
                                              -> memref<1x1x1xf32> {
-  %0 = linalg.reshape %arg0 [] : memref<f32> into memref<1xf32>
-  %1 = linalg.reshape %0 [[0, 1, 2]]
+  %0 = linalg.expand_shape %arg0 [] : memref<f32> into memref<1xf32>
+  %1 = linalg.expand_shape %0 [[0, 1, 2]]
       : memref<1xf32> into memref<1x1x1xf32>
   return %1 : memref<1x1x1xf32>
 }
 // CHECK-LABEL: expanding_memref_reshapes_to_zero
-//       CHECK:   linalg.reshape %{{.*}} []
+//       CHECK:   linalg.expand_shape %{{.*}} []
 //  CHECK-SAME:     memref<f32> into memref<1x1x1xf32>
 
 // -----
 
 func @fold_tensor_reshape(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32>
 {
-  %0 = linalg.tensor_reshape %arg0 [[0, 1], [2]]
+  %0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2]]
       : tensor<12x4xf32> into tensor<3x4x4xf32>
-  %1 = linalg.tensor_reshape %0 [[0, 1], [2]]
+  %1 = linalg.tensor_collapse_shape %0 [[0, 1], [2]]
       : tensor<3x4x4xf32> into tensor<12x4xf32>
   return %1 : tensor<12x4xf32>
 }
 // CHECK-LABEL: @fold_tensor_reshape
-//   CHECK-NOT:   linalg.tensor_reshape
+//   CHECK-NOT:   linalg.{{.*}}shape
 
 // -----
 
 func @fold_tensor_reshape_dynamic(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
 {
-  %0 = linalg.tensor_reshape %arg0 [[0, 1], [2]]
+  %0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2]]
       : tensor<?x?xf32> into tensor<?x4x?xf32>
-  %1 = linalg.tensor_reshape %0 [[0, 1], [2]]
+  %1 = linalg.tensor_collapse_shape %0 [[0, 1], [2]]
       : tensor<?x4x?xf32> into tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
 // CHECK-LABEL: @fold_tensor_reshape_dynamic
-//   CHECK-NOT:   linalg.tensor_reshape
+//   CHECK-NOT:   linalg.{{.*}}_shape
 
 // -----
 
 func @fold_memref_reshape(%arg0 : memref<12x4xf32>) -> memref<12x4xf32>
 {
-  %0 = linalg.reshape %arg0 [[0, 1], [2]]
+  %0 = linalg.expand_shape %arg0 [[0, 1], [2]]
       : memref<12x4xf32> into memref<3x4x4xf32>
-  %1 = linalg.reshape %0 [[0, 1], [2]]
+  %1 = linalg.collapse_shape %0 [[0, 1], [2]]
       : memref<3x4x4xf32> into memref<12x4xf32>
   return %1 : memref<12x4xf32>
 }
 // CHECK-LABEL: @fold_memref_reshape
-//   CHECK-NOT:   linalg.reshape
+//   CHECK-NOT:   linalg.{{.*}}_shape
 
 // -----
 
 func @fold_memref_reshape_dynamic(%arg0 : memref<?x?xf32>) -> memref<?x?xf32>
 {
-  %0 = linalg.reshape %arg0 [[0, 1], [2]]
+  %0 = linalg.expand_shape %arg0 [[0, 1], [2]]
       : memref<?x?xf32> into memref<?x4x?xf32>
-  %1 = linalg.reshape %0 [[0, 1], [2]]
+  %1 = linalg.collapse_shape %0 [[0, 1], [2]]
       : memref<?x4x?xf32> into memref<?x?xf32>
   return %1 : memref<?x?xf32>
 }
 // CHECK-LABEL: @fold_memref_reshape_dynamic
-//   CHECK-NOT:   linalg.reshape
+//   CHECK-NOT:   linalg.{{.*}}_shape
 
 // -----
 
 func @reshape_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>) -> tensor<24x5x42x8xf32>
 {
-  %0 = linalg.tensor_reshape %arg0 [[0, 1, 2, 3, 4, 5, 6]]
+  %0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2, 3, 4, 5, 6]]
       : tensor<2x3x4x5x6x7x8xf32> into tensor<40320xf32>
-  %1 = linalg.tensor_reshape %0 [[0, 1, 2, 3]]
+  %1 = linalg.tensor_expand_shape %0 [[0, 1, 2, 3]]
       : tensor<40320xf32> into tensor<24x5x42x8xf32>
   return %1 : tensor<24x5x42x8xf32>
 }
 //      CHECK: func @reshape_collapse
 // CHECK-SAME:   %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8xf32>
-//      CHECK:   %[[RESULT:.+]] = linalg.tensor_reshape %[[ARG0]]
+//      CHECK:   %[[RESULT:.+]] = linalg.tensor_collapse_shape %[[ARG0]]
 // CHECK-SAME:     [0, 1, 2], [3], [4, 5], [6]
 //      CHECK:   return %[[RESULT]]
 
@@ -232,15 +232,15 @@ func @reshape_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>) -> tensor<24x5x42x8xf3
 
 func @reshape_expand(%arg0 : tensor<24x5x42x8xf32>) -> tensor<2x3x4x5x6x7x8xf32>
 {
-  %0 = linalg.tensor_reshape %arg0 [[0, 1, 2, 3]]
+  %0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2, 3]]
       : tensor<24x5x42x8xf32> into tensor<40320xf32>
-  %1 = linalg.tensor_reshape %0 [[0, 1, 2, 3, 4, 5, 6]]
+  %1 = linalg.tensor_expand_shape %0 [[0, 1, 2, 3, 4, 5, 6]]
       : tensor<40320xf32> into tensor<2x3x4x5x6x7x8xf32>
   return %1 : tensor<2x3x4x5x6x7x8xf32>
 }
 //      CHECK: func @reshape_expand
 // CHECK-SAME:   %[[ARG0:.+]]: tensor<24x5x42x8xf32>
-//      CHECK:   %[[RESULT:.+]] = linalg.tensor_reshape %[[ARG0]]
+//      CHECK:   %[[RESULT:.+]] = linalg.tensor_expand_shape %[[ARG0]]
 // CHECK-SAME:     [0, 1, 2], [3], [4, 5], [6]
 //      CHECK:   return %[[RESULT]]
 
@@ -248,84 +248,84 @@ func @reshape_expand(%arg0 : tensor<24x5x42x8xf32>) -> tensor<2x3x4x5x6x7x8xf32>
 
 func @expand_reshape_1D(%arg0 : tensor<2048xf32>) -> tensor<4x512xf32>
 {
-  %0 = linalg.tensor_reshape %arg0 [[0, 1, 2, 3]]
+  %0 = linalg.tensor_expand_shape %arg0 [[0, 1, 2, 3]]
     : tensor<2048xf32> into tensor<1x4x1x512xf32>
-  %1 = linalg.tensor_reshape %0 [[0, 1, 2], [3]]
+  %1 = linalg.tensor_collapse_shape %0 [[0, 1, 2], [3]]
     : tensor<1x4x1x512xf32> into tensor<4x512xf32>
   return %1 : tensor<4x512xf32>
 }
 //       CHECK: func @expand_reshape_1D
-//       CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1]]
+//       CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1]]
 //  CHECK-SAME:   tensor<2048xf32> into tensor<4x512xf32>
 
 // -----
 
 func @fold_reshape_1D(%arg0 : tensor<4x512xf32>) -> tensor<2048xf32>
 {
-  %0 = linalg.tensor_reshape %arg0 [[0, 1, 2], [3]]
+  %0 = linalg.tensor_expand_shape %arg0 [[0, 1, 2], [3]]
     : tensor<4x512xf32> into tensor<1x4x1x512xf32>
-  %1 = linalg.tensor_reshape %0 [[0, 1, 2, 3]]
+  %1 = linalg.tensor_collapse_shape %0 [[0, 1, 2, 3]]
     : tensor<1x4x1x512xf32> into tensor<2048xf32>
   return %1 : tensor<2048xf32>
 }
 //       CHECK: func @fold_reshape_1D
-//       CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1]]
+//       CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0, 1]]
 //  CHECK-SAME:   tensor<4x512xf32> into tensor<2048xf32>
 
 // -----
 
 func @fold_reshape_unit_dims(%arg0 : tensor<2048x1x1xf32>) -> tensor<4x512x1x1xf32>
 {
-  %0 = linalg.tensor_reshape %arg0 [[0, 1, 2, 3], [4], [5]]
+  %0 = linalg.tensor_expand_shape %arg0 [[0, 1, 2, 3], [4], [5]]
     : tensor<2048x1x1xf32> into tensor<1x4x1x512x1x1xf32>
-  %1 = linalg.tensor_reshape %0 [[0, 1, 2], [3], [4], [5]]
+  %1 = linalg.tensor_collapse_shape %0 [[0, 1, 2], [3], [4], [5]]
     : tensor<1x4x1x512x1x1xf32> into tensor<4x512x1x1xf32>
   return %1 : tensor<4x512x1x1xf32>
 }
 //       CHECK: func @fold_reshape_unit_dims
-//       CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1], [2], [3]]
+//       CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1], [2], [3]]
 //  CHECK-SAME:   tensor<2048x1x1xf32> into tensor<4x512x1x1xf32>
 
 // -----
 
 func @expand_reshape_unit_dims(%arg0 : tensor<2048x1x2048xf32>) -> tensor<4x512x1x512x4xf32>
 {
-  %0 = linalg.tensor_reshape %arg0 [[0, 1, 2, 3, 4], [5], [6, 7, 8]]
+  %0 = linalg.tensor_expand_shape %arg0 [[0, 1, 2, 3, 4], [5], [6, 7, 8]]
     : tensor<2048x1x2048xf32> into tensor<1x4x1x512x1x1x512x1x4xf32>
-  %1 = linalg.tensor_reshape %0 [[0, 1, 2], [3, 4], [5], [6, 7], [8]]
+  %1 = linalg.tensor_collapse_shape %0 [[0, 1, 2], [3, 4], [5], [6, 7], [8]]
     : tensor<1x4x1x512x1x1x512x1x4xf32> into tensor<4x512x1x512x4xf32>
   return %1 : tensor<4x512x1x512x4xf32>
 }
 //       CHECK: func @expand_reshape_unit_dims
-//       CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1], [2], [3, 4]]
+//       CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1], [2], [3, 4]]
 //  CHECK-SAME:   tensor<2048x1x2048xf32> into tensor<4x512x1x512x4xf32>
 
 // -----
 
 func @fold_reshape_trailing_unit_dims(%arg0: tensor<2xf32>) -> tensor<2x1xf32>
 {
-  %0 = linalg.tensor_reshape %arg0 [[0, 1, 2]]
+  %0 = linalg.tensor_expand_shape %arg0 [[0, 1, 2]]
       : tensor<2xf32> into tensor<2x1x1xf32>
-  %1 = linalg.tensor_reshape %0 [[0], [1, 2]]
+  %1 = linalg.tensor_collapse_shape %0 [[0], [1, 2]]
       : tensor<2x1x1xf32> into tensor<2x1xf32>
   return %1 : tensor<2x1xf32>
 }
 //       CHECK: func @fold_reshape_trailing_unit_dims
-//       CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1]]
+//       CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1]]
 //  CHECK-SAME:   tensor<2xf32> into tensor<2x1xf32>
 
 // -----
 
 func @collapse_reshape_unit_dims_dynamic(%arg0 : tensor<?x1x?x1x1x?x?x1x1xf32>) -> tensor<?x?x?x?xf32>
 {
-  %0 = linalg.tensor_reshape %arg0 [[0], [1, 2], [3], [4], [5], [6, 7, 8]]
+  %0 = linalg.tensor_collapse_shape %arg0 [[0], [1, 2], [3], [4], [5], [6, 7, 8]]
     : tensor<?x1x?x1x1x?x?x1x1xf32> into tensor<?x?x1x1x?x?xf32>
-  %1 = linalg.tensor_reshape %0 [[0], [1], [2, 3, 4], [5]]
+  %1 = linalg.tensor_collapse_shape %0 [[0], [1], [2, 3, 4], [5]]
     : tensor<?x?x1x1x?x?xf32> into tensor<?x?x?x?xf32>
   return %1 : tensor<?x?x?x?xf32>
 }
 //       CHECK: func @collapse_reshape_unit_dims_dynamic
-//       CHECK: linalg.tensor_reshape
+//       CHECK: linalg.tensor_collapse_shape
 //  CHECK-SAME:   [0], [1, 2], [3, 4, 5], [6, 7, 8]
 //  CHECK-SAME:   tensor<?x1x?x1x1x?x?x1x1xf32> into tensor<?x?x?x?xf32>
 
@@ -333,72 +333,72 @@ func @collapse_reshape_unit_dims_dynamic(%arg0 : tensor<?x1x?x1x1x?x?x1x1xf32>)
 
 func @fold_reshape_trailing_unit_dims(%arg0: tensor<2xf32>) -> tensor<2x1xf32>
 {
-  %0 = linalg.tensor_reshape %arg0 [[0, 1, 2]]
+  %0 = linalg.tensor_expand_shape %arg0 [[0, 1, 2]]
       : tensor<2xf32> into tensor<2x1x1xf32>
-  %1 = linalg.tensor_reshape %0 [[0], [1, 2]]
+  %1 = linalg.tensor_collapse_shape %0 [[0], [1, 2]]
       : tensor<2x1x1xf32> into tensor<2x1xf32>
   return %1 : tensor<2x1xf32>
 }
 //       CHECK: func @fold_reshape_trailing_unit_dims
-//       CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1]]
+//       CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1]]
 //  CHECK-SAME:   tensor<2xf32> into tensor<2x1xf32>
 
 // -----
 
 func @fold_reshape_trailing_unit_dims_dynamic(%arg0: tensor<1x1x?x1x1x1xf32>) -> tensor<?xf32>
 {
-  %0 = linalg.tensor_reshape %arg0 [[0, 1, 2], [3], [4], [5]]
+  %0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2], [3], [4], [5]]
       : tensor<1x1x?x1x1x1xf32> into tensor<?x1x1x1xf32>
-  %1 = linalg.tensor_reshape %0 [[0, 1, 2, 3]]
+  %1 = linalg.tensor_collapse_shape %0 [[0, 1, 2, 3]]
       : tensor<?x1x1x1xf32> into tensor<?xf32>
   return %1 : tensor<?xf32>
 }
 //       CHECK: func @fold_reshape_trailing_unit_dims_dynamic
-//       CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1, 2, 3, 4, 5]]
+//       CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3, 4, 5]]
 //  CHECK-SAME:   tensor<1x1x?x1x1x1xf32> into tensor<?xf32>
 
 // -----
 
 func @no_fold_reshapes(%arg0 : tensor<?x?x?xf32>) -> tensor<?x?xf32>
 {
-  %0 = linalg.tensor_reshape %arg0 [[0], [1], [2, 3]]
+  %0 = linalg.tensor_expand_shape %arg0 [[0], [1], [2, 3]]
       : tensor<?x?x?xf32> into tensor<?x?x1x?xf32>
-  %1 = linalg.tensor_reshape %0 [[0], [1, 2, 3]]
+  %1 = linalg.tensor_collapse_shape %0 [[0], [1, 2, 3]]
       : tensor<?x?x1x?xf32> into tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
 // CHECK-LABEL: func @no_fold_reshapes
-//       CHECK:   linalg.tensor_reshape
-//       CHECK:   linalg.tensor_reshape
+//       CHECK:   linalg.tensor_expand_shape
+//       CHECK:   linalg.tensor_collapse_shape
 
 // -----
 
 func @no_fold_reshape_incompatible(%arg0 : tensor<4x6x8xf32>) -> tensor<2x6x16xf32>
 {
-  %0 = linalg.tensor_reshape %arg0 [[0, 1], [2, 3], [4]]
+  %0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2, 3], [4]]
       : tensor<4x6x8xf32> into tensor<2x2x3x2x8xf32>
-  %1 = linalg.tensor_reshape %0 [[0], [1, 2], [3, 4]]
+  %1 = linalg.tensor_collapse_shape %0 [[0], [1, 2], [3, 4]]
       : tensor<2x2x3x2x8xf32> into tensor<2x6x16xf32>
   return %1 : tensor<2x6x16xf32>
 }
 // CHECK-LABEL: func @no_fold_reshape_incompatible
-//       CHECK:   linalg.tensor_reshape
-//       CHECK:   linalg.tensor_reshape
+//       CHECK:   linalg.tensor_expand_shape
+//       CHECK:   linalg.tensor_collapse_shape
 
 // -----
 
 func @no_fold_reshape_empty_expr(%arg0: tensor<3x2x2xf32>) -> tensor<12x1xf32> {
-  %0 = linalg.tensor_reshape %arg0 [[0], [1], [2, 3]]
+  %0 = linalg.tensor_expand_shape %arg0 [[0], [1], [2, 3]]
       : tensor<3x2x2xf32> into tensor<3x2x2x1xf32>
-  %1 = linalg.tensor_reshape %0 [[0, 1, 2], [3]]
+  %1 = linalg.tensor_collapse_shape %0 [[0, 1, 2], [3]]
       : tensor<3x2x2x1xf32> into tensor<12x1xf32>
   return %1 : tensor<12x1xf32>
 }
 //      CHECK: func @no_fold_reshape_empty_expr
 // CHECK-SAME:    %[[ARG0:.+]]: tensor<3x2x2xf32>
-//      CHECK:    %[[RARG0:.+]] = linalg.tensor_reshape %[[ARG0]]
+//      CHECK:    %[[RARG0:.+]] = linalg.tensor_expand_shape %[[ARG0]]
 // CHECK-SAME:      [0], [1], [2, 3]
-//      CHECK:    %[[RES:.+]] = linalg.tensor_reshape %[[RARG0]]
+//      CHECK:    %[[RES:.+]] = linalg.tensor_collapse_shape %[[RARG0]]
 // CHECK-SAME:      [0, 1, 2], [3]
 //      CHECK:    return %[[RES:.+]] : tensor<12x1xf32>
 
@@ -436,49 +436,49 @@ func @dce_zero_memref(%arg0 : memref<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf
 func @reshape_splat_constant_int32() -> tensor<2x4x2xi32>
 {
   %c0 = constant dense<42> : tensor<2x8xi32>
-  %0 = linalg.tensor_reshape %c0 [[0], [1, 2]]
+  %0 = linalg.tensor_expand_shape %c0 [[0], [1, 2]]
       : tensor<2x8xi32> into tensor<2x4x2xi32>
   return %0 : tensor<2x4x2xi32>
 }
 // CHECK-LABEL: @reshape_splat_constant_int32
 //       CHECK:   %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xi32>
-//   CHECK-NOT:   linalg.tensor_reshape
+//   CHECK-NOT:   linalg.tensor_expand_shape
 //       CHECK:   return %[[CST]]
 
 func @reshape_splat_constant_int16() -> tensor<2x4x2xi16>
 {
   %c0 = constant dense<42> : tensor<2x8xi16>
-  %0 = linalg.tensor_reshape %c0 [[0], [1, 2]]
+  %0 = linalg.tensor_expand_shape %c0 [[0], [1, 2]]
       : tensor<2x8xi16> into tensor<2x4x2xi16>
   return %0 : tensor<2x4x2xi16>
 }
 // CHECK-LABEL: @reshape_splat_constant_int16
 //       CHECK:   %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xi16>
-//   CHECK-NOT:   linalg.tensor_reshape
+//   CHECK-NOT:   linalg.tensor_expand_shape
 //       CHECK:   return %[[CST]]
 
 func @reshape_splat_constant_float32() -> tensor<2x4x2xf32>
 {
   %c0 = constant dense<42.0> : tensor<2x8xf32>
-  %0 = linalg.tensor_reshape %c0 [[0], [1, 2]]
+  %0 = linalg.tensor_expand_shape %c0 [[0], [1, 2]]
       : tensor<2x8xf32> into tensor<2x4x2xf32>
   return %0 : tensor<2x4x2xf32>
 }
 // CHECK-LABEL: @reshape_splat_constant_float32
 //       CHECK:   %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xf32>
-//   CHECK-NOT:   linalg.tensor_reshape
+//   CHECK-NOT:   linalg.tensor_expand_shape
 //       CHECK:   return %[[CST]]
 
 func @reshape_splat_constant_float64() -> tensor<2x4x2xf64>
 {
   %c0 = constant dense<42.0> : tensor<2x8xf64>
-  %0 = linalg.tensor_reshape %c0 [[0], [1, 2]]
+  %0 = linalg.tensor_expand_shape %c0 [[0], [1, 2]]
       : tensor<2x8xf64> into tensor<2x4x2xf64>
   return %0 : tensor<2x4x2xf64>
 }
 // CHECK-LABEL: @reshape_splat_constant_float64
 //       CHECK:   %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xf64>
-//   CHECK-NOT:   linalg.tensor_reshape
+//   CHECK-NOT:   linalg.tensor_expand_shape
 //       CHECK:   return %[[CST]]
 
 // -----
@@ -733,7 +733,7 @@ func @init_tensor_dim_of_linalg_result(%arg_0 : tensor<?xf32>,
 
 func @init_tensor_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> {
   %0 = linalg.init_tensor [6, 5, %arg0] : tensor<6x5x?xf32>
-  %1 = linalg.tensor_reshape %0 [[0, 1], [2], [3, 4, 5]]
+  %1 = linalg.tensor_expand_shape %0 [[0, 1], [2], [3, 4, 5]]
       : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32>
   return %1 : tensor<2x3x5x4x?x7xf32>
 }
@@ -748,7 +748,7 @@ func @init_tensor_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> {
 
 func @init_tensor_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> {
   %0 = linalg.init_tensor [2, 3, 5, 4, %arg0, 7] : tensor<2x3x5x4x?x7xf32>
-  %1 = linalg.tensor_reshape %0 [[0, 1], [2], [3, 4, 5]]
+  %1 = linalg.tensor_collapse_shape %0 [[0, 1], [2], [3, 4, 5]]
       : tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32>
   return %1 : tensor<6x5x?xf32>
 }
@@ -898,7 +898,7 @@ func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>) -> (index, index, index)
   %c1 = constant 1 : index
   %c3 = constant 3 : index
   %c4 = constant 4 : index
-  %0 = linalg.tensor_reshape %arg0 [[0, 1], [2], [3, 4, 5]]
+  %0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2], [3, 4, 5]]
       : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32>
   %1 = memref.dim %0, %c1 : tensor<2x3x5x4x?x7xf32>
   %2 = memref.dim %0, %c3 : tensor<2x3x5x4x?x7xf32>
@@ -921,7 +921,7 @@ func @dim_reshape_collapse(%arg0 : tensor<2x3x5x4x?x7xf32>) -> (index, index)
 {
   %c1 = constant 1 : index
   %c2 = constant 2 : index
-  %0 = linalg.tensor_reshape %arg0 [[0, 1], [2], [3, 4, 5]]
+  %0 = linalg.tensor_collapse_shape %arg0 [[0, 1], [2], [3, 4, 5]]
       : tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32>
   %1 = memref.dim %0, %c1 : tensor<6x5x?xf32>
   %2 = memref.dim %0, %c2 : tensor<6x5x?xf32>
@@ -979,7 +979,7 @@ func @fold_fill_reshape() -> tensor<6x4xf32> {
   %init = linalg.init_tensor [1, 2, 3, 4] : tensor<1x2x3x4xf32>
   // CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %cst) : tensor<6x4xf32>, f32 -> tensor<6x4xf32>
   %fill = linalg.fill(%init, %zero) : tensor<1x2x3x4xf32>, f32 -> tensor<1x2x3x4xf32>
-  %reshape = linalg.tensor_reshape %fill [[0, 1, 2], [3]]
+  %reshape = linalg.tensor_collapse_shape %fill [[0, 1, 2], [3]]
       : tensor<1x2x3x4xf32> into tensor<6x4xf32>
   // CHECK: return %[[FILL]] : tensor<6x4xf32>
   return %reshape : tensor<6x4xf32>
@@ -991,10 +991,10 @@ func @fold_fill_reshape() -> tensor<6x4xf32> {
 //  CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?x?x?xf32>
 func @fold_fill_reshape_dynamic(%arg0 : tensor<?x?x?x?x?xf32>) -> tensor<?x?xf32> {
   %zero = constant 0.0 : f32
-  // CHECK: %[[RESHAPE:.+]] = linalg.tensor_reshape %[[ARG0]]
+  // CHECK: %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %[[ARG0]]
   %0 = linalg.fill(%arg0, %zero) : tensor<?x?x?x?x?xf32>, f32 -> tensor<?x?x?x?x?xf32>
   // CHECK: %[[RESULT:.+]] = linalg.fill(%[[RESHAPE]], %{{.+}})
-  %1 = linalg.tensor_reshape %0 [[0, 1, 2], [3, 4]]
+  %1 = linalg.tensor_collapse_shape %0 [[0, 1, 2], [3, 4]]
       : tensor<?x?x?x?x?xf32> into tensor<?x?xf32>
   // CHECK: return %[[RESULT]]
   return %1 : tensor<?x?xf32>
diff  --git a/mlir/test/Dialect/Linalg/detensorize_0d.mlir b/mlir/test/Dialect/Linalg/detensorize_0d.mlir
index 91e3f080bedcb..be568901a868b 100644
--- a/mlir/test/Dialect/Linalg/detensorize_0d.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_0d.mlir
@@ -19,7 +19,7 @@ func @detensor_simple(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> att
 // CHECK-DAG:     %[[arg2_val:.*]] = tensor.extract %[[arg2]]
 // CHECK:         %[[detensored_res:.*]] = addf %[[arg1_val]], %[[arg2_val]]
 // CHECK:         %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res]]
-// CHECK:         %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]]
+// CHECK:         %[[reshaped_tensor_res:.*]] = linalg.tensor_collapse_shape %[[new_tensor_res]]
 // CHECK:         return %[[reshaped_tensor_res]]
 
 func @detensor_op_sequence(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> attributes {iree.module.export} {
@@ -61,7 +61,7 @@ func @detensor_op_sequence(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32
 // CHECK:         %[[detensored_res2:.*]] = mulf %[[arg1_val2]], %[[detensored_res]]
 // CHECK:         %[[detensored_res3:.*]] = divf %[[detensored_res]], %[[detensored_res2]]
 // CHECK:         %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res3]]
-// CHECK:         %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]]
+// CHECK:         %[[reshaped_tensor_res:.*]] = linalg.tensor_collapse_shape %[[new_tensor_res]]
 // CHECK:         return %[[reshaped_tensor_res]]
 
 func @detensor_multiple_ops(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> attributes {iree.module.export} {
@@ -83,7 +83,7 @@ func @detensor_multiple_ops(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f3
 // CHECK:         %[[detensored_res:.*]] = addf %[[arg1_val]], %[[arg2_val]]
 // CHECK:         %[[detensored_res2:.*]] = mulf %[[detensored_res]], %[[arg2_val]]
 // CHECK:         %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res2]]
-// CHECK:         %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]]
+// CHECK:         %[[reshaped_tensor_res:.*]] = linalg.tensor_collapse_shape %[[new_tensor_res]]
 // CHECK:         return %[[reshaped_tensor_res]]
 
 func @detensor_foreign_op(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> attributes {iree.module.export} {
@@ -103,5 +103,5 @@ func @detensor_foreign_op(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32>
 // CHECK-DAG:     %[[arg2_val:.*]] = tensor.extract %[[arg2]]
 // CHECK:         %[[detensored_res:.*]] = "foreign.do_something"(%[[arg1_val]], %[[arg2_val]])
 // CHECK:         %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res]]
-// CHECK:         %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]]
+// CHECK:         %[[reshaped_tensor_res:.*]] = linalg.tensor_collapse_shape %[[new_tensor_res]]
 // CHECK:         return %[[reshaped_tensor_res]]
diff  --git a/mlir/test/Dialect/Linalg/detensorize_if.mlir b/mlir/test/Dialect/Linalg/detensorize_if.mlir
index c5f7eaebbba43..6e2f315242781 100644
--- a/mlir/test/Dialect/Linalg/detensorize_if.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_if.mlir
@@ -10,10 +10,10 @@
 func @main() -> (tensor<i32>) attributes {} {
   %c0 = constant 0 : i32
   %0 = tensor.from_elements %c0 : tensor<1xi32>
-  %reshaped0 = linalg.tensor_reshape %0 [] : tensor<1xi32> into tensor<i32>
+  %reshaped0 = linalg.tensor_collapse_shape %0 [] : tensor<1xi32> into tensor<i32>
   %c10 = constant 10 : i32
   %1 = tensor.from_elements %c10 : tensor<1xi32>
-  %reshaped1 = linalg.tensor_reshape %1 [] : tensor<1xi32> into tensor<i32>
+  %reshaped1 = linalg.tensor_collapse_shape %1 [] : tensor<1xi32> into tensor<i32>
   br ^bb1(%reshaped0 : tensor<i32>)
 
 ^bb1(%2: tensor<i32>):  // 2 preds: ^bb0, ^bb2
@@ -55,7 +55,7 @@ func @main() -> (tensor<i32>) attributes {} {
 // CHECK-NEXT:     br ^[[bb3:.*]](%{{.*}} : i32)
 // CHECK-NEXT:   ^[[bb3]](%{{.*}}: i32)
 // CHECK-NEXT:     tensor.from_elements %{{.*}} : tensor<1xi32>
-// CHECK-NEXT:     linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor<i32>
+// CHECK-NEXT:     linalg.tensor_collapse_shape %{{.*}} [] : tensor<1xi32> into tensor<i32>
 // CHECK-NEXT:     return %{{.*}}
 // CHECK-NEXT:   }
 
@@ -74,10 +74,10 @@ func @main() -> (tensor<i32>) attributes {} {
 func @main() -> (tensor<i32>) attributes {} {
   %c0 = constant 0 : i32
   %0 = tensor.from_elements %c0 : tensor<1xi32>
-  %reshaped0 = linalg.tensor_reshape %0 [] : tensor<1xi32> into tensor<i32>
+  %reshaped0 = linalg.tensor_collapse_shape %0 [] : tensor<1xi32> into tensor<i32>
   %c10 = constant 10 : i32
   %1 = tensor.from_elements %c10 : tensor<1xi32>
-  %reshaped1 = linalg.tensor_reshape %1 [] : tensor<1xi32> into tensor<i32>
+  %reshaped1 = linalg.tensor_collapse_shape %1 [] : tensor<1xi32> into tensor<i32>
   br ^bb1(%reshaped0 : tensor<i32>)
 
 ^bb1(%2: tensor<i32>):  // 2 preds: ^bb0, ^bb2
@@ -124,7 +124,7 @@ func @main() -> (tensor<i32>) attributes {} {
 // CHECK-NEXT:     br ^[[bb4:.*]](%{{.*}} : i32)
 // CHECK-NEXT:   ^[[bb4]](%{{.*}}: i32)
 // CHECK-NEXT:     tensor.from_elements %{{.*}} : tensor<1xi32>
-// CHECK-NEXT:     linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor<i32>
+// CHECK-NEXT:     linalg.tensor_collapse_shape %{{.*}} [] : tensor<1xi32> into tensor<i32>
 // CHECK-NEXT:     return %{{.*}}
 // CHECK-NEXT:   }
 
@@ -140,10 +140,10 @@ func @main() -> (tensor<i32>) attributes {} {
 func @main() -> (tensor<i32>) attributes {} {
   %c0 = constant 0 : i32
   %0 = tensor.from_elements %c0 : tensor<1xi32>
-  %reshaped0 = linalg.tensor_reshape %0 [] : tensor<1xi32> into tensor<i32>
+  %reshaped0 = linalg.tensor_collapse_shape %0 [] : tensor<1xi32> into tensor<i32>
   %c10 = constant 10 : i32
   %1 = tensor.from_elements %c10 : tensor<1xi32>
-  %reshaped1 = linalg.tensor_reshape %1 [] : tensor<1xi32> into tensor<i32>
+  %reshaped1 = linalg.tensor_collapse_shape %1 [] : tensor<1xi32> into tensor<i32>
   br ^bb1(%reshaped0 : tensor<i32>)
 
 ^bb1(%2: tensor<i32>):  // 2 preds: ^bb0, ^bb2
@@ -164,7 +164,7 @@ func @main() -> (tensor<i32>) attributes {} {
 
 ^bb2(%6: tensor<i32>):  // pred: ^bb1
   %12 = tensor.from_elements %c10 : tensor<1xi32>
-  %reshaped12 = linalg.tensor_reshape %12 [] : tensor<1xi32> into tensor<i32>
+  %reshaped12 = linalg.tensor_collapse_shape %12 [] : tensor<1xi32> into tensor<i32>
   %7 = linalg.init_tensor [] : tensor<i32>
   %8 = linalg.generic #attrs
     ins(%6, %reshaped12 : tensor<i32>, tensor<i32>)
@@ -191,6 +191,6 @@ func @main() -> (tensor<i32>) attributes {} {
 // CHECK-NEXT:     br ^[[bb3:.*]](%{{.*}} : i32)
 // CHECK-NEXT:   ^[[bb3]](%{{.*}}: i32)
 // CHECK-NEXT:     tensor.from_elements %{{.*}} : tensor<1xi32>
-// CHECK-NEXT:     linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor<i32>
+// CHECK-NEXT:     linalg.tensor_collapse_shape %{{.*}} [] : tensor<1xi32> into tensor<i32>
 // CHECK-NEXT:     return %{{.*}}
 // CHECK-NEXT:   }
diff  --git a/mlir/test/Dialect/Linalg/detensorize_trivial.mlir b/mlir/test/Dialect/Linalg/detensorize_trivial.mlir
index 4e0b8fdd00468..ff6883040d8d7 100644
--- a/mlir/test/Dialect/Linalg/detensorize_trivial.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_trivial.mlir
@@ -12,7 +12,7 @@
 func @main(%farg0 : tensor<i32>) -> (tensor<i1>) attributes {} {
   %c10 = constant 10 : i32
   %1 = tensor.from_elements %c10 : tensor<1xi32>
-  %reshaped1 = linalg.tensor_reshape %1 [] : tensor<1xi32> into tensor<i32>
+  %reshaped1 = linalg.tensor_collapse_shape %1 [] : tensor<1xi32> into tensor<i32>
   %3 = linalg.init_tensor [] : tensor<i1>
   %4 = linalg.generic #attrs
     ins(%farg0, %reshaped1 : tensor<i32>, tensor<i32>)
@@ -30,7 +30,7 @@ func @main(%farg0 : tensor<i32>) -> (tensor<i1>) attributes {} {
 // DET-ALL-NEXT:    tensor.extract %{{.*}}[]
 // DET-ALL-NEXT:    cmpi slt, %{{.*}}, %{{.*}}
 // DET-ALL-NEXT:    tensor.from_elements %{{.*}}
-// DET-ALL-NEXT:    linalg.tensor_reshape %{{.*}}
+// DET-ALL-NEXT:    linalg.tensor_collapse_shape %{{.*}}
 // DET-ALL-NEXT:    return %{{.*}} : tensor<i1>
 // DET-ALL-NEXT:  }
 
diff  --git a/mlir/test/Dialect/Linalg/detensorize_while.mlir b/mlir/test/Dialect/Linalg/detensorize_while.mlir
index 4b4cf3aa98bfb..4f1c20bead283 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while.mlir
@@ -52,7 +52,7 @@ func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attributes {
 // DET-ALL:         br ^[[bb1]](%{{.*}} : i32)
 // DET-ALL:       ^[[bb3]](%{{.*}}: i32)
 // DET-ALL:         tensor.from_elements {{.*}}
-// DET-ALL:         linalg.tensor_reshape {{.*}}
+// DET-ALL:         linalg.tensor_collapse_shape {{.*}}
 // DET-ALL:         return %{{.*}} : tensor<i32>
 
 // Test detensoring only ops involed in control-flow.
@@ -69,5 +69,5 @@ func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attributes {
 // DET-CF:         br ^[[bb1]](%{{.*}} : i32)
 // DET-CF:       ^[[bb3]](%{{.*}}: i32)
 // DET-CF:         tensor.from_elements %{{.*}} : tensor<1xi32>
-// DET-CF:         linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor<i32>
+// DET-CF:         linalg.tensor_collapse_shape %{{.*}} [] : tensor<1xi32> into tensor<i32>
 // DET-CF:         return %{{.*}} : tensor<i32>
diff  --git a/mlir/test/Dialect/Linalg/detensorize_while_failure.mlir b/mlir/test/Dialect/Linalg/detensorize_while_failure.mlir
index 36361d5492110..b7db6adfbfe49 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while_failure.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while_failure.mlir
@@ -80,7 +80,7 @@ func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attribute
 // DET-ALL:         cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
 // DET-ALL:       ^[[bb2]](%{{.*}}: i32)
 // DET-ALL:         tensor.from_elements %{{.*}} : tensor<1xi32>
-// DET-ALL:         linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor<i32>
+// DET-ALL:         linalg.tensor_collapse_shape %{{.*}} [] : tensor<1xi32> into tensor<i32>
 // DET-ALL:         linalg.init_tensor [10] : tensor<10xi32>
 // DET-ALL:         linalg.generic {{{.*}}} ins(%{{.*}} : tensor<i32>) outs(%{{.*}} : tensor<10xi32>) {
 // DET-ALL:         ^bb0(%{{.*}}: i32, %{{.*}}: i32):
@@ -89,11 +89,11 @@ func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attribute
 // DET-ALL:         br ^[[bb1]](%{{.*}} : tensor<10xi32>)
 // DET-ALL:       ^[[bb3]](%{{.*}}: i32)
 // DET-ALL:         tensor.from_elements %{{.*}} : tensor<1xi32>
-// DET-ALL:         linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor<i32>
+// DET-ALL:         linalg.tensor_collapse_shape %{{.*}} [] : tensor<1xi32> into tensor<i32>
 // DET-ALL:         return %{{.*}} : tensor<i32>
 // DET-ALL:       }
 
-// Try to detensor pure control-flow. However, that fails since the potential 
+// Try to detensor pure control-flow. However, that fails since the potential
 // detensorable component contains some ops that cannot be detensored.
 //
 // DET-CF-LABEL: func @main
diff  --git a/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir
index b0d88efadca70..f230a91e52b7a 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir
@@ -10,10 +10,10 @@
 func @main() -> () attributes {} {
   %c0 = constant 0 : i32
   %0 = tensor.from_elements %c0 : tensor<1xi32>
-  %reshaped0 = linalg.tensor_reshape %0 [] : tensor<1xi32> into tensor<i32>
+  %reshaped0 = linalg.tensor_collapse_shape %0 [] : tensor<1xi32> into tensor<i32>
   %c10 = constant 10 : i32
   %1 = tensor.from_elements %c10 : tensor<1xi32>
-  %reshaped1 = linalg.tensor_reshape %1 [] : tensor<1xi32> into tensor<i32>
+  %reshaped1 = linalg.tensor_collapse_shape %1 [] : tensor<1xi32> into tensor<i32>
   br ^bb1(%reshaped0 : tensor<i32>)
 
 ^bb1(%2: tensor<i32>):  // 2 preds: ^bb0, ^bb2
diff  --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 5bc11f20c7cae..bd4f66defe6ee 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -23,11 +23,11 @@ func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>, %shape: tensor<?x1x?x1x?xf3
 //   CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
 //   CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 // CHECK-LABEL: func @drop_one_trip_loops
-//       CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1], [2]]
+//       CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0, 1], [2]]
 //       CHECK: linalg.generic
 //  CHECK-SAME:   indexing_maps = [#[[$MAP2]], #[[$MAP3]]]
 //  CHECK-SAME:   iterator_types = ["parallel", "parallel", "parallel"]
-//       CHECK: linalg.tensor_reshape %{{.*}} {{\[}}[0, 1], [2, 3], [4]]
+//       CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1], [2, 3], [4]]
 
 // -----
 
@@ -101,7 +101,7 @@ func @drop_all_loops(%arg0 : tensor<1x1xf32>) -> tensor<1x1xf32>
 }
 //       CHECK: #[[$MAP0:.*]] = affine_map<() -> ()>
 // CHECK-LABEL: func @drop_all_loops
-//       CHECK:   linalg.tensor_reshape %{{.*}} []
+//       CHECK:   linalg.tensor_collapse_shape %{{.*}} []
 //       CHECK:   linalg.generic
 //  CHECK-SAME:     indexing_maps = [#[[$MAP0]], #[[$MAP0]]]
 //  CHECK-SAME:     iterator_types = []
@@ -162,7 +162,7 @@ func @leading_dim_1_canonicalization(%arg0: tensor<1x5xf32>, %shape: tensor<5xf3
 //   CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (d0)>
 
 // CHECK-LABEL: func @leading_dim_1_canonicalization
-//       CHECK:   linalg.tensor_reshape %{{.*}} {{\[}}[0, 1]]
+//       CHECK:   linalg.tensor_collapse_shape %{{.*}} {{\[}}[0, 1]]
 //       CHECK:   linalg.generic
 //  CHECK-SAME:     indexing_maps = [#[[$MAP1]], #[[$MAP1]]]
 //  CHECK-SAME:     iterator_types = ["parallel"]
@@ -183,8 +183,8 @@ func @leading_dim_1_canonicalization(%arg0: tensor<1x5xf32>, %shape: tensor<5xf3
 
 func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>, %shape : tensor<5x5xf32>) -> tensor<5x5xf32>
 {
-  %0 = linalg.tensor_reshape %arg0 [[0, 1]] : tensor<5xf32> into tensor<1x5xf32>
-  %1 = linalg.tensor_reshape %arg1 [[0, 1]] : tensor<5xf32> into tensor<5x1xf32>
+  %0 = linalg.tensor_expand_shape %arg0 [[0, 1]] : tensor<5xf32> into tensor<1x5xf32>
+  %1 = linalg.tensor_expand_shape %arg1 [[0, 1]] : tensor<5xf32> into tensor<5x1xf32>
   %2 = linalg.generic #trait
      ins(%0, %1 : tensor<1x5xf32>, tensor<5x1xf32>)
     outs(%shape : tensor<5x5xf32>) {
@@ -198,11 +198,11 @@ func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>, %shape : tens
 //   CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
 //   CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK-LABEL: func @broadcast_test
-//   CHECK-NOT:   linalg.tensor_reshape
+//   CHECK-NOT:   linalg.tensor_{{.*}}shape
 //       CHECK:   linalg.generic
 //  CHECK-SAME:     indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
 //  CHECK-SAME:     iterator_types = ["parallel", "parallel"]
-//   CHECK-NOT:   linalg.tensor_reshape
+//   CHECK-NOT:   linalg.tensor_{{.*}}shape
 
 // -----
 
@@ -231,7 +231,7 @@ func @broadcast_scalar(%arg0 : tensor<1x1xf32>, %shape : tensor<?x?xf32>) -> ten
 //   CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK-LABEL: func @broadcast_scalar
 //  CHECK-SAME:   %[[ARG0:.*]]: tensor<1x1xf32>
-//       CHECK:   %[[A:.*]] = linalg.tensor_reshape %[[ARG0]] []
+//       CHECK:   %[[A:.*]] = linalg.tensor_collapse_shape %[[ARG0]] []
 //  CHECK-SAME:     tensor<1x1xf32> into tensor<f32>
 //       CHECK:   linalg.generic
 //  CHECK-SAME:     indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
@@ -251,7 +251,7 @@ func @fold_unit_dim_tensor_reshape_op(%arg0 : tensor<5xf32>) -> tensor<2x5xf32>
     ^bb0(%arg1: f32, %arg2: f32):  // no predecessors
       linalg.yield %arg1 : f32
     } -> tensor<1x2x5xf32>
-  %3 = linalg.tensor_reshape %2 [[0, 1], [2]]
+  %3 = linalg.tensor_collapse_shape %2 [[0, 1], [2]]
     : tensor<1x2x5xf32> into tensor<2x5xf32>
   return %3 : tensor<2x5xf32>
 }
@@ -283,7 +283,7 @@ func @fold_unit_dim_for_init_tensor(%input: tensor<1x1000xf32>) -> tensor<1xf32>
 //       CHECK: func @fold_unit_dim_for_init_tensor
 
 
-//       CHECK: %[[INPUT_RESHAPE:.+]] = linalg.tensor_reshape %{{.+}} {{\[}}[0, 1]] : tensor<1x1000xf32> into tensor<1000xf32>
+//       CHECK: %[[INPUT_RESHAPE:.+]] = linalg.tensor_collapse_shape %{{.+}} {{\[}}[0, 1]] : tensor<1x1000xf32> into tensor<1000xf32>
 //       CHECK: %[[INIT:.+]] = linalg.init_tensor [] : tensor<f32>
 //       CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %cst) : tensor<f32>, f32 -> tensor<f32>
 //       CHECK: %[[GENERIC:.+]] = linalg.generic
@@ -291,7 +291,7 @@ func @fold_unit_dim_for_init_tensor(%input: tensor<1x1000xf32>) -> tensor<1xf32>
 //  CHECK-SAME:     iterator_types = ["reduction"]
 //  CHECK-SAME:   ins(%[[INPUT_RESHAPE]] : tensor<1000xf32>)
 //  CHECK-SAME:   outs(%[[FILL]] : tensor<f32>)
-//       CHECK: %[[GENERIC_RESHAPE:.+]] = linalg.tensor_reshape %[[GENERIC]] [] : tensor<f32> into tensor<1xf32>
+//       CHECK: %[[GENERIC_RESHAPE:.+]] = linalg.tensor_expand_shape %[[GENERIC]] [] : tensor<f32> into tensor<1xf32>
 //       CHECK: return %[[GENERIC_RESHAPE:.+]] : tensor<1xf32>
 
 
@@ -314,11 +314,11 @@ func @fold_subtensor(
 // CHECK-SAME:   %[[ARG1:.+]]: tensor<1x?x?x?x?x1x1xf32>
 //      CHECK:   %[[SUBTENSOR1:.+]] = subtensor %[[ARG0]]
 // CHECK-SAME:       to tensor<?x?x?xf32>
-//      CHECK:   %[[RESULT1:.+]] = linalg.tensor_reshape %[[SUBTENSOR1]]
+//      CHECK:   %[[RESULT1:.+]] = linalg.tensor_expand_shape %[[SUBTENSOR1]]
 // CHECK-SAME:       [0, 1], [2], [3, 4, 5, 6]
 //      CHECK:   %[[SUBTENSOR2:.+]] = subtensor %[[ARG1]]
 // CHECK-SAME:       to tensor<?x?x?xf32>
-//      CHECK:   %[[RESULT2:.+]] = linalg.tensor_reshape %[[SUBTENSOR2]]
+//      CHECK:   %[[RESULT2:.+]] = linalg.tensor_expand_shape %[[SUBTENSOR2]]
 // CHECK-SAME:       [0, 1], [2], [3, 4, 5, 6]
 //      CHECK:   return %[[RESULT1]], %[[RESULT2]]
 
@@ -346,7 +346,7 @@ func @unit_dim_for_reduction(%arg0: tensor<1x?x1x?xf32>) -> tensor<1x?xf32> {
 //  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0)>
 //      CHECK: func @unit_dim_for_reduction
 // CHECK-SAME:   %[[ARG0:.+]]: tensor<1x?x1x?xf32>
-//  CHECK-DAG:   %[[RESHAPE:.+]] = linalg.tensor_reshape %[[ARG0]] {{\[}}[0, 1, 2], [3]]
+//  CHECK-DAG:   %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]]
 //      CHECK:   %[[INIT:.+]] = linalg.init_tensor [%{{.+}}] : tensor<?xf32>
 //      CHECK:   %[[FILL:.+]] = linalg.fill(%[[INIT]], %{{.+}})
 //      CHECK:   %[[RESULT:.+]] = linalg.generic
@@ -354,7 +354,7 @@ func @unit_dim_for_reduction(%arg0: tensor<1x?x1x?xf32>) -> tensor<1x?xf32> {
 // CHECK-SAME:     iterator_types = ["parallel", "reduction"]
 // CHECK-SAME:     ins(%[[RESHAPE]] : tensor<?x?xf32>)
 // CHECK-SAME:     outs(%[[FILL]] : tensor<?xf32>)
-//      CHECK:   %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[RESULT]] {{\[}}[0, 1]]
+//      CHECK:   %[[RESULT_RESHAPE:.+]] = linalg.tensor_expand_shape %[[RESULT]] {{\[}}[0, 1]]
 //      CHECK:   return %[[RESULT_RESHAPE]]
 
 // -----
@@ -380,7 +380,7 @@ func @unit_dim_for_reduction_keep_one(%arg0: tensor<1x?x1x1xf32>) -> tensor<1x1x
 //  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0)>
 //      CHECK: func @unit_dim_for_reduction_keep_one
 // CHECK-SAME:   %[[ARG0:.+]]: tensor<1x?x1x1xf32>
-//  CHECK-DAG:   %[[RESHAPE:.+]] = linalg.tensor_reshape %[[ARG0]] {{\[}}[0, 1, 2], [3]]
+//  CHECK-DAG:   %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]]
 //      CHECK:   %[[INIT:.+]] = linalg.init_tensor [1] : tensor<1xf32>
 //      CHECK:   %[[FILL:.+]] = linalg.fill(%[[INIT]], %{{.+}})
 //      CHECK:   %[[RESULT:.+]] = linalg.generic
@@ -388,7 +388,7 @@ func @unit_dim_for_reduction_keep_one(%arg0: tensor<1x?x1x1xf32>) -> tensor<1x1x
 // CHECK-SAME:     iterator_types = ["parallel", "reduction"]
 // CHECK-SAME:     ins(%[[RESHAPE]] : tensor<?x1xf32>)
 // CHECK-SAME:     outs(%[[FILL]] : tensor<1xf32>)
-//      CHECK:   %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[RESULT]] {{\[}}[0, 1]]
+//      CHECK:   %[[RESULT_RESHAPE:.+]] = linalg.tensor_expand_shape %[[RESULT]] {{\[}}[0, 1]]
 //      CHECK:   return %[[RESULT_RESHAPE]]
 
 // -----
@@ -415,7 +415,7 @@ func @unit_dim_for_reduction_inner(%arg0: tensor<?x1x?x1xf32>) -> tensor<?x1xf32
 //  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0)>
 //      CHECK: func @unit_dim_for_reduction_inner
 // CHECK-SAME:   %[[ARG0:.+]]: tensor<?x1x?x1xf32>
-//  CHECK-DAG:   %[[RESHAPE:.+]] = linalg.tensor_reshape %[[ARG0]] {{\[}}[0, 1], [2, 3]]
+//  CHECK-DAG:   %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]]
 //      CHECK:   %[[INIT:.+]] = linalg.init_tensor [%{{.+}}] : tensor<?xf32>
 //      CHECK:   %[[FILL:.+]] = linalg.fill(%[[INIT]], %{{.+}})
 //      CHECK:   %[[RESULT:.+]] = linalg.generic
@@ -423,7 +423,7 @@ func @unit_dim_for_reduction_inner(%arg0: tensor<?x1x?x1xf32>) -> tensor<?x1xf32
 // CHECK-SAME:     iterator_types = ["parallel", "reduction"]
 // CHECK-SAME:     ins(%[[RESHAPE]] : tensor<?x?xf32>)
 // CHECK-SAME:     outs(%[[FILL]] : tensor<?xf32>)
-//      CHECK:   %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[RESULT]] {{\[}}[0, 1]]
+//      CHECK:   %[[RESULT_RESHAPE:.+]] = linalg.tensor_expand_shape %[[RESULT]] {{\[}}[0, 1]]
 //      CHECK:   return %[[RESULT_RESHAPE]]
 
 // -----
@@ -435,7 +435,7 @@ func @subtensor_unit_dims(%arg0: tensor<1x3xf32>) -> tensor<1x1xf32> {
 // CHECK-LABEL: func @subtensor_unit_dims
 //       CHECK:   %[[SUBTENSOR:.+]] = subtensor
 //  CHECK-SAME:     tensor<1x3xf32> to tensor<f32>
-//       CHECK:   %[[RESULT:.+]] = linalg.tensor_reshape %[[SUBTENSOR]] []
+//       CHECK:   %[[RESULT:.+]] = linalg.tensor_expand_shape %[[SUBTENSOR]] []
 //       CHECK:   return %[[RESULT]]
 
 // -----
@@ -445,7 +445,7 @@ func @subtensor_insert_unit_dims(%arg0: tensor<1x3xf32>, %arg1: tensor<1x1xf32>)
   return %0 : tensor<1x3xf32>
 }
 // CHECK-LABEL: func @subtensor_insert_unit_dims
-//       CHECK:   %[[RESHAPE:.+]] = linalg.tensor_reshape %{{.+}} []
+//       CHECK:   %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %{{.+}} []
 //       CHECK:   %[[RESULT:.+]] = subtensor_insert %[[RESHAPE]]
 //  CHECK-SAME:     tensor<f32> into tensor<1x3xf32>
 //       CHECK:   return %[[RESULT]]
diff  --git a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
index 5f0ec829978d1..822663ab7006f 100644
--- a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
@@ -5,14 +5,14 @@
 
 // CHECK-LABEL: func @reshape
 // CHECK-SAME: (%[[A:.*]]: tensor<?x16xf32>, %[[B:.*]]: tensor<16xf32>, %[[INIT:.*]]: tensor<?x112x16xf32>)
-//      CHECK: %[[RI:.*]] = linalg.tensor_reshape %[[INIT]] {{\[}}[0, 1], [2]] : tensor<?x112x16xf32> into tensor<?x16xf32>
+//      CHECK: %[[RI:.*]] = linalg.tensor_collapse_shape %[[INIT]] {{\[}}[0, 1], [2]] : tensor<?x112x16xf32> into tensor<?x16xf32>
 //      CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP2]]],
 // CHECK-SAME: iterator_types = ["parallel", "parallel"]}
 // CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x16xf32>, tensor<16xf32>) outs(%[[RI]] : tensor<?x16xf32>)
-//      CHECK: %[[RR:.*]] = linalg.tensor_reshape %[[R]] {{\[}}[0, 1], [2]] : tensor<?x16xf32> into tensor<?x112x16xf32>
+//      CHECK: %[[RR:.*]] = linalg.tensor_expand_shape %[[R]] {{\[}}[0, 1], [2]] : tensor<?x16xf32> into tensor<?x112x16xf32>
 //      CHECK: return %[[RR]] : tensor<?x112x16xf32>
 func @reshape(%A: tensor<?x16xf32>, %B: tensor<16xf32>, %init: tensor<?x112x16xf32>) -> tensor<?x112x16xf32> {
-  %0 = linalg.tensor_reshape %A [[0, 1], [2]]
+  %0 = linalg.tensor_expand_shape %A [[0, 1], [2]]
       : tensor<?x16xf32> into tensor<?x112x16xf32>
   %2 = linalg.generic {indexing_maps = [
     affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>,
@@ -35,17 +35,17 @@ func @reshape(%A: tensor<?x16xf32>, %B: tensor<16xf32>, %init: tensor<?x112x16xf
 // CHECK-LABEL: func @reshape_multiple
 // CHECK-SAME: (%[[A:.*]]: tensor<12544x16xf32>, %[[B:.*]]: tensor<12544x16xf32>, %[[C:.*]]: tensor<16xf32>)
 //      CHECK: %[[I:.*]] = linalg.init_tensor [112, 112, 16] : tensor<112x112x16xf32>
-//      CHECK: %[[RI:.*]] = linalg.tensor_reshape %[[I]] {{\[}}[0, 1], [2]] : tensor<112x112x16xf32> into tensor<12544x16xf32>
+//      CHECK: %[[RI:.*]] = linalg.tensor_collapse_shape %[[I]] {{\[}}[0, 1], [2]] : tensor<112x112x16xf32> into tensor<12544x16xf32>
 //      CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]], #[[$MAP3]], #[[$MAP2]]],
 // CHECK-SAME: iterator_types = ["parallel", "parallel"]}
 // CHECK-SAME: ins(%[[A]], %[[B]], %[[C]] : tensor<12544x16xf32>, tensor<12544x16xf32>, tensor<16xf32>) outs(%[[RI]] : tensor<12544x16xf32>)
-//      CHECK: %[[RR:.*]] = linalg.tensor_reshape %[[R]] {{\[}}[0, 1], [2]] : tensor<12544x16xf32> into tensor<112x112x16xf32>
+//      CHECK: %[[RR:.*]] = linalg.tensor_expand_shape %[[R]] {{\[}}[0, 1], [2]] : tensor<12544x16xf32> into tensor<112x112x16xf32>
 //      CHECK: return %[[RR]] : tensor<112x112x16xf32>
 func @reshape_multiple(%A: tensor<12544x16xf32>, %B: tensor<12544x16xf32>,
   %C: tensor<16xf32>) -> tensor<112x112x16xf32> {
-  %0 = linalg.tensor_reshape %A [[0, 1], [2]]
+  %0 = linalg.tensor_expand_shape %A [[0, 1], [2]]
       : tensor<12544x16xf32> into tensor<112x112x16xf32>
-  %1 = linalg.tensor_reshape %B [[0, 1], [2]]
+  %1 = linalg.tensor_expand_shape %B [[0, 1], [2]]
       : tensor<12544x16xf32> into tensor<112x112x16xf32>
   %2 = linalg.init_tensor [112, 112, 16] : tensor<112x112x16xf32>
   %3 = linalg.generic {indexing_maps = [
@@ -69,11 +69,11 @@ func @reshape_multiple(%A: tensor<12544x16xf32>, %B: tensor<12544x16xf32>,
 // Negative test, since the second source is broadcasted from d1 we cannot merge
 // d0 and d1 dimensions
 // CHECK-LABEL: func @reshape_negative
-// CHECK: linalg.tensor_reshape {{.*}} : tensor<12544x16xf32> into tensor<112x112x16xf32>
+// CHECK: linalg.tensor_expand_shape {{.*}} : tensor<12544x16xf32> into tensor<112x112x16xf32>
 // CHECK: linalg.generic
 // CHECK: } -> tensor<112x112x16xf32>
 func @reshape_negative(%A: tensor<12544x16xf32>, %B: tensor<112xf32>) -> tensor<112x112x16xf32> {
-  %20 = linalg.tensor_reshape %A [[0, 1], [2]]
+  %20 = linalg.tensor_expand_shape %A [[0, 1], [2]]
       : tensor<12544x16xf32> into tensor<112x112x16xf32>
   %21 = linalg.init_tensor [112, 112, 16] : tensor<112x112x16xf32>
   %22 = linalg.generic {indexing_maps = [
@@ -96,7 +96,7 @@ func @type_correctness(%arg0 : tensor<6x5xi32>, %arg1 : tensor<5xf32>,
   %cst_6 = constant 1.000000e+00 : f32
   %cst_7 = constant 7.000000e+00 : f32
   %cst_8 = constant 1.1920929E-7 : f32
-  %25 = linalg.tensor_reshape %arg0 [[0, 1], [2]]
+  %25 = linalg.tensor_expand_shape %arg0 [[0, 1], [2]]
       : tensor<6x5xi32> into tensor<2x3x5xi32>
   %26 = linalg.init_tensor [2, 3, 5] : tensor<2x3x5xf32>
   %28 = linalg.generic {
@@ -122,5 +122,5 @@ func @type_correctness(%arg0 : tensor<6x5xi32>, %arg1 : tensor<5xf32>,
 //       CHECK:   %[[OP:.+]] = linalg.generic
 //  CHECK-SAME:   ins(%{{.+}}, %{{.+}}, %{{.+}} : tensor<6x5xi32>, tensor<5xf32>, tensor<5xf32>)
 //  CHECK-SAME:   outs(%{{.+}} : tensor<6x5xf32>)
-//       CHECK:   linalg.tensor_reshape %[[OP]]
+//       CHECK:   linalg.tensor_expand_shape %[[OP]]
 //  CHECK-SAME:   tensor<6x5xf32> into tensor<2x3x5xf32>
diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 895799629ddf3..9dc20a46dde0a 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -348,21 +348,35 @@ func @generic(%arg0: memref<?x?xi4>) {
 
 func @reshape(%arg0: memref<f32>) {
   // expected-error @+1 {{expected non-zero memref ranks}}
-  %0 = linalg.reshape %arg0 [[0]] : memref<f32> into memref<f32>
+  %0 = linalg.expand_shape %arg0 [[0]] : memref<f32> into memref<f32>
+}
+
+// -----
+
+func @collapse_to_higher_rank(%arg0: memref<f32>) {
+  // expected-error @+1 {{expected the type 'memref<f32>' to have higher rank than the type = 'memref<1xf32>'}}
+  %0 = linalg.collapse_shape %arg0 [[0]] : memref<f32> into memref<1xf32>
+}
+
+// -----
+
+func @expand_to_smaller_rank(%arg0: memref<1xf32>) {
+  // expected-error @+1 {{expected the type 'memref<f32>' to have higher rank than the type = 'memref<1xf32>'}}
+  %0 = linalg.expand_shape %arg0 [[0]] : memref<1xf32> into memref<f32>
 }
 
 // -----
 
 func @reshape(%arg0: memref<?xf32>) {
   // expected-error @+1 {{expected to collapse or expand dims}}
-  %0 = linalg.reshape %arg0 [[0]] : memref<?xf32> into memref<?xf32>
+  %0 = linalg.collapse_shape %arg0 [[0]] : memref<?xf32> into memref<?xf32>
 }
 
 // -----
 
 func @reshape(%arg0: memref<?x?x?xf32>) {
   // expected-error @+1 {{expected rank of the collapsed type(2) to be the number of reassociation maps(1)}}
-  %0 = linalg.reshape %arg0 [[0, 1]] :
+  %0 = linalg.collapse_shape %arg0 [[0, 1]] :
     memref<?x?x?xf32> into memref<?x?xf32, offset: 0, strides: [?, 1]>
 }
 
@@ -370,7 +384,7 @@ func @reshape(%arg0: memref<?x?x?xf32>) {
 
 func @reshape(%arg0: memref<?x?x?xf32>) {
   // expected-error @+1 {{expected reassociation map #1 to be valid and contiguous}}
-  %0 = linalg.reshape %arg0 [[0, 1], [1, 2]] :
+  %0 = linalg.collapse_shape %arg0 [[0, 1], [1, 2]] :
     memref<?x?x?xf32> into memref<?x?xf32, offset: 0, strides: [?, 1]>
 }
 
@@ -378,7 +392,7 @@ func @reshape(%arg0: memref<?x?x?xf32>) {
 
 func @reshape(%arg0: memref<?x?x?xf32>) {
   // expected-error @+1 {{expected collapsed type to be 'memref<?x?xf32>', but got 'memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * s0 + d1)>>'}}
-  %0 = linalg.reshape %arg0 [[0, 1], [2]] :
+  %0 = linalg.collapse_shape %arg0 [[0, 1], [2]] :
     memref<?x?x?xf32> into memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * s0 + d1)>>
 }
 
@@ -455,7 +469,7 @@ func @illegal_expanding_reshape_dynamic_tensor
   (%arg0: tensor<?x?x?xf32>) -> tensor<?x?x?x4x?xf32>
 {
   // expected-error @+1 {{invalid to have a single dimension (2) expanded into multiple dynamic dims (2,4)}}
-  %0 = linalg.tensor_reshape %arg0 [[0], [1], [2, 3, 4]]
+  %0 = linalg.tensor_expand_shape %arg0 [[0], [1], [2, 3, 4]]
       : tensor<?x?x?xf32> into tensor<?x?x?x4x?xf32>
   return %0 : tensor<?x?x?x4x?xf32>
 }
@@ -466,7 +480,7 @@ func @illegal_expanding_reshape_dynamic_memref
   (%arg0: memref<?x?x?xf32>) -> memref<?x?x?x4x?xf32>
 {
   // expected-error @+1 {{invalid to have a single dimension (2) expanded into multiple dynamic dims (2,4)}}
-  %0 = linalg.reshape %arg0 [[0], [1], [2, 3, 4]]
+  %0 = linalg.expand_shape %arg0 [[0], [1], [2, 3, 4]]
       : memref<?x?x?xf32> into memref<?x?x?x4x?xf32>
   return %0 : memref<?x?x?x4x?xf32>
 }
@@ -477,7 +491,7 @@ func @illegal_expanding_reshape_static_tensor
   (%arg0: tensor<2x3x20xf32>) -> tensor<2x3x2x4x5xf32>
 {
   // expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}}
-  %0 = linalg.tensor_reshape %arg0 [[0], [1], [2, 3, 4]]
+  %0 = linalg.tensor_expand_shape %arg0 [[0], [1], [2, 3, 4]]
       : tensor<2x3x20xf32> into tensor<2x3x2x4x5xf32>
   return %0 : tensor<2x3x2x4x5xf32>
 }
@@ -488,7 +502,7 @@ func @illegal_collapsing_reshape_static_tensor
   (%arg0: tensor<2x3x2x4x5xf32>) -> tensor<2x3x20xf32>
 {
   // expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}}
-  %0 = linalg.tensor_reshape %arg0 [[0], [1], [2, 3, 4]]
+  %0 = linalg.tensor_collapse_shape %arg0 [[0], [1], [2, 3, 4]]
       : tensor<2x3x2x4x5xf32> into tensor<2x3x20xf32>
   return %0 : tensor<2x3x20xf32>
 }
@@ -499,7 +513,7 @@ func @illegal_expanding_reshape_static_memref
   (%arg0: memref<2x3x20xf32>) -> memref<2x3x2x4x5xf32>
 {
   // expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}}
-  %0 = linalg.reshape %arg0 [[0], [1], [2, 3, 4]]
+  %0 = linalg.expand_shape %arg0 [[0], [1], [2, 3, 4]]
       : memref<2x3x20xf32> into memref<2x3x2x4x5xf32>
   return %0 : memref<2x3x2x4x5xf32>
 }
@@ -510,87 +524,87 @@ func @illegal_collapsing_reshape_static_memref
   (%arg0: memref<2x3x2x4x5xf32>) -> memref<2x3x20xf32>
 {
   // expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}}
-  %0 = linalg.reshape %arg0 [[0], [1], [2, 3, 4]]
+  %0 = linalg.collapse_shape %arg0 [[0], [1], [2, 3, 4]]
       : memref<2x3x2x4x5xf32> into memref<2x3x20xf32>
   return %0 : memref<2x3x20xf32>
 }
 
 // -----
 
-func @illegal_collapsing_reshape_mixed_tensor(%arg0 : tensor<?x?xf32>) -> tensor<?x4x5xf32>
+func @illegal_expanding_reshape_mixed_tensor(%arg0 : tensor<?x?xf32>) -> tensor<?x4x5xf32>
 {
   // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}}
-  %0 = linalg.tensor_reshape %arg0 [[0, 1], [2]]
+  %0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2]]
       : tensor<?x?xf32> into tensor<?x4x5xf32>
   return %0 : tensor<?x4x5xf32>
 }
 
 // -----
 
-func @illegal_collapsing_reshape_mixed_tensor_2(%arg0 : tensor<?x?xf32>) -> tensor<?x4x5xf32>
+func @illegal_expanding_reshape_mixed_tensor_2(%arg0 : tensor<?x?xf32>) -> tensor<?x4x5xf32>
 {
   // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}}
-  %0 = linalg.tensor_reshape %arg0 [[0], [1, 2]]
+  %0 = linalg.tensor_expand_shape %arg0 [[0], [1, 2]]
       : tensor<?x?xf32> into tensor<?x4x5xf32>
   return %0 : tensor<?x4x5xf32>
 }
 
 // -----
 
-func @illegal_expanding_reshape_mixed_tensor(%arg0 : tensor<?x4x5xf32>) -> tensor<?x?xf32>
+func @illegal_collapsing_reshape_mixed_tensor(%arg0 : tensor<?x4x5xf32>) -> tensor<?x?xf32>
 {
   // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}}
-  %0 = linalg.tensor_reshape %arg0 [[0, 1], [2]]
+  %0 = linalg.tensor_collapse_shape %arg0 [[0, 1], [2]]
       : tensor<?x4x5xf32> into tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
 }
 
 // -----
 
-func @illegal_expanding_reshape_mixed_tensor_2(%arg0 : tensor<?x4x5xf32>) -> tensor<?x?xf32>
+func @illegal_collapsing_reshape_mixed_tensor_2(%arg0 : tensor<?x4x5xf32>) -> tensor<?x?xf32>
 {
   // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}}
-  %0 = linalg.tensor_reshape %arg0 [[0], [1, 2]]
+  %0 = linalg.tensor_collapse_shape %arg0 [[0], [1, 2]]
       : tensor<?x4x5xf32> into tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
 }
 
 // -----
 
-func @illegal_collapsing_reshape_mixed_memref(%arg0 : memref<?x?xf32>) -> memref<?x4x5xf32>
+func @illegal_expanding_reshape_mixed_memref(%arg0 : memref<?x?xf32>) -> memref<?x4x5xf32>
 {
   // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}}
-  %0 = linalg.reshape %arg0 [[0, 1], [2]]
+  %0 = linalg.expand_shape %arg0 [[0, 1], [2]]
       : memref<?x?xf32> into memref<?x4x5xf32>
   return %0 : memref<?x4x5xf32>
 }
 
 // -----
 
-func @illegal_collapsing_reshape_mixed_memref_2(%arg0 : memref<?x?xf32>) -> memref<?x4x5xf32>
+func @illegal_expanding_reshape_mixed_memref_2(%arg0 : memref<?x?xf32>) -> memref<?x4x5xf32>
 {
   // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}}
-  %0 = linalg.reshape %arg0 [[0], [1, 2]]
+  %0 = linalg.expand_shape %arg0 [[0], [1, 2]]
       : memref<?x?xf32> into memref<?x4x5xf32>
   return %0 : memref<?x4x5xf32>
 }
 
 // -----
 
-func @illegal_expanding_reshape_mixed_memref(%arg0 : memref<?x4x5xf32>) -> memref<?x?xf32>
+func @illegal_collapsing_reshape_mixed_memref(%arg0 : memref<?x4x5xf32>) -> memref<?x?xf32>
 {
   // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}}
-  %0 = linalg.reshape %arg0 [[0, 1], [2]]
+  %0 = linalg.collapse_shape %arg0 [[0, 1], [2]]
       : memref<?x4x5xf32> into memref<?x?xf32>
   return %0 : memref<?x?xf32>
 }
 
 // -----
 
-func @illegal_expanding_reshape_mixed_memref_2(%arg0 : memref<?x4x5xf32>) -> memref<?x?xf32>
+func @illegal_collapse_reshape_mixed_memref_2(%arg0 : memref<?x4x5xf32>) -> memref<?x?xf32>
 {
   // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}}
-  %0 = linalg.reshape %arg0 [[0], [1, 2]]
+  %0 = linalg.collapse_shape %arg0 [[0], [1, 2]]
       : memref<?x4x5xf32> into memref<?x?xf32>
   return %0 : memref<?x?xf32>
 }
diff  --git a/mlir/test/Dialect/Linalg/llvm.mlir b/mlir/test/Dialect/Linalg/llvm.mlir
index ea57d8ef299b9..4ddee115c0b03 100644
--- a/mlir/test/Dialect/Linalg/llvm.mlir
+++ b/mlir/test/Dialect/Linalg/llvm.mlir
@@ -14,13 +14,13 @@ func @range(%arg0: index) {
 //       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(i64, i64, i64)>
 //       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(i64, i64, i64)>
 
-func @reshape_static_expand(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> {
+func @expand_shape_static(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> {
   // Reshapes that expand a contiguous tensor with some 1's.
-  %0 = linalg.reshape %arg0 [[0, 1], [2], [3, 4]]
+  %0 = linalg.expand_shape %arg0 [[0, 1], [2], [3, 4]]
       : memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
   return %0 : memref<1x3x4x1x5xf32>
 }
-// CHECK-LABEL: func @reshape_static_expand
+// CHECK-LABEL: func @expand_shape_static
 //       CHECK:    llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
 //       CHECK:    llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
 //       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
@@ -49,12 +49,12 @@ func @reshape_static_expand(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> {
 //       CHECK:    llvm.mlir.constant(1 : index) : i64
 //       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[4, 4] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
 
-func @reshape_static_collapse(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> {
-  %0 = linalg.reshape %arg0 [[0, 1], [2], [3, 4]] :
+func @collapse_shape_static(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> {
+  %0 = linalg.collapse_shape %arg0 [[0, 1], [2], [3, 4]] :
     memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
   return %0 : memref<3x4x5xf32>
 }
-// CHECK-LABEL: func @reshape_static_collapse
+// CHECK-LABEL: func @collapse_shape_static
 //       CHECK:    llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
 //       CHECK:    llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
 //       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
@@ -75,11 +75,11 @@ func @reshape_static_collapse(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32>
 //       CHECK:    llvm.mlir.constant(1 : index) : i64
 //       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
 
-func @reshape_fold_zero_dim(%arg0 : memref<1x1xf32>) -> memref<f32> {
-  %0 = linalg.reshape %arg0 [] : memref<1x1xf32> into memref<f32>
+func @collapse_shape_fold_zero_dim(%arg0 : memref<1x1xf32>) -> memref<f32> {
+  %0 = linalg.collapse_shape %arg0 [] : memref<1x1xf32> into memref<f32>
   return %0 : memref<f32>
 }
-// CHECK-LABEL: func @reshape_fold_zero_dim
+// CHECK-LABEL: func @collapse_shape_fold_zero_dim
 //       CHECK:   llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
 //       CHECK:   llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
 //       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
@@ -88,11 +88,11 @@ func @reshape_fold_zero_dim(%arg0 : memref<1x1xf32>) -> memref<f32> {
 //       CHECK:   llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
 //       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
 
-func @reshape_expand_zero_dim(%arg0 : memref<f32>) -> memref<1x1xf32> {
-  %0 = linalg.reshape %arg0 [] : memref<f32> into memref<1x1xf32>
+func @expand_shape_zero_dim(%arg0 : memref<f32>) -> memref<1x1xf32> {
+  %0 = linalg.expand_shape %arg0 [] : memref<f32> into memref<1x1xf32>
   return %0 : memref<1x1xf32>
 }
-// CHECK-LABEL: func @reshape_expand_zero_dim
+// CHECK-LABEL: func @expand_shape_zero_dim
 //       CHECK:   llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
 //       CHECK:   llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
 //       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
diff  --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 576c83a32dab5..4a3d54c3dc3a4 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -6,7 +6,7 @@ func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xf32>,
                                          %arg1 : tensor<?x?x?xf32>) ->
                                          tensor<?x?x?xf32>
 {
-  %0 = linalg.tensor_reshape %arg0 [[0], [1, 2], [3]] :
+  %0 = linalg.tensor_collapse_shape %arg0 [[0], [1, 2], [3]] :
     tensor<?x?x4x?xf32> into tensor<?x?x?xf32>
   %1 = linalg.generic {
      indexing_maps = [#map0, #map1, #map1],
@@ -25,16 +25,16 @@ func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xf32>,
 //      CHECK: func @generic_op_reshape_producer_fusion
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x4x?xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
-//      CHECK:   %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
+//      CHECK:   %[[T0:.+]] = linalg.tensor_collapse_shape %[[ARG0]]
 // CHECK-SAME:     [0], [1, 2], [3]
-//      CHECK:   %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]]
+//      CHECK:   %[[T1:.+]] = linalg.tensor_expand_shape %[[ARG1]]
 // CHECK-SAME:     [0], [1], [2, 3]
 //      CHECK:   %[[T3:.+]] = linalg.generic
 // CHECK-SAME:     indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP6]]]
 // CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
 // CHECK-SAME:     ins(%[[ARG0]], %[[T1]] : tensor<?x?x4x?xf32>, tensor<?x?x?x4xf32>)
 // CHECK-SAME:     outs(%{{.+}} : tensor<?x?x?x4xf32>)
-//      CHECK:   %[[T4:.+]] = linalg.tensor_reshape %[[T3]]
+//      CHECK:   %[[T4:.+]] = linalg.tensor_collapse_shape %[[T3]]
 // CHECK-SAME:     [0], [1], [2, 3]
 // CHECK-SAME:     tensor<?x?x?x4xf32> into tensor<?x?x?xf32>
 //      CHECK:   return %[[T4]]
@@ -55,19 +55,19 @@ func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
       %1 = mulf %arg3, %arg4 : f32
       linalg.yield %1 : f32
   } -> tensor<?x?xf32>
-  %1 = linalg.tensor_reshape %0 [[0], [1, 2, 3]] :
+  %1 = linalg.tensor_expand_shape %0 [[0], [1, 2, 3]] :
     tensor<?x?xf32> into tensor<?x4x?x5xf32>
   return %1 : tensor<?x4x?x5xf32>
 }
 
 //  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 //      CHECK: func @generic_op_reshape_consumer_fusion
-// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-//      CHECK:   %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>,
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>)
+//      CHECK:   %[[T0:.+]] = linalg.tensor_expand_shape %[[ARG0]]
 // CHECK-SAME:     [0], [1, 2, 3]
 // CHECK-SAME:     tensor<?x?xf32> into tensor<?x4x?x5xf32>
-//      CHECK:   %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]]
+//      CHECK:   %[[T1:.+]] = linalg.tensor_expand_shape %[[ARG1]]
 // CHECK-SAME:     [0], [1, 2, 3]
 // CHECK-SAME:     tensor<?x?xf32> into tensor<?x4x?x5xf32>
 //      CHECK:   %[[T3:.+]] = linalg.generic
@@ -94,7 +94,7 @@ func @reshape_as_consumer_permutation
          %1 = addf %arg0, %arg1 : f32
          linalg.yield %1 : f32
        } -> tensor<?x?x?xf32>
-  %d = linalg.tensor_reshape %c [[0, 1], [2], [3, 4, 5]]
+  %d = linalg.tensor_expand_shape %c [[0, 1], [2], [3, 4, 5]]
        : tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32>
   return %d : tensor<?x2x?x3x4x?xf32>
 }
@@ -104,10 +104,10 @@ func @reshape_as_consumer_permutation
 //      CHECK: func @reshape_as_consumer_permutation
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-//      CHECK:   %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
+//      CHECK:   %[[T0:.+]] = linalg.tensor_expand_shape %[[ARG0]]
 // CHECK-SAME:     [0, 1, 2], [3, 4], [5]
 // CHECK-SAME:     tensor<?x?x?xf32> into tensor<3x4x?x?x2x?xf32>
-//      CHECK:   %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]]
+//      CHECK:   %[[T1:.+]] = linalg.tensor_expand_shape %[[ARG1]]
 // CHECK-SAME:     [0, 1, 2], [3]
 // CHECK-SAME:     tensor<?x?xf32> into tensor<3x4x?x?xf32>
 //      CHECK:   %[[T3:.+]] = linalg.generic
@@ -136,7 +136,7 @@ func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
       %2 = mulf %arg1, %arg2 : f32
       linalg.yield %2 : f32
     } -> tensor<264x4xf32>
-  %2 = linalg.tensor_reshape %1 [[0, 1], [2]] :
+  %2 = linalg.tensor_expand_shape %1 [[0, 1], [2]] :
     tensor<264x4xf32> into tensor<8x33x4xf32>
   return %2 : tensor<8x33x4xf32>
 }
@@ -144,7 +144,7 @@ func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
 //  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 //      CHECK: func @generic_op_reshape_consumer_static
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<264x4xf32>
-//      CHECK:   %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
+//      CHECK:   %[[T0:.+]] = linalg.tensor_expand_shape %[[ARG0]]
 // CHECK-SAME:     [0, 1], [2]
 // CHECK-SAME:     tensor<264x4xf32> into tensor<8x33x4xf32>
 //      CHECK:   %[[T1:.+]] = linalg.init_tensor [8, 33, 4]
@@ -163,7 +163,7 @@ func @indexed_consumer_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>,
                                          %arg1 : tensor<?x?x?xi32>) ->
                                          tensor<?x?x?xi32>
 {
-  %0 = linalg.tensor_reshape %arg0 [[0], [1, 2], [3]]:
+  %0 = linalg.tensor_collapse_shape %arg0 [[0], [1, 2], [3]]:
     tensor<?x?x4x?xi32> into tensor<?x?x?xi32>
   %1 = linalg.generic {
      indexing_maps = [#map0, #map1, #map1],
@@ -229,7 +229,7 @@ func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
       %5 = addi %3, %4 : i32
       linalg.yield %5 : i32
   } -> tensor<?x?xi32>
-  %1 = linalg.tensor_reshape %0 [[0], [1, 2, 3]] :
+  %1 = linalg.tensor_expand_shape %0 [[0], [1, 2, 3]] :
     tensor<?x?xi32> into tensor<?x?x4x5xi32>
   return %1 : tensor<?x?x4x5xi32>
 }
@@ -279,7 +279,7 @@ func @reshape_as_consumer_permutation
          %7 = addi %5, %6 : i32
          linalg.yield %7 : i32
        } -> tensor<6x4x210xi32>
-  %d = linalg.tensor_reshape %c [[0, 1], [2], [3, 4, 5]]
+  %d = linalg.tensor_expand_shape %c [[0, 1], [2], [3, 4, 5]]
        : tensor<6x4x210xi32> into tensor<2x3x4x5x6x7xi32>
   return %d : tensor<2x3x4x5x6x7xi32>
 }
@@ -293,9 +293,9 @@ func @reshape_as_consumer_permutation
 //       CHECK: func @reshape_as_consumer_permutation
 //  CHECK-SAME:   %[[ARG0:.+]]: tensor<210x6x4xi32>
 //  CHECK-SAME:   %[[ARG1:.+]]: tensor<210x4xi32>
-//   CHECK-DAG:   %[[T1:.+]] = linalg.tensor_reshape %[[ARG0]]
+//   CHECK-DAG:   %[[T1:.+]] = linalg.tensor_expand_shape %[[ARG0]]
 //  CHECK-SAME:     [0, 1, 2], [3, 4], [5]
-//   CHECK-DAG:   %[[T2:.+]] = linalg.tensor_reshape %[[ARG1]]
+//   CHECK-DAG:   %[[T2:.+]] = linalg.tensor_expand_shape %[[ARG1]]
 //  CHECK-SAME:     [0, 1, 2], [3]
 //   CHECK-DAG:   %[[T0:.+]] = linalg.init_tensor [2, 3, 4, 5, 6, 7]
 //       CHECK:   %[[T4:.+]] = linalg.generic
@@ -326,7 +326,7 @@ func @reshape_as_consumer_permutation
 func @reshape_as_producer_projected_permutation(
     %arg0 : tensor<33x8x?xi32>, %shape : tensor<264x?x4xi32>) -> tensor<264x?x4xi32>
 {
-  %0 = linalg.tensor_reshape %arg0 [[0, 1], [2]]
+  %0 = linalg.tensor_collapse_shape %arg0 [[0, 1], [2]]
     : tensor<33x8x?xi32> into tensor<264x?xi32>
   %1 = linalg.generic
     {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>,
@@ -372,7 +372,7 @@ func @reshape_as_producer_projected_permutation(
 //       CHECK:       %[[T5:.+]] = index_cast %[[IDX3]] : index to i32
 //       CHECK:       %[[T6:.+]] = addi %[[T4]], %[[T5]] : i32
 //       CHECK:       linalg.yield %[[T6]] : i32
-//       CHECK:    %[[RES2:.+]] = linalg.tensor_reshape %[[RES]]
+//       CHECK:    %[[RES2:.+]] = linalg.tensor_collapse_shape %[[RES]]
 //  CHECK-SAME:      [0, 1], [2], [3]
 //  CHECK-SAME:    : tensor<33x8x?x4xi32> into tensor<264x?x4xi32>
 //       CHECK:  return %[[RES2]] : tensor<264x?x4xi32>
@@ -394,7 +394,7 @@ func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor<?x?xf32>,
       %1 = mulf %arg3, %arg4 : f32
       linalg.yield %1 : f32
   } -> tensor<?x?xf32>
-  %1 = linalg.tensor_reshape %0 [[0], [1, 2, 3]] :
+  %1 = linalg.tensor_expand_shape %0 [[0], [1, 2, 3]] :
     tensor<?x?xf32> into tensor<?x?x4x5xf32>
   return %1 : tensor<?x?x4x5xf32>
 }
@@ -404,10 +404,10 @@ func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor<?x?xf32>,
 //      CHECK: func @generic_op_reshape_consumer_fusion_projected
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-//      CHECK:   %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
+//      CHECK:   %[[T0:.+]] = linalg.tensor_expand_shape %[[ARG0]]
 // CHECK-SAME:     [0, 1, 2], [3]
 // CHECK-SAME:     tensor<?x?xf32> into tensor<?x4x5x?xf32>
-//      CHECK:   %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]]
+//      CHECK:   %[[T1:.+]] = linalg.tensor_expand_shape %[[ARG1]]
 // CHECK-SAME:     [0, 1, 2], [3]
 // CHECK-SAME:     tensor<?x?xf32> into tensor<?x4x5x?xf32>
 //      CHECK:   %[[T3:.+]] = linalg.generic
@@ -420,7 +420,7 @@ func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor<?x?xf32>,
 // -----
 
 func @unit_dim_reshape_expansion(%arg0 : tensor<1x5xf32>) -> tensor<5x5xf32> {
-  %0 = linalg.tensor_reshape %arg0 [[0, 1]]
+  %0 = linalg.tensor_collapse_shape %arg0 [[0, 1]]
       : tensor<1x5xf32> into tensor<5xf32>
   %1 = linalg.init_tensor [5, 5] : tensor<5x5xf32>
   %2 = linalg.generic
@@ -434,7 +434,7 @@ func @unit_dim_reshape_expansion(%arg0 : tensor<1x5xf32>) -> tensor<5x5xf32> {
   return %2 : tensor<5x5xf32>
 }
 //      CHECK: func @unit_dim_reshape_expansion
-//  CHECK-DAG:   linalg.tensor_reshape
+//  CHECK-DAG:   linalg.tensor_collapse_shape
 //  CHECK-DAG:   linalg.init_tensor
 //      CHECK:   linalg.generic
 
@@ -450,14 +450,14 @@ func @unit_dim_reshape_collapse(%arg0 : tensor<5xf32>) -> tensor<5x1x5xf32> {
   ^bb0(%arg2: f32, %arg3: f32):  // no predecessors
     linalg.yield %arg2 : f32
   } -> tensor<5x5xf32>
-  %2 = linalg.tensor_reshape %1 [[0, 1], [2]]
+  %2 = linalg.tensor_expand_shape %1 [[0, 1], [2]]
     : tensor<5x5xf32> into tensor<5x1x5xf32>
   return %2 : tensor<5x1x5xf32>
 }
 // CHECK: func @unit_dim_reshape_collapse
 // CHECK:   linalg.init_tensor
 // CHECK:   linalg.generic
-// CHECK:   linalg.tensor_reshape
+// CHECK:   linalg.tensor_expand_shape
 
 // -----
 
@@ -465,7 +465,7 @@ func @unit_dim_reshape_expansion_full
   (%arg0 : tensor<1x?x1x2x1x4xf32>, %arg1 : tensor<?x2x4xf32>)
   -> tensor<?x2x4xf32> {
   %c1 = constant 1 : index
-  %0 = linalg.tensor_reshape %arg0 [[0, 1, 2], [3, 4], [5]]
+  %0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2], [3, 4], [5]]
     : tensor<1x?x1x2x1x4xf32> into tensor<?x2x4xf32>
   %1 = memref.dim %arg0, %c1 : tensor<1x?x1x2x1x4xf32>
   %2 = linalg.init_tensor [%1, 2, 4] : tensor<?x2x4xf32>
@@ -483,7 +483,7 @@ func @unit_dim_reshape_expansion_full
   return %3 : tensor<?x2x4xf32>
 }
 //      CHECK: func @unit_dim_reshape_expansion_full
-//  CHECK-DAG:   linalg.tensor_reshape
+//  CHECK-DAG:   linalg.tensor_collapse_shape
 //  CHECK-DAG:   linalg.init_tensor
 //      CHECK:   linalg.generic
 // CHECK-SAME:     ins(%{{.+}}, %{{.+}} : tensor<?x2x4xf32>, tensor<?x2x4xf32>)
@@ -491,7 +491,7 @@ func @unit_dim_reshape_expansion_full
 //         FOLDUNITDIM: func @unit_dim_reshape_expansion_full
 //    FOLDUNITDIM-SAME:   %[[ARG0:.+]]: tensor<1x?x1x2x1x4xf32>
 //    FOLDUNITDIM-SAME:   %[[ARG1:.+]]: tensor<?x2x4xf32>
-//     FOLDUNITDIM-DAG:   %[[RESHAPE:.+]] = linalg.tensor_reshape %[[ARG1]]
+//     FOLDUNITDIM-DAG:   %[[RESHAPE:.+]] = linalg.tensor_expand_shape %[[ARG1]]
 //         FOLDUNITDIM:   linalg.generic
 //    FOLDUNITDIM-SAME:     ins(%[[ARG0]], %[[RESHAPE]] : tensor<1x?x1x2x1x4xf32>, tensor<1x?x1x2x1x4xf32>)
 //    FOLDUNITDIM-SAME:     outs(%{{.+}} : tensor<1x?x1x2x1x4xf32>)
diff  --git a/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir
index 15fc2b5de0ae2..e334b46d6b5d4 100644
--- a/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir
@@ -3,7 +3,7 @@
 #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xi32>)
   -> tensor<?x?x4x?xi32> {
-  %0 = linalg.tensor_reshape %arg0 [[0], [1, 2], [3]] :
+  %0 = linalg.tensor_expand_shape %arg0 [[0], [1, 2], [3]] :
     tensor<?x?x?xi32> into tensor<?x?x4x?xi32>
   %1 = linalg.generic {
     indexing_maps = [#map0, #map0],
@@ -22,7 +22,7 @@ func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xi32>)
 //   CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 //       CHECK: func @generic_op_reshape_producer_fusion
 //  CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?xi32>
-//       CHECK:   %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
+//       CHECK:   %[[T0:.+]] = linalg.tensor_expand_shape %[[ARG0]]
 //  CHECK-SAME:     [0], [1, 2], [3]
 //       CHECK:   linalg.generic
 //  CHECK-SAME:     indexing_maps = [#[[MAP3]], #[[MAP4]]]
@@ -46,7 +46,7 @@ func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xi32>)
     %3 = addi %arg6, %2 : i32
     linalg.yield %3 : i32
   } -> tensor<?x?x4x5xi32>
-  %1 = linalg.tensor_reshape %0 [[0], [1, 2, 3]] :
+  %1 = linalg.tensor_collapse_shape %0 [[0], [1, 2, 3]] :
     tensor<?x?x4x5xi32> into tensor<?x?xi32>
   return %1 : tensor<?x?xi32>
 }
@@ -54,21 +54,21 @@ func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xi32>)
 //   CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>
 //       CHECK: func @generic_op_reshape_consumer_fusion
 //  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x4x5xi32>
-//       CHECK:   %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
+//       CHECK:   %[[T0:.+]] = linalg.tensor_collapse_shape %[[ARG0]]
 //  CHECK-SAME:     [0], [1, 2, 3]
 //       CHECK:   linalg.generic
 //  CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP3]]]
 //  CHECK-SAME:     outs(%[[T0]] : tensor<?x?xi32>)
 //       CHECK:   %[[IDX:.+]] = linalg.index 0 : index
 //  CHECK-NEXT:   %[[IDX_CASTED:.+]] = index_cast %[[IDX]] : index to i32
-//   CHECK-NOT:   linalg.tensor_reshape
+//   CHECK-NOT:   linalg.tensor_collapse_shape
 
 // -----
 
 #map2 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
 #map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 func @generic_op_021_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<3x7x5xf32> {
-  %0 = linalg.tensor_reshape %arg0 [[0], [1, 2]]
+  %0 = linalg.tensor_expand_shape %arg0 [[0], [1, 2]]
       : tensor<3x35xf32> into tensor<3x5x7xf32>
   %1 = linalg.init_tensor [3, 7, 5] : tensor<3x7x5xf32>
   %2 = linalg.generic
@@ -84,7 +84,7 @@ func @generic_op_021_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf3
 //   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 + d2 * 7)>
 //   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 //       CHECK: func @generic_op_021_permultation_reshape_producer_fusion
-//   CHECK-NOT:   linalg.tensor_reshape
+//   CHECK-NOT:   linalg.tensor_expand_shape
 //       CHECK:   linalg.generic
 //  CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]]]
 
@@ -93,7 +93,7 @@ func @generic_op_021_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf3
 #map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
 #map3 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
 func @generic_op_120_permutation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x7x3xf32> {
-  %0 = linalg.tensor_reshape %arg0 [[0], [1, 2]]
+  %0 = linalg.tensor_expand_shape %arg0 [[0], [1, 2]]
       : tensor<3x35xf32> into tensor<3x5x7xf32>
   %1 = linalg.init_tensor [5, 7, 3] : tensor<5x7x3xf32>
   %2 = linalg.generic
@@ -109,7 +109,7 @@ func @generic_op_120_permutation_reshape_producer_fusion(%arg0 : tensor<3x35xf32
 //   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)>
 //   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
 //       CHECK: func @generic_op_120_permutation_reshape_producer_fusion
-//   CHECK-NOT:   linalg.tensor_reshape
+//   CHECK-NOT:   linalg.tensor_expand_shape
 //       CHECK:   linalg.generic
 //  CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]]]
 
@@ -120,7 +120,7 @@ func @generic_op_120_permutation_reshape_producer_fusion(%arg0 : tensor<3x35xf32
 #map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
 #map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 func @generic_op_102_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x3x7xf32> {
-  %0 = linalg.tensor_reshape %arg0 [[0], [1, 2]]
+  %0 = linalg.tensor_expand_shape %arg0 [[0], [1, 2]]
       : tensor<3x35xf32> into tensor<3x5x7xf32>
   %1 = linalg.init_tensor [5, 3, 7] : tensor<5x3x7xf32>
   %2 = linalg.generic
@@ -137,7 +137,7 @@ func @generic_op_102_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf3
 //   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)>
 //   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 //       CHECK: func @generic_op_102_permultation_reshape_producer_fusion
-//   CHECK-NOT:   linalg.tensor_reshape
+//   CHECK-NOT:   linalg.tensor_expand_shape
 //       CHECK:   linalg.generic
 //  CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]]]
 
@@ -156,7 +156,7 @@ func @generic_op_102_permultation_reshape_consumer_fusion(%arg0 : tensor<3x5x7xf
     ^bb0(%arg2: f32, %arg3 : f32):  // no predecessors
       linalg.yield %arg2 : f32
   } -> tensor<5x3x7xf32>
-  %2 = linalg.tensor_reshape %1 [[0], [1, 2]]
+  %2 = linalg.tensor_collapse_shape %1 [[0], [1, 2]]
       : tensor<5x3x7xf32> into tensor<5x21xf32>
   return %2 : tensor<5x21xf32>
 }
@@ -165,7 +165,7 @@ func @generic_op_102_permultation_reshape_consumer_fusion(%arg0 : tensor<3x5x7xf
 //       CHECK: func @generic_op_102_permultation_reshape_consumer_fusion
 //  CHECK-SAME:   %[[ARG0:.+]]: tensor<3x5x7xf32>
 //       CHECK:   %[[T0:.+]] = linalg.init_tensor [5, 3, 7]
-//       CHECK:   %[[T1:.+]] = linalg.tensor_reshape %[[T0]]
+//       CHECK:   %[[T1:.+]] = linalg.tensor_collapse_shape %[[T0]]
 //  CHECK-SAME:     [0], [1, 2]
 //       CHECK:   linalg.generic
 //  CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP3]]]
@@ -188,7 +188,7 @@ func @generic_op_reshape_consumer_nofusion(%arg0 : tensor<?x?x?x5xf32>,
       %1 = mulf %arg3, %arg4 : f32
       linalg.yield %1 : f32
   } -> tensor<?x?x?x5xf32>
-  %1 = linalg.tensor_reshape %0 [[0], [1, 2, 3]] :
+  %1 = linalg.tensor_collapse_shape %0 [[0], [1, 2, 3]] :
     tensor<?x?x?x5xf32> into tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
@@ -197,5 +197,5 @@ func @generic_op_reshape_consumer_nofusion(%arg0 : tensor<?x?x?x5xf32>,
 //  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?x5xf32>
 //       CHECK:   %[[NOFUSE:.+]] = linalg.generic
 //  CHECK-SAME:     ins(%[[ARG0]], %[[ARG1]]
-//       CHECK:   %[[RESULT:.+]] = linalg.tensor_reshape %[[NOFUSE]]
+//       CHECK:   %[[RESULT:.+]] = linalg.tensor_collapse_shape %[[NOFUSE]]
 //       CHECK:   return %[[RESULT]]
diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 211e97855c662..9ca383a66cc13 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -563,92 +563,92 @@ func @indexed_generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?,
 func @reshape_static(%arg0: memref<3x4x5xf32>, %arg1: tensor<3x4x5xf32>,
                      %arg2: tensor<3x?x5xf32>) {
   // Reshapes that collapse and expand back a contiguous buffer.
-  %0 = linalg.reshape %arg0 [[0, 1], [2]] :
+  %0 = linalg.collapse_shape %arg0 [[0, 1], [2]] :
     memref<3x4x5xf32> into memref<12x5xf32>
-  %r0 = linalg.reshape %0 [[0, 1], [2]] :
+  %r0 = linalg.expand_shape %0 [[0, 1], [2]] :
     memref<12x5xf32> into memref<3x4x5xf32>
-  %1 = linalg.reshape %arg0 [[0], [1, 2]] :
+  %1 = linalg.collapse_shape %arg0 [[0], [1, 2]] :
     memref<3x4x5xf32> into memref<3x20xf32>
-  %r1 = linalg.reshape %1 [[0], [1, 2]] :
+  %r1 = linalg.expand_shape %1 [[0], [1, 2]] :
     memref<3x20xf32> into memref<3x4x5xf32>
-  %2 = linalg.reshape %arg0 [[0, 1, 2]] :
+  %2 = linalg.collapse_shape %arg0 [[0, 1, 2]] :
     memref<3x4x5xf32> into memref<60xf32>
-  %r2 = linalg.reshape %2 [[0, 1, 2]] :
+  %r2 = linalg.expand_shape %2 [[0, 1, 2]] :
     memref<60xf32> into memref<3x4x5xf32>
   // Reshapes that expand and collapse back a contiguous buffer with some 1's.
-  %3 = linalg.reshape %arg0 [[0, 1], [2], [3, 4]] :
+  %3 = linalg.expand_shape %arg0 [[0, 1], [2], [3, 4]] :
     memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
-  %r3 = linalg.reshape %3 [[0, 1], [2], [3, 4]] :
+  %r3 = linalg.collapse_shape %3 [[0, 1], [2], [3, 4]] :
     memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
   // Reshapes on tensors.
-  %t0 = linalg.tensor_reshape %arg1 [[0, 1], [2], [3, 4]] :
+  %t0 = linalg.tensor_expand_shape %arg1 [[0, 1], [2], [3, 4]] :
     tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32>
-  %rt0 = linalg.tensor_reshape %t0 [[0, 1], [2], [3, 4]] :
+  %rt0 = linalg.tensor_collapse_shape %t0 [[0, 1], [2], [3, 4]] :
     tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32>
-  %t1 = linalg.tensor_reshape %arg2 [[0, 1], [2], [3, 4]] :
+  %t1 = linalg.tensor_expand_shape %arg2 [[0, 1], [2], [3, 4]] :
     tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32>
-  %rt1 = linalg.tensor_reshape %t1 [[0], [1, 2], [3, 4]] :
+  %rt1 = linalg.tensor_collapse_shape %t1 [[0], [1, 2], [3, 4]] :
     tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32>
   return
 }
 // CHECK-LABEL: func @reshape_static
-//       CHECK:   linalg.reshape {{.*}} {{\[}}[0, 1], [2]]
+//       CHECK:   linalg.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
 //  CHECK-SAME:     memref<3x4x5xf32> into memref<12x5xf32>
-//       CHECK:   linalg.reshape {{.*}} {{\[}}[0, 1], [2]]
+//       CHECK:   linalg.expand_shape {{.*}} {{\[}}[0, 1], [2]]
 //  CHECK-SAME:     memref<12x5xf32> into memref<3x4x5xf32>
-//       CHECK:   linalg.reshape {{.*}} {{\[}}[0], [1, 2]]
+//       CHECK:   linalg.collapse_shape {{.*}} {{\[}}[0], [1, 2]]
 //  CHECK-SAME:     memref<3x4x5xf32> into memref<3x20xf32>
-//       CHECK:   linalg.reshape {{.*}} {{\[}}[0], [1, 2]]
+//       CHECK:   linalg.expand_shape {{.*}} {{\[}}[0], [1, 2]]
 //  CHECK-SAME:     memref<3x20xf32> into memref<3x4x5xf32>
-//       CHECK:   linalg.reshape {{.*}} {{\[}}[0, 1, 2]]
+//       CHECK:   linalg.collapse_shape {{.*}} {{\[}}[0, 1, 2]]
 //  CHECK-SAME:     memref<3x4x5xf32> into memref<60xf32>
-//       CHECK:   linalg.reshape {{.*}} {{\[}}[0, 1, 2]]
+//       CHECK:   linalg.expand_shape {{.*}} {{\[}}[0, 1, 2]]
 //  CHECK-SAME:     memref<60xf32> into memref<3x4x5xf32>
-//       CHECK:   linalg.reshape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
+//       CHECK:   linalg.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
 //  CHECK-SAME:     memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
-//       CHECK:   linalg.reshape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
+//       CHECK:   linalg.collapse_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
 //  CHECK-SAME:     memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
 //
-//       CHECK:   linalg.tensor_reshape {{.*}}: tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32>
-//       CHECK:   linalg.tensor_reshape {{.*}}: tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32>
-//       CHECK:   linalg.tensor_reshape {{.*}}: tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32>
-//       CHECK:   linalg.tensor_reshape {{.*}}: tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32>
+//       CHECK:   linalg.tensor_expand_shape {{.*}}: tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32>
+//       CHECK:   linalg.tensor_collapse_shape {{.*}}: tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32>
+//       CHECK:   linalg.tensor_expand_shape {{.*}}: tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32>
+//       CHECK:   linalg.tensor_collapse_shape {{.*}}: tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32>
 
 // -----
 
 func @reshape_dynamic(%arg0: memref<?x?x?xf32>,
                       %arg1: memref<?x?x?xf32, offset : 0, strides : [?, ?, 1]>,
                       %arg2: memref<?x?x?xf32, offset : ?, strides : [?, ?, 1]>) {
-  %0 = linalg.reshape %arg0 [[0, 1], [2]] :
+  %0 = linalg.collapse_shape %arg0 [[0, 1], [2]] :
     memref<?x?x?xf32> into memref<?x?xf32>
-  %r0 = linalg.reshape %0 [[0, 1], [2]] :
+  %r0 = linalg.expand_shape %0 [[0, 1], [2]] :
     memref<?x?xf32> into memref<?x4x?xf32>
-  %1 = linalg.reshape %arg1 [[0, 1], [2]] :
+  %1 = linalg.collapse_shape %arg1 [[0, 1], [2]] :
     memref<?x?x?xf32, offset : 0, strides : [?, ?, 1]> into
     memref<?x?xf32, offset : 0, strides : [?, 1]>
-  %r1 = linalg.reshape %1 [[0, 1], [2]] :
+  %r1 = linalg.expand_shape %1 [[0, 1], [2]] :
     memref<?x?xf32, offset : 0, strides : [?, 1]> into
     memref<?x4x?xf32, offset : 0, strides : [?, ?, 1]>
-  %2 = linalg.reshape %arg2 [[0, 1], [2]] :
+  %2 = linalg.collapse_shape %arg2 [[0, 1], [2]] :
     memref<?x?x?xf32, offset : ?, strides : [?, ?, 1]> into
     memref<?x?xf32, offset : ?, strides : [?, 1]>
-  %r2 = linalg.reshape %2 [[0, 1], [2]] :
+  %r2 = linalg.expand_shape %2 [[0, 1], [2]] :
     memref<?x?xf32, offset : ?, strides : [?, 1]> into
     memref<?x4x?xf32, offset : ?, strides : [?, ?, 1]>
   return
 }
 // CHECK-LABEL: func @reshape
-//       CHECK:   linalg.reshape {{.*}} {{\[}}[0, 1], [2]]
+//       CHECK:   linalg.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
 //  CHECK-SAME:     memref<?x?x?xf32> into memref<?x?xf32>
-//       CHECK:   linalg.reshape {{.*}} {{\[}}[0, 1], [2]]
+//       CHECK:   linalg.expand_shape {{.*}} {{\[}}[0, 1], [2]]
 //  CHECK-SAME:     memref<?x?xf32> into memref<?x4x?xf32>
-//       CHECK:   linalg.reshape {{.*}} {{\[}}[0, 1], [2]]
+//       CHECK:   linalg.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
 //  CHECK-SAME:     memref<?x?x?xf32, #[[$strided3DOFF0]]> into memref<?x?xf32, #[[$strided2DOFF0]]>
-//       CHECK:   linalg.reshape {{.*}} {{\[}}[0, 1], [2]]
+//       CHECK:   linalg.expand_shape {{.*}} {{\[}}[0, 1], [2]]
 //  CHECK-SAME:     memref<?x?xf32, #[[$strided2DOFF0]]> into memref<?x4x?xf32, #[[$strided3DOFF0]]>
-//       CHECK:   linalg.reshape {{.*}} {{\[}}[0, 1], [2]]
+//       CHECK:   linalg.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
 //  CHECK-SAME:     memref<?x?x?xf32, #[[$strided3D]]> into memref<?x?xf32, #[[$strided2D]]>
-//       CHECK:   linalg.reshape {{.*}} {{\[}}[0, 1], [2]]
+//       CHECK:   linalg.expand_shape {{.*}} {{\[}}[0, 1], [2]]
 //  CHECK-SAME:     memref<?x?xf32, #[[$strided2D]]> into memref<?x4x?xf32, #[[$strided3D]]>
 
 func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?x?xf32>, %c3: memref<?x?x?xf32>,
@@ -679,25 +679,25 @@ func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?x?xf32>, %c3: memref<?x?x
 
 func @tensor_reshape_zero_dim(%arg0 : tensor<1x1xf32>, %arg1 : tensor<f32>) -> (tensor<f32>, tensor<1x1xf32>)
 {
-  %0 = linalg.tensor_reshape %arg0 [] : tensor<1x1xf32> into tensor<f32>
-  %1 = linalg.tensor_reshape %0 [] : tensor<f32> into tensor<1x1xf32>
+  %0 = linalg.tensor_collapse_shape %arg0 [] : tensor<1x1xf32> into tensor<f32>
+  %1 = linalg.tensor_expand_shape %0 [] : tensor<f32> into tensor<1x1xf32>
   return %0, %1 : tensor<f32>, tensor<1x1xf32>
 }
 // CHECK-LABEL: func @tensor_reshape_zero_dim
-//       CHECK:   linalg.tensor_reshape %{{.*}} [] : tensor<1x1xf32> into tensor<f32>
-//       CHECK:   linalg.tensor_reshape %{{.*}} [] : tensor<f32> into tensor<1x1xf32>
+//       CHECK:   linalg.tensor_collapse_shape %{{.*}} [] : tensor<1x1xf32> into tensor<f32>
+//       CHECK:   linalg.tensor_expand_shape %{{.*}} [] : tensor<f32> into tensor<1x1xf32>
 
 // -----
 
 func @memref_reshape_zero_dim(%arg0 : memref<1x1xf32>, %arg1 : memref<f32>) -> (memref<f32>, memref<1x1xf32>)
 {
-  %0 = linalg.reshape %arg0 [] : memref<1x1xf32> into memref<f32>
-  %1 = linalg.reshape %0 [] : memref<f32> into memref<1x1xf32>
+  %0 = linalg.collapse_shape %arg0 [] : memref<1x1xf32> into memref<f32>
+  %1 = linalg.expand_shape %0 [] : memref<f32> into memref<1x1xf32>
   return %0, %1 : memref<f32>, memref<1x1xf32>
 }
 // CHECK-LABEL: func @memref_reshape_zero_dim
-//       CHECK:   linalg.reshape %{{.*}} [] : memref<1x1xf32> into memref<f32>
-//       CHECK:   linalg.reshape %{{.*}} [] : memref<f32> into memref<1x1xf32>
+//       CHECK:   linalg.collapse_shape %{{.*}} [] : memref<1x1xf32> into memref<f32>
+//       CHECK:   linalg.expand_shape %{{.*}} [] : memref<f32> into memref<1x1xf32>
 
 // -----
 
@@ -716,12 +716,12 @@ func @init_tensor(%arg0 : index, %arg1 : index)
 func @legal_collapsing_reshape_dynamic_tensor
   (%arg0: tensor<?x?x?x4x?xf32>) -> tensor<?x?x?xf32>
 {
-  %0 = linalg.tensor_reshape %arg0 [[0], [1], [2, 3, 4]] :
+  %0 = linalg.tensor_collapse_shape %arg0 [[0], [1], [2, 3, 4]] :
     tensor<?x?x?x4x?xf32> into tensor<?x?x?xf32>
   return %0 : tensor<?x?x?xf32>
 }
 //      CHECK: func @legal_collapsing_reshape_dynamic_tensor
-//      CHECK:   linalg.tensor_reshape
+//      CHECK:   linalg.tensor_collapse_shape
 // CHECK-SAME:    [0], [1], [2, 3, 4]
 
 // -----
@@ -729,12 +729,12 @@ func @legal_collapsing_reshape_dynamic_tensor
 func @legal_collapsing_reshape_dynamic_memref
   (%arg0: memref<?x?x?x4x?xf32>) -> memref<?x?x?xf32>
 {
-  %0 = linalg.reshape %arg0 [[0], [1], [2, 3, 4]] :
+  %0 = linalg.collapse_shape %arg0 [[0], [1], [2, 3, 4]] :
     memref<?x?x?x4x?xf32> into memref<?x?x?xf32>
   return %0 : memref<?x?x?xf32>
 }
 //      CHECK: func @legal_collapsing_reshape_dynamic_memref
-//      CHECK:   linalg.reshape
+//      CHECK:   linalg.collapse_shape
 // CHECK-SAME:    [0], [1], [2, 3, 4]
 
 // -----
        
    
    
More information about the Mlir-commits
mailing list