[Mlir-commits] [mlir] 46ef86b - [mlir] Move linalg::Expand/CollapseShapeOp to memref dialect.

Alexander Belyaev llvmlistbot at llvm.org
Fri Jul 16 04:32:50 PDT 2021


Author: Alexander Belyaev
Date: 2021-07-16T13:32:17+02:00
New Revision: 46ef86b5d82ea8ec36de680f10b69d36487e4d1d

URL: https://github.com/llvm/llvm-project/commit/46ef86b5d82ea8ec36de680f10b69d36487e4d1d
DIFF: https://github.com/llvm/llvm-project/commit/46ef86b5d82ea8ec36de680f10b69d36487e4d1d.diff

LOG: [mlir] Move linalg::Expand/CollapseShapeOp to memref dialect.

RFC: https://llvm.discourse.group/t/rfc-reshape-ops-restructuring/3310

Differential Revision: https://reviews.llvm.org/D106141

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
    mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
    mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
    mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
    mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
    mlir/test/Dialect/Linalg/bufferize.mlir
    mlir/test/Dialect/Linalg/canonicalize.mlir
    mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
    mlir/test/Dialect/Linalg/invalid.mlir
    mlir/test/Dialect/Linalg/llvm.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir
    mlir/test/Dialect/MemRef/canonicalize.mlir
    mlir/test/Dialect/MemRef/invalid.mlir
    mlir/test/Dialect/MemRef/ops.mlir
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index cb0507d6f903..8d14880c148b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -396,91 +396,6 @@ class Linalg_ReshapeLikeOp<string mnemonic, list<OpTrait> traits = []> :
 def IndexListArrayAttr :
   TypedArrayAttrBase<I64ArrayAttr, "Array of 64-bit integer array attributes">;
 
-class Linalg_ReshapeOp<string mnemonic> : Linalg_ReshapeLikeOp<mnemonic,
-    [DeclareOpInterfaceMethods<ViewLikeOpInterface>]>,
-    Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>,
-    Results<(outs AnyStridedMemRef:$result)> {
-  let extraClassDeclaration = commonExtraClassDeclaration # [{
-    MemRefType getSrcType() { return src().getType().cast<MemRefType>(); }
-    MemRefType getResultType() { return result().getType().cast<MemRefType>(); }
-  }];
-  let hasFolder = 1;
-  let hasCanonicalizer = 1;
-  let printer = [{ return ::print(p, *this); }];
-}
-
-def Linalg_ExpandShapeOp : Linalg_ReshapeOp<"expand_shape"> {
-  let summary = "operation to produce a memref with a higher rank.";
-  let description = [{
-    The `linalg.expand_shape` op produces a new view with a higher rank whose
-    sizes are a reassociation of the original `view`. Depending on whether or
-    not the reassociated MemRefType is contiguous, the resulting memref may
-    require explicit alloc and copies.
-
-    A reassociation is defined as a continuous grouping of dimensions and is
-    represented with an array of I64ArrayAttr attribute.
-
-    For now, it is assumed that either:
-      1. a reassociation produces and consumes contiguous MemRefType or,
-      2. the reshape op will be folded into its consumers (by changing the shape
-         of the computations).
-    All other cases are undefined behavior and a reshape op may not lower to
-    LLVM if it cannot be proven statically that it does not require alloc+copy.
-
-    The operand memref type when dimensions can be zero-ranked if the result
-    memref type is statically shaped with all dimensions being unit extent. In
-    such case the reassociation map is empty.
-
-    The verification rule is that the reassociation maps are applied to the
-    result memref with the larger rank to obtain the operand memref with the
-    smaller rank.
-
-    Example:
-
-    ```mlir
-    // Dimension expansion i -> (i', j') and (k) -> (k')
-    %1 = linalg.expand_shape %0 [[0, 1], [2]] :
-      memref<?x?xf32, stride_spec> into memref<?x?x?xf32, stride_spec_2>
-    ```
-  }];
-}
-
-def Linalg_CollapseShapeOp : Linalg_ReshapeOp<"collapse_shape"> {
-  let summary = "operation to produce a memref with a smaller rank.";
-  let description = [{
-    The `linalg.collapse_shape` op produces a new view with a smaller rank
-    whose sizes are a reassociation of the original `view`. Depending on
-    whether or not the reassociated MemRefType is contiguous, the resulting
-    memref may require explicit alloc and copies.
-
-    A reassociation is defined as a continuous grouping of dimensions and is
-    represented with an array of I64ArrayAttr attribute.
-
-    For now, it is assumed that either:
-      1. a reassociation produces and consumes contiguous MemRefType or,
-      2. the reshape op will be folded into its consumers (by changing the shape
-         of the computations).
-    All other cases are undefined behavior and a reshape op may not lower to
-    LLVM if it cannot be proven statically that it does not require alloc+copy.
-
-    The result memref type of a reshape can be zero-ranked if the operand
-    memref type is statically shaped with all dimensions being unit extent. In
-    such case the reassociation map is empty.
-
-    The verification rule is that the reassociation maps are applied to the
-    operand memref with the larger rank to obtain the result memref with the
-    smaller rank.
-
-    Examples:
-
-    ```mlir
-    // Dimension collapse (i, j) -> i' and k -> k'
-    %1 = linalg.collapse_shape %0 [[0, 1], [2]] :
-      memref<?x?x?xf32, stride_spec> into memref<?x?xf32, stride_spec_2>
-    ```
-  }];
-}
-
 class Linalg_TensorReshapeOp<string mnemonic> : Linalg_ReshapeLikeOp<
     mnemonic,
     [DeclareOpInterfaceMethods<InferShapedTypeOpInterface,

diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
index c0f3c3445823..8c61248567f6 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
@@ -10,6 +10,7 @@
 #define MLIR_DIALECT_MEMREF_IR_MEMREF_H_
 
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/Interfaces/CallInterfaces.h"
 #include "mlir/Interfaces/CastInterfaces.h"

diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index e60cf298d36e..3385a7e8e78c 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -203,7 +203,7 @@ def MemRef_AllocaOp : AllocLikeOp<"alloca", AutomaticAllocationScopeResource> {
 // AllocaScopeOp
 //===----------------------------------------------------------------------===//
 
-def MemRef_AllocaScopeOp : MemRef_Op<"alloca_scope", 
+def MemRef_AllocaScopeOp : MemRef_Op<"alloca_scope",
       [AutomaticAllocationScope,
        DeclareOpInterfaceMethods<RegionBranchOpInterface>,
        SingleBlockImplicitTerminator<"AllocaScopeReturnOp">,
@@ -225,8 +225,8 @@ def MemRef_AllocaScopeOp : MemRef_Op<"alloca_scope",
 
     Here, `%myalloca` memref is valid within the explicitly delimited scope
     and is automatically deallocated at the end of the given region. Conceptually,
-    `memref.alloca_scope` is a passthrough operation with 
-    `AutomaticAllocationScope` that spans the body of the region within the operation. 
+    `memref.alloca_scope` is a passthrough operation with
+    `AutomaticAllocationScope` that spans the body of the region within the operation.
 
     `memref.alloca_scope` may also return results that are defined in the nested
     region. To return a value, one should use `memref.alloca_scope.return`
@@ -251,14 +251,14 @@ def MemRef_AllocaScopeOp : MemRef_Op<"alloca_scope",
 // AllocaScopeReturnOp
 //===----------------------------------------------------------------------===//
 
-def MemRef_AllocaScopeReturnOp : MemRef_Op<"alloca_scope.return", 
+def MemRef_AllocaScopeReturnOp : MemRef_Op<"alloca_scope.return",
       [HasParent<"AllocaScopeOp">,
        NoSideEffect,
        ReturnLike,
        Terminator]> {
   let summary = "terminator for alloca_scope operation";
   let description = [{
-    `memref.alloca_scope.return` operation returns zero or more SSA values 
+    `memref.alloca_scope.return` operation returns zero or more SSA values
     from the region within `memref.alloca_scope`. If no values are returned,
     the return operation may be omitted. Otherwise, it has to be present
     to indicate which values are going to be returned. For example:
@@ -927,6 +927,150 @@ def MemRef_ReshapeOp: MemRef_Op<"reshape", [
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// ExpandShapeOp / CollapseShapeOp
+//===----------------------------------------------------------------------===//
+
+def IndexListArrayAttr :
+  TypedArrayAttrBase<I64ArrayAttr, "Array of 64-bit integer array attributes">;
+
+class MemRef_ReassociativeReshapeOp<string mnemonic, list<OpTrait> traits = []> :
+    MemRef_Op<mnemonic, !listconcat(traits,
+      [NoSideEffect, ViewLikeOpInterface])>,
+    Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>,
+    Results<(outs AnyStridedMemRef:$result)>{
+  let builders = [
+    // Builders for a contracting reshape whose result type is computed from
+    // `src` and `reassociation`.
+    OpBuilder<(ins "Value":$src,
+      "ArrayRef<ReassociationIndices>":$reassociation,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+    OpBuilder<(ins "Value":$src,
+      "ArrayRef<ReassociationExprs>":$reassociation,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+    [{
+      auto reassociationMaps =
+          convertReassociationMapsToIndices($_builder, reassociation);
+      build($_builder, $_state, src, reassociationMaps, attrs);
+    }]>,
+
+    // Builders for a reshape whose result type is passed explicitly. This may
+    // be either a contracting or expanding reshape.
+    OpBuilder<(ins "Type":$resultType, "Value":$src,
+      "ArrayRef<ReassociationIndices>":$reassociation,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+    [{
+      build($_builder, $_state, resultType, src, attrs);
+      $_state.addAttribute("reassociation",
+                          getReassociationIndicesAttribute($_builder, reassociation));
+    }]>,
+    OpBuilder<(ins "Type":$resultType, "Value":$src,
+      "ArrayRef<ReassociationExprs>":$reassociation,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+    [{
+      auto reassociationMaps =
+          convertReassociationMapsToIndices($_builder, reassociation);
+      build($_builder, $_state, resultType, src, reassociationMaps, attrs);
+    }]>
+  ];
+
+  code commonExtraClassDeclaration = [{
+    SmallVector<AffineMap, 4> getReassociationMaps();
+    SmallVector<ReassociationExprs, 4> getReassociationExprs();
+    SmallVector<ReassociationIndices, 4> getReassociationIndices() {
+      SmallVector<ReassociationIndices, 4> reassociationIndices;
+      for (auto attr : reassociation())
+        reassociationIndices.push_back(llvm::to_vector<2>(
+            llvm::map_range(attr.cast<ArrayAttr>(), [&](Attribute indexAttr) {
+              return indexAttr.cast<IntegerAttr>().getInt();
+            })));
+      return reassociationIndices;
+    };
+    MemRefType getSrcType() { return src().getType().cast<MemRefType>(); }
+    MemRefType getResultType() { return result().getType().cast<MemRefType>(); }
+    Value getViewSource() { return src(); }
+  }];
+
+  let hasFolder = 1;
+  let hasCanonicalizer = 1;
+  let printer = [{ return ::print(p, *this); }];
+  let parser = [{ return ::parseReshapeLikeOp(parser, result); }];
+}
+
+def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape"> {
+  let summary = "operation to produce a memref with a higher rank.";
+  let description = [{
+    The `memref.expand_shape` op produces a new view with a higher rank whose
+    sizes are a reassociation of the original `view`. Depending on whether or
+    not the reassociated MemRefType is contiguous, the resulting memref may
+    require explicit alloc and copies.
+
+    A reassociation is defined as a continuous grouping of dimensions and is
+    represented with an array of I64ArrayAttr attribute.
+
+    For now, it is assumed that either:
+      1. a reassociation produces and consumes contiguous MemRefType or,
+      2. the reshape op will be folded into its consumers (by changing the shape
+         of the computations).
+    All other cases are undefined behavior and a reshape op may not lower to
+    LLVM if it cannot be proven statically that it does not require alloc+copy.
+
+    The operand memref type when dimensions can be zero-ranked if the result
+    memref type is statically shaped with all dimensions being unit extent. In
+    such case the reassociation map is empty.
+
+    The verification rule is that the reassociation maps are applied to the
+    result memref with the larger rank to obtain the operand memref with the
+    smaller rank.
+
+    Example:
+
+    ```mlir
+    // Dimension expansion i -> (i', j') and (k) -> (k')
+    %1 = memref.expand_shape %0 [[0, 1], [2]] :
+      memref<?x?xf32, stride_spec> into memref<?x?x?xf32, stride_spec_2>
+    ```
+  }];
+  let extraClassDeclaration = commonExtraClassDeclaration;
+}
+
+def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape"> {
+  let summary = "operation to produce a memref with a smaller rank.";
+  let description = [{
+    The `memref.collapse_shape` op produces a new view with a smaller rank
+    whose sizes are a reassociation of the original `view`. Depending on
+    whether or not the reassociated MemRefType is contiguous, the resulting
+    memref may require explicit alloc and copies.
+
+    A reassociation is defined as a continuous grouping of dimensions and is
+    represented with an array of I64ArrayAttr attribute.
+
+    For now, it is assumed that either:
+      1. a reassociation produces and consumes contiguous MemRefType or,
+      2. the reshape op will be folded into its consumers (by changing the shape
+         of the computations).
+    All other cases are undefined behavior and a reshape op may not lower to
+    LLVM if it cannot be proven statically that it does not require alloc+copy.
+
+    The result memref type of a reshape can be zero-ranked if the operand
+    memref type is statically shaped with all dimensions being unit extent. In
+    such case the reassociation map is empty.
+
+    The verification rule is that the reassociation maps are applied to the
+    operand memref with the larger rank to obtain the result memref with the
+    smaller rank.
+
+    Examples:
+
+    ```mlir
+    // Dimension collapse (i, j) -> i' and k -> k'
+    %1 = memref.collapse_shape %0 [[0, 1], [2]] :
+      memref<?x?x?xf32, stride_spec> into memref<?x?xf32, stride_spec_2>
+    ```
+  }];
+  let extraClassDeclaration = commonExtraClassDeclaration;
+}
+
 //===----------------------------------------------------------------------===//
 // StoreOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 3a557a63fd26..57350c5b1406 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -1,4 +1,4 @@
-//===- RehshapeOpsUtils.h - Utilities used by reshape ops --*- C++ -*------===//
+//===- ReshapeOpsUtils.h - Utilities used by reshape ops --*- C++ -*------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -26,7 +26,7 @@ using ReassociationIndicesRef = ArrayRef<int64_t>;
 using ReassociationExprs = SmallVector<AffineExpr, 2>;
 
 /// Attribute name for the ArrayAttr which encodes reassociation indices.
-constexpr StringRef getReassociationAttrName();
+constexpr StringRef getReassociationAttrName() { return "reassociation"; }
 
 /// Compose reassociation maps that are used in pair of reshape ops where one
 /// is a producer and other is the consumer. Only valid to use this method when
@@ -45,6 +45,23 @@ Optional<SmallVector<ReassociationIndices>> composeReassociationIndices(
     ArrayRef<ReassociationIndices> consumerReassociations,
     MLIRContext *context);
 
+/// Convert reassociation indices to affine expressions.
+SmallVector<SmallVector<AffineExpr, 2>, 2> convertReassociationIndicesToExprs(
+    OpBuilder &b, ArrayRef<ReassociationIndices> reassociationIndices);
+
+/// Constructs affine maps out of Array<Array<AffineExpr>>.
+SmallVector<AffineMap, 4>
+getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation);
+
+/// Wraps a list of reassociations in an ArrayAttr.
+ArrayAttr
+getReassociationIndicesAttribute(OpBuilder &b,
+                                 ArrayRef<ReassociationIndices> reassociation);
+
+/// Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
+SmallVector<ReassociationIndices, 2> convertReassociationMapsToIndices(
+    OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs);
+
 /// Return the reassociations maps to use to reshape given the source type and
 /// the target type when possible. Return llvm::None when this computation
 /// failed.
@@ -78,7 +95,7 @@ void printReshapeOp(OpAsmPrinter &p, ReshapeLikeOp op) {
 
   p << "] ";
   p.printOptionalAttrDict(op->getAttrs(),
-                          /*elidedAttrs=*/{op.getReassociationAttrName()});
+                          /*elidedAttrs=*/{getReassociationAttrName()});
   p << ": " << op.src().getType() << " into " << op.getType();
 }
 

diff  --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index 9357407f5521..580bc059e14f 100644
--- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
@@ -93,48 +94,6 @@ class RangeOpConversion : public ConvertOpToLLVMPattern<RangeOp> {
   }
 };
 
-// ReshapeOp creates a new view descriptor of the proper rank.
-// For now, the only conversion supported is for target MemRef with static sizes
-// and strides.
-template <typename ReshapeOp>
-class ReshapeOpConversion : public ConvertOpToLLVMPattern<ReshapeOp> {
-public:
-  using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
-  using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
-
-  LogicalResult
-  matchAndRewrite(ReshapeOp reshapeOp, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    MemRefType dstType = reshapeOp.getResultType();
-
-    if (!dstType.hasStaticShape())
-      return failure();
-
-    int64_t offset;
-    SmallVector<int64_t, 4> strides;
-    auto res = getStridesAndOffset(dstType, strides, offset);
-    if (failed(res) || llvm::any_of(strides, [](int64_t val) {
-          return ShapedType::isDynamicStrideOrOffset(val);
-        }))
-      return failure();
-
-    ReshapeOpAdaptor adaptor(operands);
-    MemRefDescriptor baseDesc(adaptor.src());
-    Location loc = reshapeOp->getLoc();
-    auto desc =
-        MemRefDescriptor::undef(rewriter, reshapeOp->getLoc(),
-                                this->typeConverter->convertType(dstType));
-    desc.setAllocatedPtr(rewriter, loc, baseDesc.allocatedPtr(rewriter, loc));
-    desc.setAlignedPtr(rewriter, loc, baseDesc.alignedPtr(rewriter, loc));
-    desc.setOffset(rewriter, loc, baseDesc.offset(rewriter, loc));
-    for (auto en : llvm::enumerate(dstType.getShape()))
-      desc.setConstantSize(rewriter, loc, en.index(), en.value());
-    for (auto en : llvm::enumerate(strides))
-      desc.setConstantStride(rewriter, loc, en.index(), en.value());
-    rewriter.replaceOp(reshapeOp, {desc});
-    return success();
-  }
-};
 
 // YieldOp produces and LLVM::ReturnOp.
 class YieldOpConversion : public ConvertOpToLLVMPattern<linalg::YieldOp> {
@@ -153,9 +112,7 @@ class YieldOpConversion : public ConvertOpToLLVMPattern<linalg::YieldOp> {
 /// Populate the given list with patterns that convert from Linalg to LLVM.
 void mlir::populateLinalgToLLVMConversionPatterns(LLVMTypeConverter &converter,
                                                   RewritePatternSet &patterns) {
-  patterns.add<RangeOpConversion, ReshapeOpConversion<ExpandShapeOp>,
-               ReshapeOpConversion<CollapseShapeOp>, YieldOpConversion>(
-      converter);
+  patterns.add<RangeOpConversion, YieldOpConversion>(converter);
 
   // Populate the type conversions for the linalg types.
   converter.addConversion(
@@ -176,6 +133,7 @@ void ConvertLinalgToLLVMPass::runOnOperation() {
   RewritePatternSet patterns(&getContext());
   LLVMTypeConverter converter(&getContext());
   populateLinalgToLLVMConversionPatterns(converter, patterns);
+  populateMemRefToLLVMConversionPatterns(converter, patterns);
 
   LLVMConversionTarget target(getContext());
   target.addIllegalOp<RangeOp>();

diff  --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
index 6f422e5f629f..f47b8878a17c 100644
--- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
+++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
@@ -186,9 +186,7 @@ void ConvertLinalgToStandardPass::runOnOperation() {
   ConversionTarget target(getContext());
   target.addLegalDialect<AffineDialect, memref::MemRefDialect, scf::SCFDialect,
                          StandardOpsDialect>();
-  target.addLegalOp<ModuleOp, FuncOp, ReturnOp>();
-  target.addLegalOp<linalg::ExpandShapeOp, linalg::CollapseShapeOp,
-                    linalg::RangeOp>();
+  target.addLegalOp<ModuleOp, FuncOp, ReturnOp, linalg::RangeOp>();
   RewritePatternSet patterns(&getContext());
   populateLinalgToStandardConversionPatterns(patterns);
   if (failed(applyFullConversion(module, target, std::move(patterns))))

diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index a92c4069fcad..4d5fe3307a8a 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1000,6 +1000,49 @@ struct MemRefReshapeOpLowering
   }
 };
 
+// ReshapeOp creates a new view descriptor of the proper rank.
+// For now, the only conversion supported is for target MemRef with static sizes
+// and strides.
+template <typename ReshapeOp>
+class ReassociatingReshapeOpConversion
+    : public ConvertOpToLLVMPattern<ReshapeOp> {
+public:
+  using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
+  using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
+
+  LogicalResult
+  matchAndRewrite(ReshapeOp reshapeOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    MemRefType dstType = reshapeOp.getResultType();
+
+    if (!dstType.hasStaticShape())
+      return failure();
+
+    int64_t offset;
+    SmallVector<int64_t, 4> strides;
+    auto res = getStridesAndOffset(dstType, strides, offset);
+    if (failed(res) || llvm::any_of(strides, [](int64_t val) {
+          return ShapedType::isDynamicStrideOrOffset(val);
+        }))
+      return failure();
+
+    ReshapeOpAdaptor adaptor(operands);
+    MemRefDescriptor baseDesc(adaptor.src());
+    Location loc = reshapeOp->getLoc();
+    auto desc =
+        MemRefDescriptor::undef(rewriter, reshapeOp->getLoc(),
+                                this->typeConverter->convertType(dstType));
+    desc.setAllocatedPtr(rewriter, loc, baseDesc.allocatedPtr(rewriter, loc));
+    desc.setAlignedPtr(rewriter, loc, baseDesc.alignedPtr(rewriter, loc));
+    desc.setOffset(rewriter, loc, baseDesc.offset(rewriter, loc));
+    for (auto en : llvm::enumerate(dstType.getShape()))
+      desc.setConstantSize(rewriter, loc, en.index(), en.value());
+    for (auto en : llvm::enumerate(strides))
+      desc.setConstantStride(rewriter, loc, en.index(), en.value());
+    rewriter.replaceOp(reshapeOp, {desc});
+    return success();
+  }
+};
 /// Conversion pattern that transforms a subview op into:
 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
@@ -1355,6 +1398,8 @@ void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
       MemRefReinterpretCastOpLowering,
       MemRefReshapeOpLowering,
       PrefetchOpLowering,
+      ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
+      ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
       StoreOpLowering,
       SubViewOpLowering,
       TransposeOpLowering,

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 12fbb8ce839b..b95fc5da547b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/Matchers.h"
@@ -1103,14 +1104,6 @@ OpFoldResult PadTensorOp::fold(ArrayRef<Attribute>) {
 // ReshapeOp
 //===----------------------------------------------------------------------===//
 
-static void print(OpAsmPrinter &p, linalg::ExpandShapeOp op) {
-  ::mlir::printReshapeOp<linalg::ExpandShapeOp>(p, op);
-}
-
-static void print(OpAsmPrinter &p, linalg::CollapseShapeOp op) {
-  ::mlir::printReshapeOp<linalg::CollapseShapeOp>(p, op);
-}
-
 static void print(OpAsmPrinter &p, linalg::TensorExpandShapeOp op) {
   ::mlir::printReshapeOp<linalg::TensorExpandShapeOp>(p, op);
 }
@@ -1260,20 +1253,6 @@ convertReassociationIndicesToExprs(
   return reassociationMaps;
 }
 
-SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
-  return getSymbolLessAffineMaps(getReassociationExprs());
-}
-SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
-  OpBuilder b(this->getContext());
-  return convertReassociationIndicesToExprs(b, getReassociationIndices());
-}
-SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
-  return getSymbolLessAffineMaps(getReassociationExprs());
-}
-SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
-  OpBuilder b(this->getContext());
-  return convertReassociationIndicesToExprs(b, getReassociationIndices());
-}
 
 SmallVector<AffineMap, 4> TensorCollapseShapeOp::getReassociationMaps() {
   return getSymbolLessAffineMaps(getReassociationExprs());
@@ -1422,71 +1401,6 @@ getReassociationIndicesAttribute(OpBuilder &b,
   return b.getArrayAttr(reassociationAttr);
 }
 
-void mlir::linalg::ExpandShapeOp::build(
-    OpBuilder &b, OperationState &result, Value src,
-    ArrayRef<ReassociationIndices> reassociation,
-    ArrayRef<NamedAttribute> attrs) {
-  auto memRefType = src.getType().cast<MemRefType>();
-  auto resultType = computeReshapeCollapsedType(
-      memRefType, getSymbolLessAffineMaps(
-                      convertReassociationIndicesToExprs(b, reassociation)));
-  build(b, result, resultType, src, attrs);
-  result.addAttribute(getReassociationAttrName(),
-                      getReassociationIndicesAttribute(b, reassociation));
-}
-
-Value mlir::linalg::ExpandShapeOp::getViewSource() { return src(); }
-
-void mlir::linalg::CollapseShapeOp::build(
-    OpBuilder &b, OperationState &result, Value src,
-    ArrayRef<ReassociationIndices> reassociation,
-    ArrayRef<NamedAttribute> attrs) {
-  auto memRefType = src.getType().cast<MemRefType>();
-  auto resultType = computeReshapeCollapsedType(
-      memRefType, getSymbolLessAffineMaps(
-                      convertReassociationIndicesToExprs(b, reassociation)));
-  build(b, result, resultType, src, attrs);
-  result.addAttribute(getReassociationAttrName(),
-                      getReassociationIndicesAttribute(b, reassociation));
-}
-
-Value mlir::linalg::CollapseShapeOp::getViewSource() { return src(); }
-
-template <typename ReshapeOp,
-          bool isExpansion = std::is_same<ReshapeOp, ExpandShapeOp>::value>
-static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType,
-                                     MemRefType collapsedType) {
-  if (failed(
-          verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
-    return failure();
-  auto maps = op.getReassociationMaps();
-  MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps);
-  if (collapsedType != expectedType)
-    return op.emitOpError("expected collapsed type to be ")
-           << expectedType << ", but got " << collapsedType;
-  return success();
-}
-
-static LogicalResult verify(ExpandShapeOp op) {
-  return verifyReshapeOp(op, op.getResultType(), op.getSrcType());
-}
-
-void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                                MLIRContext *context) {
-  results.add<CollapseReshapeOps<ExpandShapeOp>,
-              CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>>(context);
-}
-
-static LogicalResult verify(CollapseShapeOp op) {
-  return verifyReshapeOp(op, op.getSrcType(), op.getResultType());
-}
-
-void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                                  MLIRContext *context) {
-  results.add<CollapseReshapeOps<CollapseShapeOp>,
-              CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>>(context);
-}
-
 //===----------------------------------------------------------------------===//
 // TensorReshapeOp
 //===----------------------------------------------------------------------===//
@@ -2433,16 +2347,6 @@ std::string mlir::linalg::generateLibraryCallName(Operation *op) {
 // TODO: Consider making all this boilerplate easy to autogenerate
 // with Tablegen. This seems a desirable property in the context of
 // OpInterfaces where a Linalg "named" op **isa** LinalgOp.
-OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
-  if (succeeded(foldMemRefCast(*this)))
-    return getResult();
-  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
-}
-OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
-  if (succeeded(foldMemRefCast(*this)))
-    return getResult();
-  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands);
-}
 OpFoldResult TensorExpandShapeOp::fold(ArrayRef<Attribute> operands) {
   return foldReshapeOp<TensorExpandShapeOp, TensorCollapseShapeOp>(*this,
                                                                    operands);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
index e4dcbde6f1da..b918b98a76b8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
@@ -155,8 +155,8 @@ class BufferizeTensorReshapeOp : public OpConversionPattern<TensorReshapeOp> {
 public:
   using OpConversionPattern<TensorReshapeOp>::OpConversionPattern;
   using ReshapeOp = typename std::conditional_t<
-      std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value, ExpandShapeOp,
-      CollapseShapeOp>;
+      std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value,
+      memref::ExpandShapeOp, memref::CollapseShapeOp>;
 
   LogicalResult
   matchAndRewrite(TensorReshapeOp op, ArrayRef<Value> operands,

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 865d4bf9227c..de847d1f0fe7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -352,7 +352,7 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
           convertAffineMapArrayToExprs(reassociationMap));
     }
     if (origResultType.isa<MemRefType>()) {
-      return rewriter.create<linalg::ExpandShapeOp>(
+      return rewriter.create<memref::ExpandShapeOp>(
           loc, origResultType, result,
           convertAffineMapArrayToExprs(reassociationMap));
     }
@@ -368,7 +368,7 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
     if (operandType == newInputOutputType)
       return operand;
     if (operandType.isa<MemRefType>()) {
-      return rewriter.create<linalg::CollapseShapeOp>(
+      return rewriter.create<memref::CollapseShapeOp>(
           loc, newInputOutputType, operand,
           convertAffineMapArrayToExprs(reassociationMap));
     }

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 518539376c9f..03fa871b1f7a 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1300,6 +1300,189 @@ static LogicalResult verify(ReinterpretCastOp op) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// Reassociative reshape ops
+//===----------------------------------------------------------------------===//
+
+SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
+  return getSymbolLessAffineMaps(getReassociationExprs());
+}
+SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
+  OpBuilder b(this->getContext());
+  return convertReassociationIndicesToExprs(b, getReassociationIndices());
+}
+
+SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
+  return getSymbolLessAffineMaps(getReassociationExprs());
+}
+SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
+  OpBuilder b(this->getContext());
+  return convertReassociationIndicesToExprs(b, getReassociationIndices());
+}
+
+static void print(OpAsmPrinter &p, ExpandShapeOp op) {
+  ::mlir::printReshapeOp<ExpandShapeOp>(p, op);
+}
+
+static void print(OpAsmPrinter &p, CollapseShapeOp op) {
+  ::mlir::printReshapeOp<CollapseShapeOp>(p, op);
+}
+
+/// Detect whether memref dims [dim, dim + extent) can be reshaped without
+/// copies.
+static bool isReshapableDimBand(unsigned dim, unsigned extent,
+                                ArrayRef<int64_t> sizes,
+                                ArrayRef<AffineExpr> strides) {
+  assert(sizes.size() == strides.size() && "mismatched ranks");
+  // off by 1 indexing to avoid out of bounds
+  //                       V
+  for (auto idx = dim, e = dim + extent; idx + 1 < e; ++idx) {
+    // Only bands of static shapes are reshapable. This is due to the fact that
+    // there is no relation between dynamic sizes and dynamic strides: we do not
+    // have enough information to know whether a "-1" size corresponds to the
+    // proper symbol in the AffineExpr of a stride.
+    if (ShapedType::isDynamic(sizes[dim + 1]))
+      return false;
+    // TODO: Refine this by passing the proper nDims and nSymbols so we can
+    // simplify on the fly and catch more reshapable cases.
+    if (strides[idx] != strides[idx + 1] * sizes[idx + 1])
+      return false;
+  }
+  return true;
+}
+
+/// Compute the MemRefType obtained by applying the `reassociation` (which is
+/// expected to be valid) to `type`.
+/// If `type` is Contiguous MemRefType, this always produce a contiguous
+/// MemRefType.
+static MemRefType
+computeReshapeCollapsedType(MemRefType type,
+                            ArrayRef<AffineMap> reassociation) {
+  auto sizes = type.getShape();
+  AffineExpr offset;
+  SmallVector<AffineExpr, 4> strides;
+  auto status = getStridesAndOffset(type, strides, offset);
+  (void)status;
+  assert(succeeded(status) && "expected strided memref");
+
+  SmallVector<int64_t, 4> newSizes;
+  newSizes.reserve(reassociation.size());
+  SmallVector<AffineExpr, 4> newStrides;
+  newStrides.reserve(reassociation.size());
+
+  // Use the fact that reassociation is valid to simplify the logic: only use
+  // each map's rank.
+  assert(isReassociationValid(reassociation) && "invalid reassociation");
+  unsigned currentDim = 0;
+  for (AffineMap m : reassociation) {
+    unsigned dim = m.getNumResults();
+    int64_t size = 1;
+    AffineExpr stride = strides[currentDim + dim - 1];
+    if (!isReshapableDimBand(currentDim, dim, sizes, strides)) {
+      size = ShapedType::kDynamicSize;
+      stride = AffineExpr();
+    } else {
+      for (unsigned d = 0; d < dim; ++d)
+        size *= sizes[currentDim + d];
+    }
+    newSizes.push_back(size);
+    newStrides.push_back(stride);
+    currentDim += dim;
+  }
+
+  // Early-exit: if `type` is contiguous, the result must be contiguous.
+  if (canonicalizeStridedLayout(type).getAffineMaps().empty())
+    return MemRefType::Builder(type).setShape(newSizes).setAffineMaps({});
+
+  // Convert back to int64_t because we don't have enough information to create
+  // new strided layouts from AffineExpr only. This corresponds to a case where
+  // copies may be necessary.
+  int64_t intOffset = ShapedType::kDynamicStrideOrOffset;
+  if (auto o = offset.dyn_cast<AffineConstantExpr>())
+    intOffset = o.getValue();
+  SmallVector<int64_t, 4> intStrides;
+  intStrides.reserve(strides.size());
+  for (auto stride : newStrides) {
+    if (auto cst = stride.dyn_cast_or_null<AffineConstantExpr>())
+      intStrides.push_back(cst.getValue());
+    else
+      intStrides.push_back(ShapedType::kDynamicStrideOrOffset);
+  }
+  auto layout =
+      makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext());
+  return canonicalizeStridedLayout(
+      MemRefType::Builder(type).setShape(newSizes).setAffineMaps({layout}));
+}
+
+void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src,
+                          ArrayRef<ReassociationIndices> reassociation,
+                          ArrayRef<NamedAttribute> attrs) {
+  auto memRefType = src.getType().cast<MemRefType>();
+  auto resultType = computeReshapeCollapsedType(
+      memRefType, getSymbolLessAffineMaps(
+                      convertReassociationIndicesToExprs(b, reassociation)));
+  build(b, result, resultType, src, attrs);
+  result.addAttribute(getReassociationAttrName(),
+                      getReassociationIndicesAttribute(b, reassociation));
+}
+
+void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
+                            ArrayRef<ReassociationIndices> reassociation,
+                            ArrayRef<NamedAttribute> attrs) {
+  auto memRefType = src.getType().cast<MemRefType>();
+  auto resultType = computeReshapeCollapsedType(
+      memRefType, getSymbolLessAffineMaps(
+                      convertReassociationIndicesToExprs(b, reassociation)));
+  build(b, result, resultType, src, attrs);
+  result.addAttribute(getReassociationAttrName(),
+                      getReassociationIndicesAttribute(b, reassociation));
+}
+
+template <typename ReshapeOp,
+          bool isExpansion = std::is_same<ReshapeOp, ExpandShapeOp>::value>
+static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType,
+                                     MemRefType collapsedType) {
+  if (failed(
+          verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
+    return failure();
+  auto maps = op.getReassociationMaps();
+  MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps);
+  if (collapsedType != expectedType)
+    return op.emitOpError("expected collapsed type to be ")
+           << expectedType << ", but got " << collapsedType;
+  return success();
+}
+
+static LogicalResult verify(ExpandShapeOp op) {
+  return verifyReshapeOp(op, op.getResultType(), op.getSrcType());
+}
+
+void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                                MLIRContext *context) {
+  results.add<CollapseReshapeOps<ExpandShapeOp>,
+              CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>>(context);
+}
+
+static LogicalResult verify(CollapseShapeOp op) {
+  return verifyReshapeOp(op, op.getSrcType(), op.getResultType());
+}
+
+void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                                  MLIRContext *context) {
+  results.add<CollapseReshapeOps<CollapseShapeOp>,
+              CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>>(context);
+}
+OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
+  if (succeeded(foldMemRefCast(*this)))
+    return getResult();
+  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
+}
+OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
+  if (succeeded(foldMemRefCast(*this)))
+    return getResult();
+  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands);
+}
+
 //===----------------------------------------------------------------------===//
 // ReshapeOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 4cd72e2c9ff3..0c7aa52848b3 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -15,8 +15,6 @@
 
 using namespace mlir;
 
-constexpr StringRef mlir::getReassociationAttrName() { return "reassociation"; }
-
 Optional<SmallVector<ReassociationIndices>>
 mlir::getReassociationIndicesForReshape(ShapedType sourceType,
                                         ShapedType targetType) {
@@ -183,6 +181,70 @@ Optional<SmallVector<ReassociationIndices>> mlir::composeReassociationIndices(
   return composedIndices;
 }
 
+SmallVector<SmallVector<AffineExpr, 2>, 2>
+mlir::convertReassociationIndicesToExprs(
+    OpBuilder &b, ArrayRef<ReassociationIndices> reassociationIndices) {
+  SmallVector<SmallVector<AffineExpr, 2>, 2> reassociationMaps;
+  for (const auto &indices : reassociationIndices) {
+    SmallVector<AffineExpr, 2> reassociationMap;
+    reassociationMap.reserve(indices.size());
+    for (int64_t index : indices)
+      reassociationMap.push_back(b.getAffineDimExpr(index));
+    reassociationMaps.push_back(std::move(reassociationMap));
+  }
+  return reassociationMaps;
+}
+
+template <typename AffineExprTy>
+unsigned getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays) {
+  unsigned pos = 0;
+  for (const auto &exprs : exprArrays) {
+    for (auto expr : exprs) {
+      expr.walk([&pos](AffineExpr e) {
+        if (auto d = e.dyn_cast<AffineExprTy>())
+          pos = std::max(pos, d.getPosition());
+      });
+    }
+  }
+  return pos;
+}
+
+ArrayAttr mlir::getReassociationIndicesAttribute(
+    OpBuilder &b, ArrayRef<ReassociationIndices> reassociation) {
+  SmallVector<Attribute, 4> reassociationAttr =
+      llvm::to_vector<4>(llvm::map_range(
+          reassociation, [&](ReassociationIndices indices) -> Attribute {
+            return b.getI64ArrayAttr(indices).cast<Attribute>();
+          }));
+  return b.getArrayAttr(reassociationAttr);
+}
+
+SmallVector<ReassociationIndices, 2> mlir::convertReassociationMapsToIndices(
+    OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs) {
+  SmallVector<ReassociationIndices, 2> reassociationIndices;
+  for (const auto &exprs : reassociationExprs) {
+    ReassociationIndices indices;
+    indices.reserve(exprs.size());
+    for (const auto &expr : exprs)
+      indices.push_back(expr.cast<AffineDimExpr>().getPosition());
+    reassociationIndices.push_back(indices);
+  }
+  return reassociationIndices;
+}
+
+SmallVector<AffineMap, 4>
+mlir::getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation) {
+  unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation);
+  assert(getMaxPosOfType<AffineSymbolExpr>(reassociation) == 0 &&
+         "Expected symbol-less expressions");
+  SmallVector<AffineMap, 4> maps;
+  maps.reserve(reassociation.size());
+  for (const auto &exprs : reassociation) {
+    assert(!exprs.empty());
+    maps.push_back(AffineMap::get(maxDim + 1, 0, exprs, exprs[0].getContext()));
+  }
+  return maps;
+}
 bool mlir::isReassociationValid(ArrayRef<AffineMap> reassociation,
                                 int *invalidIndex) {
   if (reassociation.empty())

diff  --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 2c4d386a9b3f..aa372355f5f6 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -698,3 +698,105 @@ func @get_gv3_memref() {
   return
 }
 
+// -----
+
+func @expand_shape_static(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> {
+  // Reshapes that expand a contiguous tensor with some 1's.
+  %0 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]]
+      : memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
+  return %0 : memref<1x3x4x1x5xf32>
+}
+// CHECK-LABEL: func @expand_shape_static
+//       CHECK:    llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+//       CHECK:    llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+//       CHECK:    llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+//       CHECK:    llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+//       CHECK:    llvm.mlir.constant(1 : index) : i64
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+//       CHECK:    llvm.mlir.constant(3 : index) : i64
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+//       CHECK:    llvm.mlir.constant(4 : index) : i64
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+//       CHECK:    llvm.mlir.constant(1 : index) : i64
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[3, 3] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+//       CHECK:    llvm.mlir.constant(5 : index) : i64
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[3, 4] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+//       CHECK:    llvm.mlir.constant(60 : index) : i64
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+//       CHECK:    llvm.mlir.constant(20 : index) : i64
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+//       CHECK:    llvm.mlir.constant(5 : index) : i64
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+//       CHECK:    llvm.mlir.constant(5 : index) : i64
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[4, 3] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+//       CHECK:    llvm.mlir.constant(1 : index) : i64
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[4, 4] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+
+// -----
+
+func @collapse_shape_static(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> {
+  %0 = memref.collapse_shape %arg0 [[0, 1], [2], [3, 4]] :
+    memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
+  return %0 : memref<3x4x5xf32>
+}
+// CHECK-LABEL: func @collapse_shape_static
+//       CHECK:    llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+//       CHECK:    llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+//       CHECK:    llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+//       CHECK:    llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+//       CHECK:    llvm.mlir.constant(3 : index) : i64
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+//       CHECK:    llvm.mlir.constant(4 : index) : i64
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+//       CHECK:    llvm.mlir.constant(5 : index) : i64
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+//       CHECK:    llvm.mlir.constant(20 : index) : i64
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+//       CHECK:    llvm.mlir.constant(5 : index) : i64
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+//       CHECK:    llvm.mlir.constant(1 : index) : i64
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+
+// -----
+
+func @collapse_shape_fold_zero_dim(%arg0 : memref<1x1xf32>) -> memref<f32> {
+  %0 = memref.collapse_shape %arg0 [] : memref<1x1xf32> into memref<f32>
+  return %0 : memref<f32>
+}
+// CHECK-LABEL: func @collapse_shape_fold_zero_dim
+//       CHECK:   llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+//       CHECK:   llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+//       CHECK:   llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+//       CHECK:   llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+
+// -----
+
+func @expand_shape_zero_dim(%arg0 : memref<f32>) -> memref<1x1xf32> {
+  %0 = memref.expand_shape %arg0 [] : memref<f32> into memref<1x1xf32>
+  return %0 : memref<1x1xf32>
+}
+// CHECK-LABEL: func @expand_shape_zero_dim
+//       CHECK:   llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+//       CHECK:   llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+//       CHECK:   llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+//       CHECK:   llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+//       CHECK:   llvm.mlir.constant(1 : index) : i64
+//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+//       CHECK:   llvm.mlir.constant(1 : index) : i64
+//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+//       CHECK:   llvm.mlir.constant(1 : index) : i64
+//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+//       CHECK:   llvm.mlir.constant(1 : index) : i64
+//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>

diff  --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index 98ab9ab06796..fa95deb1dbaa 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -261,7 +261,7 @@ func @bufferize_tensor_collapse_shape(%arg0: tensor<4x5xf32>) -> tensor<20xf32>
   return %out : tensor<20xf32>
 }
 // CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[IN]] : memref<4x5xf32>
-// CHECK: %[[RESHAPE:.*]] = linalg.collapse_shape %[[MEMREF]] {{\[}}[0, 1]]
+// CHECK: %[[RESHAPE:.*]] = memref.collapse_shape %[[MEMREF]] {{\[}}[0, 1]]
 // CHECK-SAME: : memref<4x5xf32> into memref<20xf32>
 // CHECK: %[[TENSOR:.*]] = memref.tensor_load %[[RESHAPE]] : memref<20xf32>
 // CHECK: return %[[TENSOR]]

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 03e19909fa28..c453255d3948 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -81,19 +81,6 @@ func @collapsing_tensor_reshapes_to_zero_dim(%arg0 : tensor<1x1x1xf32>)
 
 // -----
 
-func @collapsing_memref_reshapes_to_zero_dim(%arg0 : memref<1x1x1xf32>)
-                                             -> memref<f32> {
-  %0 = linalg.collapse_shape %arg0 [[0, 1, 2]]
-      : memref<1x1x1xf32> into memref<1xf32>
-  %1 = linalg.collapse_shape %0 [] : memref<1xf32> into memref<f32>
-  return %1 : memref<f32>
-}
-// CHECK-LABEL: collapsing_memref_reshapes_to_zero
-//       CHECK:   linalg.collapse_shape %{{.*}} []
-//  CHECK-SAME:     memref<1x1x1xf32> into memref<f32>
-
-// -----
-
 func @expanding_tensor_reshapes(%arg0 : tensor<?x?xf32>) -> tensor<?x6x4x?x5xf32>
 {
   %0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2]]
@@ -108,34 +95,6 @@ func @expanding_tensor_reshapes(%arg0 : tensor<?x?xf32>) -> tensor<?x6x4x?x5xf32
 
 // -----
 
-func @collapsing_memref_reshapes(%arg0 : memref<?x?x?x?x?xf32>) -> memref<?x?xf32>
-{
-  %0 = linalg.collapse_shape %arg0 [[0, 1], [2], [3, 4]]
-      : memref<?x?x?x?x?xf32> into memref<?x?x?xf32>
-  %1 = linalg.collapse_shape %0 [[0, 1], [2]]
-      : memref<?x?x?xf32> into memref<?x?xf32>
-  return %1 : memref<?x?xf32>
-}
-// CHECK-LABEL: collapsing_memref_reshapes
-//       CHECK:   linalg.collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
-//   CHECK-NOT:   linalg.collapse_shape
-
-// -----
-
-func @expanding_memref_reshapes(%arg0 : memref<?x?xf32>) -> memref<?x6x4x5x?xf32>
-{
-  %0 = linalg.expand_shape %arg0 [[0, 1], [2]]
-      : memref<?x?xf32> into memref<?x4x?xf32>
-  %1 = linalg.expand_shape %0 [[0, 1], [2], [3, 4]]
-      : memref<?x4x?xf32> into memref<?x6x4x5x?xf32>
-  return %1 : memref<?x6x4x5x?xf32>
-}
-// CHECK-LABEL: expanding_memref_reshapes
-//       CHECK:   linalg.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
-//   CHECK-NOT:   linalg.expand_shape
-
-// -----
-
 func @expanding_tensor_reshapes_to_zero_dim(%arg0 : tensor<f32>)
                                              -> tensor<1x1x1xf32> {
   %0 = linalg.tensor_expand_shape %arg0 [] : tensor<f32> into tensor<1xf32>
@@ -149,19 +108,6 @@ func @expanding_tensor_reshapes_to_zero_dim(%arg0 : tensor<f32>)
 
 // -----
 
-func @expanding_memref_reshapes_to_zero_dim(%arg0 : memref<f32>)
-                                             -> memref<1x1x1xf32> {
-  %0 = linalg.expand_shape %arg0 [] : memref<f32> into memref<1xf32>
-  %1 = linalg.expand_shape %0 [[0, 1, 2]]
-      : memref<1xf32> into memref<1x1x1xf32>
-  return %1 : memref<1x1x1xf32>
-}
-// CHECK-LABEL: expanding_memref_reshapes_to_zero
-//       CHECK:   linalg.expand_shape %{{.*}} []
-//  CHECK-SAME:     memref<f32> into memref<1x1x1xf32>
-
-// -----
-
 func @fold_tensor_reshape(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32>
 {
   %0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2]]
@@ -188,32 +134,6 @@ func @fold_tensor_reshape_dynamic(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
 
 // -----
 
-func @fold_memref_reshape(%arg0 : memref<12x4xf32>) -> memref<12x4xf32>
-{
-  %0 = linalg.expand_shape %arg0 [[0, 1], [2]]
-      : memref<12x4xf32> into memref<3x4x4xf32>
-  %1 = linalg.collapse_shape %0 [[0, 1], [2]]
-      : memref<3x4x4xf32> into memref<12x4xf32>
-  return %1 : memref<12x4xf32>
-}
-// CHECK-LABEL: @fold_memref_reshape
-//   CHECK-NOT:   linalg.{{.*}}_shape
-
-// -----
-
-func @fold_memref_reshape_dynamic(%arg0 : memref<?x?xf32>) -> memref<?x?xf32>
-{
-  %0 = linalg.expand_shape %arg0 [[0, 1], [2]]
-      : memref<?x?xf32> into memref<?x4x?xf32>
-  %1 = linalg.collapse_shape %0 [[0, 1], [2]]
-      : memref<?x4x?xf32> into memref<?x?xf32>
-  return %1 : memref<?x?xf32>
-}
-// CHECK-LABEL: @fold_memref_reshape_dynamic
-//   CHECK-NOT:   linalg.{{.*}}_shape
-
-// -----
-
 func @reshape_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>) -> tensor<24x5x42x8xf32>
 {
   %0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2, 3, 4, 5, 6]]

diff  --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 69353eb7b744..a4357b6e4cd1 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -479,7 +479,7 @@ func @drop_one_trip_loops(%arg0 : memref<?x1x?xf32>, %arg1 : f32, %shape: memref
 //   CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> ()>
 //   CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 // CHECK-LABEL: func @drop_one_trip_loops
-//       CHECK: linalg.collapse_shape %{{.*}} {{\[}}[0, 1], [2]]
+//       CHECK: memref.collapse_shape %{{.*}} {{\[}}[0, 1], [2]]
 //       CHECK: linalg.generic
 //  CHECK-SAME:   indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP3]]]
 //  CHECK-SAME:   iterator_types = ["parallel", "parallel", "parallel"]
@@ -556,7 +556,7 @@ func @drop_all_loops(%arg0 : memref<1x1xf32>) -> memref<1x1xf32>
 }
 //       CHECK: #[[$MAP0:.*]] = affine_map<() -> ()>
 // CHECK-LABEL: func @drop_all_loops
-//       CHECK:   linalg.collapse_shape %{{.*}} []
+//       CHECK:   memref.collapse_shape %{{.*}} []
 //       CHECK:   linalg.generic
 //  CHECK-SAME:     indexing_maps = [#[[$MAP0]], #[[$MAP0]]]
 //  CHECK-SAME:     iterator_types = []
@@ -617,7 +617,7 @@ func @leading_dim_1_canonicalization(%arg0: memref<1x5xf32>, %shape: memref<5xf3
 //   CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (d0)>
 
 // CHECK-LABEL: func @leading_dim_1_canonicalization
-//       CHECK:   linalg.collapse_shape %{{.*}} {{\[}}[0, 1]]
+//       CHECK:   memref.collapse_shape %{{.*}} {{\[}}[0, 1]]
 //       CHECK:   linalg.generic
 //  CHECK-SAME:     indexing_maps = [#[[$MAP1]], #[[$MAP1]]]
 //  CHECK-SAME:     iterator_types = ["parallel"]
@@ -638,8 +638,8 @@ func @leading_dim_1_canonicalization(%arg0: memref<1x5xf32>, %shape: memref<5xf3
 
 func @broadcast_test(%arg0 : memref<5xf32>, %arg1 : memref<5xf32>, %shape : memref<5x5xf32>) -> memref<5x5xf32>
 {
-  %0 = linalg.expand_shape %arg0 [[0, 1]] : memref<5xf32> into memref<1x5xf32>
-  %1 = linalg.expand_shape %arg1 [[0, 1]] : memref<5xf32> into memref<5x1xf32>
+  %0 = memref.expand_shape %arg0 [[0, 1]] : memref<5xf32> into memref<1x5xf32>
+  %1 = memref.expand_shape %arg1 [[0, 1]] : memref<5xf32> into memref<5x1xf32>
   linalg.generic #trait
      ins(%0, %1 : memref<1x5xf32>, memref<5x1xf32>)
     outs(%shape : memref<5x5xf32>) {
@@ -686,7 +686,7 @@ func @broadcast_scalar(%arg0 : memref<1x1xf32>, %shape : memref<?x?xf32>) -> mem
 //   CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK-LABEL: func @broadcast_scalar
 //  CHECK-SAME:   %[[ARG0:.*]]: memref<1x1xf32>
-//       CHECK:   %[[A:.*]] = linalg.collapse_shape %[[ARG0]] []
+//       CHECK:   %[[A:.*]] = memref.collapse_shape %[[ARG0]] []
 //  CHECK-SAME:     memref<1x1xf32> into memref<f32>
 //       CHECK:   linalg.generic
 //  CHECK-SAME:     indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
@@ -706,16 +706,16 @@ func @fold_unit_dim_memref_reshape_op(%arg0 : memref<5xf32>) -> memref<2x5xf32>
     ^bb0(%arg1: f32, %arg2: f32):  // no predecessors
       linalg.yield %arg1 : f32
     }
-  %3 = linalg.collapse_shape %1 [[0, 1], [2]]
+  %3 = memref.collapse_shape %1 [[0, 1], [2]]
     : memref<1x2x5xf32> into memref<2x5xf32>
   return %3 : memref<2x5xf32>
 }
 // CHECK-LABEL: func @fold_unit_dim_memref_reshape_op
 //       CHECK:   %[[ALLOC:.*]] = memref.alloc() : memref<1x2x5xf32>
-//       CHECK:   %[[OUT:.*]] = linalg.collapse_shape %[[ALLOC]]
+//       CHECK:   %[[OUT:.*]] = memref.collapse_shape %[[ALLOC]]
 //       CHECK:   linalg.generic
 //       CHECK-SAME:   outs(%[[OUT:.*]] :
-//       CHECK:   %[[RESULT:.*]] = linalg.collapse_shape %[[ALLOC]]
+//       CHECK:   %[[RESULT:.*]] = memref.collapse_shape %[[ALLOC]]
 //       CHECK:   return %[[RESULT]]
 
 // -----
@@ -740,8 +740,8 @@ func @fold_unit_dim_for_init_memref(%input: memref<1x1000xf32>) -> memref<1xf32>
 
 //       CHECK: func @fold_unit_dim_for_init_memref
 //       CHECK: %[[INIT:.+]] = memref.alloc() : memref<1xf32>
-//       CHECK: %[[INPUT_RESHAPE:.+]] = linalg.collapse_shape %{{.+}} {{\[}}[0, 1]] : memref<1x1000xf32> into memref<1000xf32>
-//       CHECK: %[[INIT_RESHAPE:.+]] = linalg.collapse_shape %[[INIT]] [] : memref<1xf32> into memref<f32>
+//       CHECK: %[[INPUT_RESHAPE:.+]] = memref.collapse_shape %{{.+}} {{\[}}[0, 1]] : memref<1x1000xf32> into memref<1000xf32>
+//       CHECK: %[[INIT_RESHAPE:.+]] = memref.collapse_shape %[[INIT]] [] : memref<1xf32> into memref<f32>
 //       CHECK: linalg.generic
 //  CHECK-SAME:     indexing_maps = [#[[MAP1]], #[[MAP2]]]
 //  CHECK-SAME:     iterator_types = ["reduction"]

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 1de474708803..569b9a1b387d 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -308,58 +308,6 @@ func @generic(%arg0: memref<?x?xi4>) {
 
 // -----
 
-func @reshape(%arg0: memref<f32>) {
-  // expected-error @+1 {{expected non-zero memref ranks}}
-  %0 = linalg.expand_shape %arg0 [[0]] : memref<f32> into memref<f32>
-}
-
-// -----
-
-func @collapse_to_higher_rank(%arg0: memref<f32>) {
-  // expected-error @+1 {{expected the type 'memref<f32>' to have higher rank than the type = 'memref<1xf32>'}}
-  %0 = linalg.collapse_shape %arg0 [[0]] : memref<f32> into memref<1xf32>
-}
-
-// -----
-
-func @expand_to_smaller_rank(%arg0: memref<1xf32>) {
-  // expected-error @+1 {{expected the type 'memref<f32>' to have higher rank than the type = 'memref<1xf32>'}}
-  %0 = linalg.expand_shape %arg0 [[0]] : memref<1xf32> into memref<f32>
-}
-
-// -----
-
-func @reshape(%arg0: memref<?xf32>) {
-  // expected-error @+1 {{expected to collapse or expand dims}}
-  %0 = linalg.collapse_shape %arg0 [[0]] : memref<?xf32> into memref<?xf32>
-}
-
-// -----
-
-func @reshape(%arg0: memref<?x?x?xf32>) {
-  // expected-error @+1 {{expected rank of the collapsed type(2) to be the number of reassociation maps(1)}}
-  %0 = linalg.collapse_shape %arg0 [[0, 1]] :
-    memref<?x?x?xf32> into memref<?x?xf32, offset: 0, strides: [?, 1]>
-}
-
-// -----
-
-func @reshape(%arg0: memref<?x?x?xf32>) {
-  // expected-error @+1 {{expected reassociation map #1 to be valid and contiguous}}
-  %0 = linalg.collapse_shape %arg0 [[0, 1], [1, 2]] :
-    memref<?x?x?xf32> into memref<?x?xf32, offset: 0, strides: [?, 1]>
-}
-
-// -----
-
-func @reshape(%arg0: memref<?x?x?xf32>) {
-  // expected-error @+1 {{expected collapsed type to be 'memref<?x?xf32>', but got 'memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * s0 + d1)>>'}}
-  %0 = linalg.collapse_shape %arg0 [[0, 1], [2]] :
-    memref<?x?x?xf32> into memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * s0 + d1)>>
-}
-
-// -----
-
 func @pooling_rank_mismatch(%arg0: memref<?x?x?xf32>,
                             %arg1: memref<2x3xf32>,
                             %arg2: memref<?x?x?xf32>) {
@@ -397,7 +345,6 @@ func @matching_inits(%m: memref<?x?xf32>, %t: tensor<?x?xf32>) {
   return
 }
 
-
 // -----
 
 func @init_tensor_err(%arg0 : index, %arg1 : index)
@@ -438,16 +385,6 @@ func @illegal_expanding_reshape_dynamic_tensor
 
 // -----
 
-func @illegal_expanding_reshape_dynamic_memref
-  (%arg0: memref<?x?x?xf32>) -> memref<?x?x?x4x?xf32>
-{
-  // expected-error @+1 {{invalid to have a single dimension (2) expanded into multiple dynamic dims (2,4)}}
-  %0 = linalg.expand_shape %arg0 [[0], [1], [2, 3, 4]]
-      : memref<?x?x?xf32> into memref<?x?x?x4x?xf32>
-  return %0 : memref<?x?x?x4x?xf32>
-}
-
-// -----
 
 func @illegal_expanding_reshape_static_tensor
   (%arg0: tensor<2x3x20xf32>) -> tensor<2x3x2x4x5xf32>
@@ -471,28 +408,6 @@ func @illegal_collapsing_reshape_static_tensor
 
 // -----
 
-func @illegal_expanding_reshape_static_memref
-  (%arg0: memref<2x3x20xf32>) -> memref<2x3x2x4x5xf32>
-{
-  // expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}}
-  %0 = linalg.expand_shape %arg0 [[0], [1], [2, 3, 4]]
-      : memref<2x3x20xf32> into memref<2x3x2x4x5xf32>
-  return %0 : memref<2x3x2x4x5xf32>
-}
-
-// -----
-
-func @illegal_collapsing_reshape_static_memref
-  (%arg0: memref<2x3x2x4x5xf32>) -> memref<2x3x20xf32>
-{
-  // expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}}
-  %0 = linalg.collapse_shape %arg0 [[0], [1], [2, 3, 4]]
-      : memref<2x3x2x4x5xf32> into memref<2x3x20xf32>
-  return %0 : memref<2x3x20xf32>
-}
-
-// -----
-
 func @illegal_expanding_reshape_mixed_tensor(%arg0 : tensor<?x?xf32>) -> tensor<?x4x5xf32>
 {
   // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}}
@@ -533,46 +448,6 @@ func @illegal_collapsing_reshape_mixed_tensor_2(%arg0 : tensor<?x4x5xf32>) -> te
 
 // -----
 
-func @illegal_expanding_reshape_mixed_memref(%arg0 : memref<?x?xf32>) -> memref<?x4x5xf32>
-{
-  // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}}
-  %0 = linalg.expand_shape %arg0 [[0, 1], [2]]
-      : memref<?x?xf32> into memref<?x4x5xf32>
-  return %0 : memref<?x4x5xf32>
-}
-
-// -----
-
-func @illegal_expanding_reshape_mixed_memref_2(%arg0 : memref<?x?xf32>) -> memref<?x4x5xf32>
-{
-  // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}}
-  %0 = linalg.expand_shape %arg0 [[0], [1, 2]]
-      : memref<?x?xf32> into memref<?x4x5xf32>
-  return %0 : memref<?x4x5xf32>
-}
-
-// -----
-
-func @illegal_collapsing_reshape_mixed_memref(%arg0 : memref<?x4x5xf32>) -> memref<?x?xf32>
-{
-  // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}}
-  %0 = linalg.collapse_shape %arg0 [[0, 1], [2]]
-      : memref<?x4x5xf32> into memref<?x?xf32>
-  return %0 : memref<?x?xf32>
-}
-
-// -----
-
-func @illegal_collapse_reshape_mixed_memref_2(%arg0 : memref<?x4x5xf32>) -> memref<?x?xf32>
-{
-  // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}}
-  %0 = linalg.collapse_shape %arg0 [[0], [1, 2]]
-      : memref<?x4x5xf32> into memref<?x?xf32>
-  return %0 : memref<?x?xf32>
-}
-
-// -----
-
 func @pad_result_type(%arg0: tensor<?x2x3x4xi32>, %arg1: index, %arg2: i32) -> tensor<?x?x?x8xf32> {
   // expected-error @+1 {{specified type 'tensor<?x?x?x8xf32>' does not match the inferred type 'tensor<?x?x?x9xi32>}}
   %0 = linalg.pad_tensor %arg0 low[1, %arg1, 2, 2] high[1, 2, %arg1, 3] {
@@ -824,6 +699,6 @@ func @invalid_reverse(%A: memref<5xf32>, %B: memref<5xf32>) {
   linalg.generic #attrs ins(%A: memref<5xf32>) outs(%B: memref<5xf32>) {
 		^bb0(%a: f32, %b: f32):
 		linalg.yield %a : f32
-	} 
+	}
 	return
 }

diff  --git a/mlir/test/Dialect/Linalg/llvm.mlir b/mlir/test/Dialect/Linalg/llvm.mlir
index 4ddee115c0b0..6c60c5f3af3a 100644
--- a/mlir/test/Dialect/Linalg/llvm.mlir
+++ b/mlir/test/Dialect/Linalg/llvm.mlir
@@ -13,98 +13,3 @@ func @range(%arg0: index) {
 //       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(i64, i64, i64)>
 //       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(i64, i64, i64)>
 //       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(i64, i64, i64)>
-
-func @expand_shape_static(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> {
-  // Reshapes that expand a contiguous tensor with some 1's.
-  %0 = linalg.expand_shape %arg0 [[0, 1], [2], [3, 4]]
-      : memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
-  return %0 : memref<1x3x4x1x5xf32>
-}
-// CHECK-LABEL: func @expand_shape_static
-//       CHECK:    llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-//       CHECK:    llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-//       CHECK:    llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-//       CHECK:    llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-//       CHECK:    llvm.mlir.constant(1 : index) : i64
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-//       CHECK:    llvm.mlir.constant(3 : index) : i64
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-//       CHECK:    llvm.mlir.constant(4 : index) : i64
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-//       CHECK:    llvm.mlir.constant(1 : index) : i64
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[3, 3] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-//       CHECK:    llvm.mlir.constant(5 : index) : i64
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[3, 4] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-//       CHECK:    llvm.mlir.constant(60 : index) : i64
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-//       CHECK:    llvm.mlir.constant(20 : index) : i64
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-//       CHECK:    llvm.mlir.constant(5 : index) : i64
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-//       CHECK:    llvm.mlir.constant(5 : index) : i64
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[4, 3] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-//       CHECK:    llvm.mlir.constant(1 : index) : i64
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[4, 4] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-
-func @collapse_shape_static(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> {
-  %0 = linalg.collapse_shape %arg0 [[0, 1], [2], [3, 4]] :
-    memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
-  return %0 : memref<3x4x5xf32>
-}
-// CHECK-LABEL: func @collapse_shape_static
-//       CHECK:    llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:    llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:    llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:    llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:    llvm.mlir.constant(3 : index) : i64
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:    llvm.mlir.constant(4 : index) : i64
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:    llvm.mlir.constant(5 : index) : i64
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:    llvm.mlir.constant(20 : index) : i64
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:    llvm.mlir.constant(5 : index) : i64
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:    llvm.mlir.constant(1 : index) : i64
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-
-func @collapse_shape_fold_zero_dim(%arg0 : memref<1x1xf32>) -> memref<f32> {
-  %0 = linalg.collapse_shape %arg0 [] : memref<1x1xf32> into memref<f32>
-  return %0 : memref<f32>
-}
-// CHECK-LABEL: func @collapse_shape_fold_zero_dim
-//       CHECK:   llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
-//       CHECK:   llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
-//       CHECK:   llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
-//       CHECK:   llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
-
-func @expand_shape_zero_dim(%arg0 : memref<f32>) -> memref<1x1xf32> {
-  %0 = linalg.expand_shape %arg0 [] : memref<f32> into memref<1x1xf32>
-  return %0 : memref<1x1xf32>
-}
-// CHECK-LABEL: func @expand_shape_zero_dim
-//       CHECK:   llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-//       CHECK:   llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
-//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-//       CHECK:   llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
-//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-//       CHECK:   llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
-//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-//       CHECK:   llvm.mlir.constant(1 : index) : i64
-//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-//       CHECK:   llvm.mlir.constant(1 : index) : i64
-//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-//       CHECK:   llvm.mlir.constant(1 : index) : i64
-//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-//       CHECK:   llvm.mlir.constant(1 : index) : i64
-//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 5d842bba5960..e0d7ab2dfb24 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -12,9 +12,7 @@
 // CHECK-DAG: #[[$permute_1:.*]] = affine_map<(d0, d1, d2) -> (d2, d1, d0)>
 // CHECK-DAG: #[[$strided1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
 // CHECK-DAG: #[[$strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
-// CHECK-DAG: #[[$strided2DOFF0:.*]] = affine_map<(d0, d1)[s0] -> (d0 * s0 + d1)>
 // CHECK-DAG: #[[$strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>
-// CHECK-DAG: #[[$strided3DOFF0:.*]] = affine_map<(d0, d1, d2)[s0, s1] -> (d0 * s0 + d1 * s1 + d2)>
 // CHECK-DAG: #[[$strided3DT:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2 * s1 + s0 + d1 * s2 + d0)>
 // CHECK-DAG: #[[$strided6D:.*]] = affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4 + d4 * s5 + d5)>
 
@@ -169,7 +167,6 @@ func @ops(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
 
 // -----
 
-
 func @fill_view(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: f32) {
   linalg.fill(%arg1, %arg0) : f32, memref<?xf32, offset: ?, strides: [1]>
   return
@@ -541,96 +538,6 @@ func @generic_region(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1
 
 // -----
 
-func @reshape_static(%arg0: memref<3x4x5xf32>, %arg1: tensor<3x4x5xf32>,
-                     %arg2: tensor<3x?x5xf32>) {
-  // Reshapes that collapse and expand back a contiguous buffer.
-  %0 = linalg.collapse_shape %arg0 [[0, 1], [2]] :
-    memref<3x4x5xf32> into memref<12x5xf32>
-  %r0 = linalg.expand_shape %0 [[0, 1], [2]] :
-    memref<12x5xf32> into memref<3x4x5xf32>
-  %1 = linalg.collapse_shape %arg0 [[0], [1, 2]] :
-    memref<3x4x5xf32> into memref<3x20xf32>
-  %r1 = linalg.expand_shape %1 [[0], [1, 2]] :
-    memref<3x20xf32> into memref<3x4x5xf32>
-  %2 = linalg.collapse_shape %arg0 [[0, 1, 2]] :
-    memref<3x4x5xf32> into memref<60xf32>
-  %r2 = linalg.expand_shape %2 [[0, 1, 2]] :
-    memref<60xf32> into memref<3x4x5xf32>
-  // Reshapes that expand and collapse back a contiguous buffer with some 1's.
-  %3 = linalg.expand_shape %arg0 [[0, 1], [2], [3, 4]] :
-    memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
-  %r3 = linalg.collapse_shape %3 [[0, 1], [2], [3, 4]] :
-    memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
-  // Reshapes on tensors.
-  %t0 = linalg.tensor_expand_shape %arg1 [[0, 1], [2], [3, 4]] :
-    tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32>
-  %rt0 = linalg.tensor_collapse_shape %t0 [[0, 1], [2], [3, 4]] :
-    tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32>
-  %t1 = linalg.tensor_expand_shape %arg2 [[0, 1], [2], [3, 4]] :
-    tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32>
-  %rt1 = linalg.tensor_collapse_shape %t1 [[0], [1, 2], [3, 4]] :
-    tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32>
-  return
-}
-// CHECK-LABEL: func @reshape_static
-//       CHECK:   linalg.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
-//  CHECK-SAME:     memref<3x4x5xf32> into memref<12x5xf32>
-//       CHECK:   linalg.expand_shape {{.*}} {{\[}}[0, 1], [2]]
-//  CHECK-SAME:     memref<12x5xf32> into memref<3x4x5xf32>
-//       CHECK:   linalg.collapse_shape {{.*}} {{\[}}[0], [1, 2]]
-//  CHECK-SAME:     memref<3x4x5xf32> into memref<3x20xf32>
-//       CHECK:   linalg.expand_shape {{.*}} {{\[}}[0], [1, 2]]
-//  CHECK-SAME:     memref<3x20xf32> into memref<3x4x5xf32>
-//       CHECK:   linalg.collapse_shape {{.*}} {{\[}}[0, 1, 2]]
-//  CHECK-SAME:     memref<3x4x5xf32> into memref<60xf32>
-//       CHECK:   linalg.expand_shape {{.*}} {{\[}}[0, 1, 2]]
-//  CHECK-SAME:     memref<60xf32> into memref<3x4x5xf32>
-//       CHECK:   linalg.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
-//  CHECK-SAME:     memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
-//       CHECK:   linalg.collapse_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
-//  CHECK-SAME:     memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
-//
-//       CHECK:   linalg.tensor_expand_shape {{.*}}: tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32>
-//       CHECK:   linalg.tensor_collapse_shape {{.*}}: tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32>
-//       CHECK:   linalg.tensor_expand_shape {{.*}}: tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32>
-//       CHECK:   linalg.tensor_collapse_shape {{.*}}: tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32>
-
-// -----
-
-func @reshape_dynamic(%arg0: memref<?x?x?xf32>,
-                      %arg1: memref<?x?x?xf32, offset : 0, strides : [?, ?, 1]>,
-                      %arg2: memref<?x?x?xf32, offset : ?, strides : [?, ?, 1]>) {
-  %0 = linalg.collapse_shape %arg0 [[0, 1], [2]] :
-    memref<?x?x?xf32> into memref<?x?xf32>
-  %r0 = linalg.expand_shape %0 [[0, 1], [2]] :
-    memref<?x?xf32> into memref<?x4x?xf32>
-  %1 = linalg.collapse_shape %arg1 [[0, 1], [2]] :
-    memref<?x?x?xf32, offset : 0, strides : [?, ?, 1]> into
-    memref<?x?xf32, offset : 0, strides : [?, 1]>
-  %r1 = linalg.expand_shape %1 [[0, 1], [2]] :
-    memref<?x?xf32, offset : 0, strides : [?, 1]> into
-    memref<?x4x?xf32, offset : 0, strides : [?, ?, 1]>
-  %2 = linalg.collapse_shape %arg2 [[0, 1], [2]] :
-    memref<?x?x?xf32, offset : ?, strides : [?, ?, 1]> into
-    memref<?x?xf32, offset : ?, strides : [?, 1]>
-  %r2 = linalg.expand_shape %2 [[0, 1], [2]] :
-    memref<?x?xf32, offset : ?, strides : [?, 1]> into
-    memref<?x4x?xf32, offset : ?, strides : [?, ?, 1]>
-  return
-}
-// CHECK-LABEL: func @reshape
-//       CHECK:   linalg.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
-//  CHECK-SAME:     memref<?x?x?xf32> into memref<?x?xf32>
-//       CHECK:   linalg.expand_shape {{.*}} {{\[}}[0, 1], [2]]
-//  CHECK-SAME:     memref<?x?xf32> into memref<?x4x?xf32>
-//       CHECK:   linalg.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
-//  CHECK-SAME:     memref<?x?x?xf32, #[[$strided3DOFF0]]> into memref<?x?xf32, #[[$strided2DOFF0]]>
-//       CHECK:   linalg.expand_shape {{.*}} {{\[}}[0, 1], [2]]
-//  CHECK-SAME:     memref<?x?xf32, #[[$strided2DOFF0]]> into memref<?x4x?xf32, #[[$strided3DOFF0]]>
-//       CHECK:   linalg.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
-//  CHECK-SAME:     memref<?x?x?xf32, #[[$strided3D]]> into memref<?x?xf32, #[[$strided2D]]>
-//       CHECK:   linalg.expand_shape {{.*}} {{\[}}[0, 1], [2]]
-//  CHECK-SAME:     memref<?x?xf32, #[[$strided2D]]> into memref<?x4x?xf32, #[[$strided3D]]>
 
 func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?x?xf32>, %c3: memref<?x?x?xf32>,
                 %ta3: tensor<?x?x?xf32>, %tb3: tensor<?x?x?xf32>, %tc3: tensor<?x?x?xf32>)
@@ -670,17 +577,6 @@ func @tensor_reshape_zero_dim(%arg0 : tensor<1x1xf32>, %arg1 : tensor<f32>) -> (
 
 // -----
 
-func @memref_reshape_zero_dim(%arg0 : memref<1x1xf32>, %arg1 : memref<f32>) -> (memref<f32>, memref<1x1xf32>)
-{
-  %0 = linalg.collapse_shape %arg0 [] : memref<1x1xf32> into memref<f32>
-  %1 = linalg.expand_shape %0 [] : memref<f32> into memref<1x1xf32>
-  return %0, %1 : memref<f32>, memref<1x1xf32>
-}
-// CHECK-LABEL: func @memref_reshape_zero_dim
-//       CHECK:   linalg.collapse_shape %{{.*}} [] : memref<1x1xf32> into memref<f32>
-//       CHECK:   linalg.expand_shape %{{.*}} [] : memref<f32> into memref<1x1xf32>
-
-// -----
 
 func @init_tensor(%arg0 : index, %arg1 : index)
 {
@@ -707,19 +603,6 @@ func @legal_collapsing_reshape_dynamic_tensor
 
 // -----
 
-func @legal_collapsing_reshape_dynamic_memref
-  (%arg0: memref<?x?x?x4x?xf32>) -> memref<?x?x?xf32>
-{
-  %0 = linalg.collapse_shape %arg0 [[0], [1], [2, 3, 4]] :
-    memref<?x?x?x4x?xf32> into memref<?x?x?xf32>
-  return %0 : memref<?x?x?xf32>
-}
-//      CHECK: func @legal_collapsing_reshape_dynamic_memref
-//      CHECK:   linalg.collapse_shape
-// CHECK-SAME:    [0], [1], [2, 3, 4]
-
-// -----
-
 func @fill_tensor(%arg0 : index, %arg1 : index, %arg2 : f32) -> tensor<?x?xf32> {
   %0 = linalg.init_tensor [%arg0, %arg1] : tensor<?x?xf32>
   %1 = linalg.fill(%arg2, %0) : f32, tensor<?x?xf32> -> tensor<?x?xf32>

diff  --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 2ae2c06dea92..02a8ce4441c3 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -395,3 +395,81 @@ func @allocator(%arg0 : memref<memref<?xi32>>, %arg1 : index)  {
   memref.store %0, %arg0[] : memref<memref<?xi32>>
   return 
 }
+
+// -----
+
+func @collapsing_memref_reshapes_to_zero_dim(%arg0 : memref<1x1x1xf32>)
+                                             -> memref<f32> {
+  %0 = memref.collapse_shape %arg0 [[0, 1, 2]]
+      : memref<1x1x1xf32> into memref<1xf32>
+  %1 = memref.collapse_shape %0 [] : memref<1xf32> into memref<f32>
+  return %1 : memref<f32>
+}
+// CHECK-LABEL: collapsing_memref_reshapes_to_zero
+//       CHECK:   memref.collapse_shape %{{.*}} []
+//  CHECK-SAME:     memref<1x1x1xf32> into memref<f32>
+
+// -----
+
+func @collapsing_memref_reshapes(%arg0 : memref<?x?x?x?x?xf32>)
+    -> memref<?x?xf32> {
+  %0 = memref.collapse_shape %arg0 [[0, 1], [2], [3, 4]]
+      : memref<?x?x?x?x?xf32> into memref<?x?x?xf32>
+  %1 = memref.collapse_shape %0 [[0, 1], [2]]
+      : memref<?x?x?xf32> into memref<?x?xf32>
+  return %1 : memref<?x?xf32>
+}
+// CHECK-LABEL: collapsing_memref_reshapes
+//       CHECK:   memref.collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
+//   CHECK-NOT:   memref.collapse_shape
+
+// -----
+
+func @expanding_memref_reshapes(%arg0 : memref<?x?xf32>)
+    -> memref<?x6x4x5x?xf32> {
+  %0 = memref.expand_shape %arg0 [[0, 1], [2]]
+      : memref<?x?xf32> into memref<?x4x?xf32>
+  %1 = memref.expand_shape %0 [[0, 1], [2], [3, 4]]
+      : memref<?x4x?xf32> into memref<?x6x4x5x?xf32>
+  return %1 : memref<?x6x4x5x?xf32>
+}
+// CHECK-LABEL: expanding_memref_reshapes
+//       CHECK:   memref.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
+//   CHECK-NOT:   memref.expand_shape
+
+// -----
+
+func @expanding_memref_reshapes_to_zero_dim(%arg0 : memref<f32>)
+                                             -> memref<1x1x1xf32> {
+  %0 = memref.expand_shape %arg0 [] : memref<f32> into memref<1xf32>
+  %1 = memref.expand_shape %0 [[0, 1, 2]]
+      : memref<1xf32> into memref<1x1x1xf32>
+  return %1 : memref<1x1x1xf32>
+}
+// CHECK-LABEL: expanding_memref_reshapes_to_zero
+//       CHECK:   memref.expand_shape %{{.*}} []
+//  CHECK-SAME:     memref<f32> into memref<1x1x1xf32>
+
+// -----
+
+func @fold_memref_reshape(%arg0 : memref<12x4xf32>) -> memref<12x4xf32> {
+  %0 = memref.expand_shape %arg0 [[0, 1], [2]]
+      : memref<12x4xf32> into memref<3x4x4xf32>
+  %1 = memref.collapse_shape %0 [[0, 1], [2]]
+      : memref<3x4x4xf32> into memref<12x4xf32>
+  return %1 : memref<12x4xf32>
+}
+// CHECK-LABEL: @fold_memref_reshape
+//   CHECK-NOT:   linalg.{{.*}}_shape
+
+// -----
+
+func @fold_memref_reshape_dynamic(%arg0 : memref<?x?xf32>) -> memref<?x?xf32> {
+  %0 = memref.expand_shape %arg0 [[0, 1], [2]]
+      : memref<?x?xf32> into memref<?x4x?xf32>
+  %1 = memref.collapse_shape %0 [[0, 1], [2]]
+      : memref<?x4x?xf32> into memref<?x?xf32>
+  return %1 : memref<?x?xf32>
+}
+// CHECK-LABEL: @fold_memref_reshape_dynamic
+//   CHECK-NOT:   linalg.{{.*}}_shape

diff  --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 63209ef10894..dcd1a6b12849 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -231,3 +231,125 @@ func @copy_
diff erent_eltype(%arg0: memref<2xf32>, %arg1: memref<2xf16>) {
   memref.copy %arg0, %arg1 : memref<2xf32> to memref<2xf16>
   return
 }
+
+// -----
+
+func @expand_shape(%arg0: memref<f32>) {
+  // expected-error @+1 {{expected non-zero memref ranks}}
+  %0 = memref.expand_shape %arg0 [[0]] : memref<f32> into memref<f32>
+}
+
+// -----
+
+func @collapse_shape_to_higher_rank(%arg0: memref<f32>) {
+  // expected-error @+1 {{expected the type 'memref<f32>' to have higher rank than the type = 'memref<1xf32>'}}
+  %0 = memref.collapse_shape %arg0 [[0]] : memref<f32> into memref<1xf32>
+}
+
+// -----
+
+func @expand_shape_to_smaller_rank(%arg0: memref<1xf32>) {
+  // expected-error @+1 {{expected the type 'memref<f32>' to have higher rank than the type = 'memref<1xf32>'}}
+  %0 = memref.expand_shape %arg0 [[0]] : memref<1xf32> into memref<f32>
+}
+
+// -----
+
+func @collapse_shape(%arg0: memref<?xf32>) {
+  // expected-error @+1 {{expected to collapse or expand dims}}
+  %0 = memref.collapse_shape %arg0 [[0]] : memref<?xf32> into memref<?xf32>
+}
+
+// -----
+
+func @collapse_shape_mismatch_indices_num(%arg0: memref<?x?x?xf32>) {
+  // expected-error @+1 {{expected rank of the collapsed type(2) to be the number of reassociation maps(1)}}
+  %0 = memref.collapse_shape %arg0 [[0, 1]] :
+    memref<?x?x?xf32> into memref<?x?xf32, offset: 0, strides: [?, 1]>
+}
+
+// -----
+
+func @collapse_shape_invalid_reassociation(%arg0: memref<?x?x?xf32>) {
+  // expected-error @+1 {{expected reassociation map #1 to be valid and contiguous}}
+  %0 = memref.collapse_shape %arg0 [[0, 1], [1, 2]] :
+    memref<?x?x?xf32> into memref<?x?xf32, offset: 0, strides: [?, 1]>
+}
+
+// -----
+
+func @collapse_shape_wrong_collapsed_type(%arg0: memref<?x?x?xf32>) {
+  // expected-error @+1 {{expected collapsed type to be 'memref<?x?xf32>', but got 'memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * s0 + d1)>>'}}
+  %0 = memref.collapse_shape %arg0 [[0, 1], [2]] :
+    memref<?x?x?xf32> into memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * s0 + d1)>>
+}
+
+// -----
+
+func @expand_shape_illegal_dynamic_memref
+  (%arg0: memref<?x?x?xf32>) -> memref<?x?x?x4x?xf32> {
+  // expected-error @+1 {{invalid to have a single dimension (2) expanded into multiple dynamic dims (2,4)}}
+  %0 = memref.expand_shape %arg0 [[0], [1], [2, 3, 4]]
+      : memref<?x?x?xf32> into memref<?x?x?x4x?xf32>
+  return %0 : memref<?x?x?x4x?xf32>
+}
+
+// -----
+
+func @expand_shape_illegal_static_memref
+  (%arg0: memref<2x3x20xf32>) -> memref<2x3x2x4x5xf32> {
+  // expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}}
+  %0 = memref.expand_shape %arg0 [[0], [1], [2, 3, 4]]
+      : memref<2x3x20xf32> into memref<2x3x2x4x5xf32>
+  return %0 : memref<2x3x2x4x5xf32>
+}
+
+// -----
+
+func @collapse_shape_illegal_static_memref
+  (%arg0: memref<2x3x2x4x5xf32>) -> memref<2x3x20xf32> {
+  // expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}}
+  %0 = memref.collapse_shape %arg0 [[0], [1], [2, 3, 4]]
+      : memref<2x3x2x4x5xf32> into memref<2x3x20xf32>
+  return %0 : memref<2x3x20xf32>
+}
+
+// -----
+
+func @expand_shape_illegal_mixed_memref(%arg0 : memref<?x?xf32>)
+    -> memref<?x4x5xf32> {
+  // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}}
+  %0 = memref.expand_shape %arg0 [[0, 1], [2]]
+      : memref<?x?xf32> into memref<?x4x5xf32>
+  return %0 : memref<?x4x5xf32>
+}
+
+// -----
+
+func @expand_shape_illegal_mixed_memref_2(%arg0 : memref<?x?xf32>)
+    -> memref<?x4x5xf32> {
+  // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}}
+  %0 = memref.expand_shape %arg0 [[0], [1, 2]]
+      : memref<?x?xf32> into memref<?x4x5xf32>
+  return %0 : memref<?x4x5xf32>
+}
+
+// -----
+
+func @collapse_shape_illegal_mixed_memref(%arg0 : memref<?x4x5xf32>)
+    -> memref<?x?xf32> {
+  // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}}
+  %0 = memref.collapse_shape %arg0 [[0, 1], [2]]
+      : memref<?x4x5xf32> into memref<?x?xf32>
+  return %0 : memref<?x?xf32>
+}
+
+// -----
+
+func @collapse_shape_illegal_mixed_memref_2(%arg0 : memref<?x4x5xf32>)
+    -> memref<?x?xf32> {
+  // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}}
+  %0 = memref.collapse_shape %arg0 [[0], [1, 2]]
+      : memref<?x4x5xf32> into memref<?x?xf32>
+  return %0 : memref<?x?xf32>
+}

diff  --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 993a6131ab51..714b769099a1 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -1,6 +1,11 @@
 // RUN: mlir-opt %s | mlir-opt | FileCheck %s
 // RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
 
+// CHECK-DAG: #[[$strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+// CHECK-DAG: #[[$strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>
+// CHECK-DAG: #[[$strided2DOFF0:.*]] = affine_map<(d0, d1)[s0] -> (d0 * s0 + d1)>
+// CHECK-DAG: #[[$strided3DOFF0:.*]] = affine_map<(d0, d1, d2)[s0, s1] -> (d0 * s0 + d1 * s1 + d2)>
+
 // CHECK-LABEL: test_buffer_cast
 func @test_buffer_cast(%arg0: tensor<?xi64>, %arg1: tensor<*xi64>) -> (memref<?xi64, affine_map<(d0) -> (d0 + 7)>>, memref<*xi64, 1>) {
   %0 = memref.buffer_cast %arg0 : memref<?xi64, affine_map<(d0) -> (d0 + 7)>>
@@ -95,3 +100,114 @@ func @memref_alloca_scope() {
   }
   return
 }
+
+func @expand_collapse_shape_static(%arg0: memref<3x4x5xf32>,
+                                   %arg1: tensor<3x4x5xf32>,
+                                   %arg2: tensor<3x?x5xf32>) {
+  // Reshapes that collapse and expand back a contiguous buffer.
+  %0 = memref.collapse_shape %arg0 [[0, 1], [2]] :
+    memref<3x4x5xf32> into memref<12x5xf32>
+  %r0 = memref.expand_shape %0 [[0, 1], [2]] :
+    memref<12x5xf32> into memref<3x4x5xf32>
+  %1 = memref.collapse_shape %arg0 [[0], [1, 2]] :
+    memref<3x4x5xf32> into memref<3x20xf32>
+  %r1 = memref.expand_shape %1 [[0], [1, 2]] :
+    memref<3x20xf32> into memref<3x4x5xf32>
+  %2 = memref.collapse_shape %arg0 [[0, 1, 2]] :
+    memref<3x4x5xf32> into memref<60xf32>
+  %r2 = memref.expand_shape %2 [[0, 1, 2]] :
+    memref<60xf32> into memref<3x4x5xf32>
+  // Reshapes that expand and collapse back a contiguous buffer with some 1's.
+  %3 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]] :
+    memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
+  %r3 = memref.collapse_shape %3 [[0, 1], [2], [3, 4]] :
+    memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
+  // Reshapes on tensors.
+  %t0 = linalg.tensor_expand_shape %arg1 [[0, 1], [2], [3, 4]] :
+    tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32>
+  %rt0 = linalg.tensor_collapse_shape %t0 [[0, 1], [2], [3, 4]] :
+    tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32>
+  %t1 = linalg.tensor_expand_shape %arg2 [[0, 1], [2], [3, 4]] :
+    tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32>
+  %rt1 = linalg.tensor_collapse_shape %t1 [[0], [1, 2], [3, 4]] :
+    tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32>
+  return
+}
+// CHECK-LABEL: func @expand_collapse_shape_static
+//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
+//  CHECK-SAME:     memref<3x4x5xf32> into memref<12x5xf32>
+//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
+//  CHECK-SAME:     memref<12x5xf32> into memref<3x4x5xf32>
+//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0], [1, 2]]
+//  CHECK-SAME:     memref<3x4x5xf32> into memref<3x20xf32>
+//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0], [1, 2]]
+//  CHECK-SAME:     memref<3x20xf32> into memref<3x4x5xf32>
+//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1, 2]]
+//  CHECK-SAME:     memref<3x4x5xf32> into memref<60xf32>
+//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1, 2]]
+//  CHECK-SAME:     memref<60xf32> into memref<3x4x5xf32>
+//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
+//  CHECK-SAME:     memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
+//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
+//  CHECK-SAME:     memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
+//
+//       CHECK:   linalg.tensor_expand_shape {{.*}}: tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32>
+//       CHECK:   linalg.tensor_collapse_shape {{.*}}: tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32>
+//       CHECK:   linalg.tensor_expand_shape {{.*}}: tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32>
+//       CHECK:   linalg.tensor_collapse_shape {{.*}}: tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32>
+
+
+func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
+         %arg1: memref<?x?x?xf32, offset : 0, strides : [?, ?, 1]>,
+         %arg2: memref<?x?x?xf32, offset : ?, strides : [?, ?, 1]>) {
+  %0 = memref.collapse_shape %arg0 [[0, 1], [2]] :
+    memref<?x?x?xf32> into memref<?x?xf32>
+  %r0 = memref.expand_shape %0 [[0, 1], [2]] :
+    memref<?x?xf32> into memref<?x4x?xf32>
+  %1 = memref.collapse_shape %arg1 [[0, 1], [2]] :
+    memref<?x?x?xf32, offset : 0, strides : [?, ?, 1]> into
+    memref<?x?xf32, offset : 0, strides : [?, 1]>
+  %r1 = memref.expand_shape %1 [[0, 1], [2]] :
+    memref<?x?xf32, offset : 0, strides : [?, 1]> into
+    memref<?x4x?xf32, offset : 0, strides : [?, ?, 1]>
+  %2 = memref.collapse_shape %arg2 [[0, 1], [2]] :
+    memref<?x?x?xf32, offset : ?, strides : [?, ?, 1]> into
+    memref<?x?xf32, offset : ?, strides : [?, 1]>
+  %r2 = memref.expand_shape %2 [[0, 1], [2]] :
+    memref<?x?xf32, offset : ?, strides : [?, 1]> into
+    memref<?x4x?xf32, offset : ?, strides : [?, ?, 1]>
+  return
+}
+// CHECK-LABEL: func @expand_collapse_shape_dynamic
+//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
+//  CHECK-SAME:     memref<?x?x?xf32> into memref<?x?xf32>
+//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
+//  CHECK-SAME:     memref<?x?xf32> into memref<?x4x?xf32>
+//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
+//  CHECK-SAME:     memref<?x?x?xf32, #[[$strided3DOFF0]]> into memref<?x?xf32, #[[$strided2DOFF0]]>
+//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
+//  CHECK-SAME:     memref<?x?xf32, #[[$strided2DOFF0]]> into memref<?x4x?xf32, #[[$strided3DOFF0]]>
+//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
+//  CHECK-SAME:     memref<?x?x?xf32, #[[$strided3D]]> into memref<?x?xf32, #[[$strided2D]]>
+//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
+//  CHECK-SAME:     memref<?x?xf32, #[[$strided2D]]> into memref<?x4x?xf32, #[[$strided3D]]>
+
+func @expand_collapse_shape_zero_dim(%arg0 : memref<1x1xf32>, %arg1 : memref<f32>)
+    -> (memref<f32>, memref<1x1xf32>) {
+  %0 = memref.collapse_shape %arg0 [] : memref<1x1xf32> into memref<f32>
+  %1 = memref.expand_shape %0 [] : memref<f32> into memref<1x1xf32>
+  return %0, %1 : memref<f32>, memref<1x1xf32>
+}
+// CHECK-LABEL: func @expand_collapse_shape_zero_dim
+//       CHECK:   memref.collapse_shape %{{.*}} [] : memref<1x1xf32> into memref<f32>
+//       CHECK:   memref.expand_shape %{{.*}} [] : memref<f32> into memref<1x1xf32>
+
+func @collapse_shape_to_dynamic
+  (%arg0: memref<?x?x?x4x?xf32>) -> memref<?x?x?xf32> {
+  %0 = memref.collapse_shape %arg0 [[0], [1], [2, 3, 4]] :
+    memref<?x?x?x4x?xf32> into memref<?x?x?xf32>
+  return %0 : memref<?x?x?xf32>
+}
+//      CHECK: func @collapse_shape_to_dynamic
+//      CHECK:   memref.collapse_shape
+// CHECK-SAME:    [0], [1], [2, 3, 4]

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index c9f831f07a7d..1fd0a2bd2100 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -5862,6 +5862,7 @@ cc_library(
         ":LLVMDialect",
         ":LinalgOps",
         ":LinalgTransforms",
+        ":MemRefToLLVM",
         ":Pass",
         ":SCFDialect",
         ":SCFToStandard",


        


More information about the Mlir-commits mailing list