[Mlir-commits] [mlir] 3914836 - [mlir][MemRefToLLVM] Remove the code for lowering collaspe/expand_shape

Quentin Colombet llvmlistbot at llvm.org
Mon Jan 23 07:48:34 PST 2023


Author: Quentin Colombet
Date: 2023-01-23T15:41:55Z
New Revision: 3914836273aa51f14121abd77cc6986cd3ccee11

URL: https://github.com/llvm/llvm-project/commit/3914836273aa51f14121abd77cc6986cd3ccee11
DIFF: https://github.com/llvm/llvm-project/commit/3914836273aa51f14121abd77cc6986cd3ccee11.diff

LOG: [mlir][MemRefToLLVM] Remove the code for lowering collaspe/expand_shape

collapse/expand_shape are supposed to be expanded before we hit the
lowering code.
The expansion is done with the pass called expand-strided-metadata.

This patch is NFC in spirit but not in practice because
expand-strided-metadata won't try to accomodate for "invalid" strides
for dynamic sizes that are 1 at runtime.

The previous code was broken in that respect too, but differently: it
handled only the case of row-major layouts.
That whole part is being reworked separately.

Differential Revision: https://reviews.llvm.org/D136483

Added: 
    

Modified: 
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
    mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
    mlir/test/Integration/Dialect/Linalg/CPU/test-collapse-tensor.mlir
    mlir/test/Integration/Dialect/Linalg/CPU/test-expand-tensor.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 4d5af81757d03..7132560ad8b1d 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1374,256 +1374,8 @@ struct MemRefReshapeOpLowering
   }
 };
 
-/// Helper function to convert a vector of `OpFoldResult`s into a vector of
-/// `Value`s.
-static SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
-                                      Type &llvmIndexType,
-                                      ArrayRef<OpFoldResult> valueOrAttrVec) {
-  return llvm::to_vector<4>(
-      llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
-        if (auto attr = value.dyn_cast<Attribute>())
-          return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr);
-        return value.get<Value>();
-      }));
-}
-
-/// Compute a map that for a given dimension of the expanded type gives the
-/// dimension in the collapsed type it maps to. Essentially its the inverse of
-/// the `reassocation` maps.
-static DenseMap<int64_t, int64_t>
-getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation) {
-  llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
-  for (auto &en : enumerate(reassociation)) {
-    for (auto dim : en.value())
-      expandedDimToCollapsedDim[dim] = en.index();
-  }
-  return expandedDimToCollapsedDim;
-}
-
-static OpFoldResult
-getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType,
-                         int64_t outDimIndex, ArrayRef<int64_t> outStaticShape,
-                         MemRefDescriptor &inDesc,
-                         ArrayRef<int64_t> inStaticShape,
-                         ArrayRef<ReassociationIndices> reassocation,
-                         DenseMap<int64_t, int64_t> &outDimToInDimMap) {
-  int64_t outDimSize = outStaticShape[outDimIndex];
-  if (!ShapedType::isDynamic(outDimSize))
-    return b.getIndexAttr(outDimSize);
-
-  // Calculate the multiplication of all the out dim sizes except the
-  // current dim.
-  int64_t inDimIndex = outDimToInDimMap[outDimIndex];
-  int64_t otherDimSizesMul = 1;
-  for (auto otherDimIndex : reassocation[inDimIndex]) {
-    if (otherDimIndex == static_cast<unsigned>(outDimIndex))
-      continue;
-    int64_t otherDimSize = outStaticShape[otherDimIndex];
-    assert(!ShapedType::isDynamic(otherDimSize) &&
-           "single dimension cannot be expanded into multiple dynamic "
-           "dimensions");
-    otherDimSizesMul *= otherDimSize;
-  }
-
-  // outDimSize = inDimSize / otherOutDimSizesMul
-  int64_t inDimSize = inStaticShape[inDimIndex];
-  Value inDimSizeDynamic =
-      ShapedType::isDynamic(inDimSize)
-          ? inDesc.size(b, loc, inDimIndex)
-          : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
-                                       b.getIndexAttr(inDimSize));
-  Value outDimSizeDynamic = b.create<LLVM::SDivOp>(
-      loc, inDimSizeDynamic,
-      b.create<LLVM::ConstantOp>(loc, llvmIndexType,
-                                 b.getIndexAttr(otherDimSizesMul)));
-  return outDimSizeDynamic;
-}
-
-static OpFoldResult getCollapsedOutputDimSize(
-    OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex,
-    int64_t outDimSize, ArrayRef<int64_t> inStaticShape,
-    MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) {
-  if (!ShapedType::isDynamic(outDimSize))
-    return b.getIndexAttr(outDimSize);
-
-  Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1));
-  Value outDimSizeDynamic = c1;
-  for (auto inDimIndex : reassocation[outDimIndex]) {
-    int64_t inDimSize = inStaticShape[inDimIndex];
-    Value inDimSizeDynamic =
-        ShapedType::isDynamic(inDimSize)
-            ? inDesc.size(b, loc, inDimIndex)
-            : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
-                                         b.getIndexAttr(inDimSize));
-    outDimSizeDynamic =
-        b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic);
-  }
-  return outDimSizeDynamic;
-}
-
-static SmallVector<OpFoldResult, 4>
-getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
-                        ArrayRef<ReassociationIndices> reassociation,
-                        ArrayRef<int64_t> inStaticShape,
-                        MemRefDescriptor &inDesc,
-                        ArrayRef<int64_t> outStaticShape) {
-  return llvm::to_vector<4>(llvm::map_range(
-      llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
-        return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
-                                         outStaticShape[outDimIndex],
-                                         inStaticShape, inDesc, reassociation);
-      }));
-}
-
-static SmallVector<OpFoldResult, 4>
-getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
-                       ArrayRef<ReassociationIndices> reassociation,
-                       ArrayRef<int64_t> inStaticShape,
-                       MemRefDescriptor &inDesc,
-                       ArrayRef<int64_t> outStaticShape) {
-  DenseMap<int64_t, int64_t> outDimToInDimMap =
-      getExpandedDimToCollapsedDimMap(reassociation);
-  return llvm::to_vector<4>(llvm::map_range(
-      llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
-        return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
-                                        outStaticShape, inDesc, inStaticShape,
-                                        reassociation, outDimToInDimMap);
-      }));
-}
-
-static SmallVector<Value>
-getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
-                      ArrayRef<ReassociationIndices> reassociation,
-                      ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc,
-                      ArrayRef<int64_t> outStaticShape) {
-  return outStaticShape.size() < inStaticShape.size()
-             ? getAsValues(b, loc, llvmIndexType,
-                           getCollapsedOutputShape(b, loc, llvmIndexType,
-                                                   reassociation, inStaticShape,
-                                                   inDesc, outStaticShape))
-             : getAsValues(b, loc, llvmIndexType,
-                           getExpandedOutputShape(b, loc, llvmIndexType,
-                                                  reassociation, inStaticShape,
-                                                  inDesc, outStaticShape));
-}
-
-static void fillInStridesForExpandedMemDescriptor(
-    OpBuilder &b, Location loc, MemRefType srcType, MemRefDescriptor &srcDesc,
-    MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) {
-  // See comments for computeExpandedLayoutMap for details on how the strides
-  // are calculated.
-  for (auto &en : llvm::enumerate(reassociation)) {
-    auto currentStrideToExpand = srcDesc.stride(b, loc, en.index());
-    for (auto dstIndex : llvm::reverse(en.value())) {
-      dstDesc.setStride(b, loc, dstIndex, currentStrideToExpand);
-      Value size = dstDesc.size(b, loc, dstIndex);
-      currentStrideToExpand =
-          b.create<LLVM::MulOp>(loc, size, currentStrideToExpand);
-    }
-  }
-}
-
-static void fillInStridesForCollapsedMemDescriptor(
-    ConversionPatternRewriter &rewriter, Location loc, Operation *op,
-    TypeConverter *typeConverter, MemRefType srcType, MemRefDescriptor &srcDesc,
-    MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) {
-  auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
-  // See comments for computeCollapsedLayoutMap for details on how the strides
-  // are calculated.
-  auto srcShape = srcType.getShape();
-  for (auto &en : llvm::enumerate(reassociation)) {
-    rewriter.setInsertionPoint(op);
-    auto dstIndex = en.index();
-    ArrayRef<int64_t> ref = llvm::ArrayRef(en.value());
-    while (srcShape[ref.back()] == 1 && ref.size() > 1)
-      ref = ref.drop_back();
-    if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
-      dstDesc.setStride(rewriter, loc, dstIndex,
-                        srcDesc.stride(rewriter, loc, ref.back()));
-    } else {
-      // Iterate over the source strides in reverse order. Skip over the
-      // dimensions whose size is 1.
-      // TODO: we should take the minimum stride in the reassociation group
-      // instead of just the first where the dimension is not 1.
-      //
-      // +------------------------------------------------------+
-      // | curEntry:                                            |
-      // |   %srcStride = strides[srcIndex]                     |
-      // |   %neOne = cmp sizes[srcIndex],1                     +--+
-      // |   cf.cond_br %neOne, continue(%srcStride), nextEntry |  |
-      // +-------------------------+----------------------------+  |
-      //                           |                               |
-      //                           v                               |
-      //            +-----------------------------+                |
-      //            | nextEntry:                  |                |
-      //            |   ...                       +---+            |
-      //            +--------------+--------------+   |            |
-      //                           |                  |            |
-      //                           v                  |            |
-      //            +-----------------------------+   |            |
-      //            | nextEntry:                  |   |            |
-      //            |   ...                       |   |            |
-      //            +--------------+--------------+   |   +--------+
-      //                           |                  |   |
-      //                           v                  v   v
-      //   +--------------------------------------------------+
-      //   | continue(%newStride):                            |
-      //   |   %newMemRefDes = setStride(%newStride,dstIndex) |
-      //   +--------------------------------------------------+
-      OpBuilder::InsertionGuard guard(rewriter);
-      Block *initBlock = rewriter.getInsertionBlock();
-      Block *continueBlock =
-          rewriter.splitBlock(initBlock, rewriter.getInsertionPoint());
-      continueBlock->insertArgument(unsigned(0), srcDesc.getIndexType(), loc);
-      rewriter.setInsertionPointToStart(continueBlock);
-      dstDesc.setStride(rewriter, loc, dstIndex, continueBlock->getArgument(0));
-
-      Block *curEntryBlock = initBlock;
-      Block *nextEntryBlock;
-      for (auto srcIndex : llvm::reverse(ref)) {
-        if (srcShape[srcIndex] == 1 && srcIndex != ref.front())
-          continue;
-        rewriter.setInsertionPointToEnd(curEntryBlock);
-        Value srcStride = srcDesc.stride(rewriter, loc, srcIndex);
-        if (srcIndex == ref.front()) {
-          rewriter.create<LLVM::BrOp>(loc, srcStride, continueBlock);
-          break;
-        }
-        Value one = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType,
-                                                      rewriter.getIndexAttr(1));
-        Value predNeOne = rewriter.create<LLVM::ICmpOp>(
-            loc, LLVM::ICmpPredicate::ne, srcDesc.size(rewriter, loc, srcIndex),
-            one);
-        {
-          OpBuilder::InsertionGuard guard(rewriter);
-          nextEntryBlock = rewriter.createBlock(
-              initBlock->getParent(), Region::iterator(continueBlock), {});
-        }
-        rewriter.create<LLVM::CondBrOp>(loc, predNeOne, continueBlock,
-                                        srcStride, nextEntryBlock,
-                                        std::nullopt);
-        curEntryBlock = nextEntryBlock;
-      }
-    }
-  }
-}
-
-static void fillInDynamicStridesForMemDescriptor(
-    ConversionPatternRewriter &b, Location loc, Operation *op,
-    TypeConverter *typeConverter, MemRefType srcType, MemRefType dstType,
-    MemRefDescriptor &srcDesc, MemRefDescriptor &dstDesc,
-    ArrayRef<ReassociationIndices> reassociation) {
-  if (srcType.getRank() > dstType.getRank())
-    fillInStridesForCollapsedMemDescriptor(b, loc, op, typeConverter, srcType,
-                                           srcDesc, dstDesc, reassociation);
-  else
-    fillInStridesForExpandedMemDescriptor(b, loc, srcType, srcDesc, dstDesc,
-                                          reassociation);
-}
-
-// ReshapeOp creates a new view descriptor of the proper rank.
-// For now, the only conversion supported is for target MemRef with static sizes
-// and strides.
+/// RessociatingReshapeOp must be expanded before we reach this stage.
+/// Report that information.
 template <typename ReshapeOp>
 class ReassociatingReshapeOpConversion
     : public ConvertOpToLLVMPattern<ReshapeOp> {
@@ -1634,56 +1386,9 @@ class ReassociatingReshapeOpConversion
   LogicalResult
   matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    MemRefType dstType = reshapeOp.getResultType();
-    MemRefType srcType = reshapeOp.getSrcType();
-
-    int64_t offset;
-    SmallVector<int64_t, 4> strides;
-    if (failed(getStridesAndOffset(dstType, strides, offset))) {
-      return rewriter.notifyMatchFailure(
-          reshapeOp, "failed to get stride and offset exprs");
-    }
-
-    MemRefDescriptor srcDesc(adaptor.getSrc());
-    Location loc = reshapeOp->getLoc();
-    auto dstDesc = MemRefDescriptor::undef(
-        rewriter, loc, this->typeConverter->convertType(dstType));
-    dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc));
-    dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc));
-    dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc));
-
-    ArrayRef<int64_t> srcStaticShape = srcType.getShape();
-    ArrayRef<int64_t> dstStaticShape = dstType.getShape();
-    Type llvmIndexType =
-        this->typeConverter->convertType(rewriter.getIndexType());
-    SmallVector<Value> dstShape = getDynamicOutputShape(
-        rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(),
-        srcStaticShape, srcDesc, dstStaticShape);
-    for (auto &en : llvm::enumerate(dstShape))
-      dstDesc.setSize(rewriter, loc, en.index(), en.value());
-
-    if (llvm::all_of(strides, isStaticStrideOrOffset)) {
-      for (auto &en : llvm::enumerate(strides))
-        dstDesc.setConstantStride(rewriter, loc, en.index(), en.value());
-    } else if (srcType.getLayout().isIdentity() &&
-               dstType.getLayout().isIdentity()) {
-      Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType,
-                                                   rewriter.getIndexAttr(1));
-      Value stride = c1;
-      for (auto dimIndex :
-           llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) {
-        dstDesc.setStride(rewriter, loc, dimIndex, stride);
-        stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride);
-      }
-    } else {
-      // There could be mixed static/dynamic strides. For simplicity, we
-      // recompute all strides if there is at least one dynamic stride.
-      fillInDynamicStridesForMemDescriptor(
-          rewriter, loc, reshapeOp, this->typeConverter, srcType, dstType,
-          srcDesc, dstDesc, reshapeOp.getReassociationIndices());
-    }
-    rewriter.replaceOp(reshapeOp, {dstDesc});
-    return success();
+    return rewriter.notifyMatchFailure(
+        reshapeOp,
+        "reassociation operations should have been expanded beforehand");
   }
 };
 

diff  --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
index 061cc6d17a29f..51650adc9a9ff 100644
--- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
@@ -373,3 +373,298 @@ func.func @subview_negative_stride(%arg0 : memref<7xf32>) -> memref<7xf32, strid
   %0 = memref.subview %arg0[6] [7] [-1] : memref<7xf32> to memref<7xf32, strided<[-1], offset: 6>>
   return %0 : memref<7xf32, strided<[-1], offset: 6>>
 }
+
+// -----
+
+func.func @collapse_shape_static(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> {
+  %0 = memref.collapse_shape %arg0 [[0, 1], [2], [3, 4]] :
+    memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
+  return %0 : memref<3x4x5xf32>
+}
+// CHECK-LABEL: func @collapse_shape_static
+// CHECK-SAME: %[[ARG:.*]]: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> {
+// CHECK:           %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x3x4x1x5xf32> to !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK:           %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK:           %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK:           %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK:           %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[C3:.*]] = llvm.mlir.constant(3 : index) : i64
+// CHECK:           %[[DESC3:.*]] = llvm.insertvalue %[[C3]], %[[DESC2]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[C20:.*]] = llvm.mlir.constant(20 : index) : i64
+// CHECK:           %[[DESC4:.*]] = llvm.insertvalue %[[C20]], %[[DESC3]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[C4:.*]] = llvm.mlir.constant(4 : index) : i64
+// CHECK:           %[[DESC5:.*]] = llvm.insertvalue %[[C4]], %[[DESC4]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[C5:.*]] = llvm.mlir.constant(5 : index) : i64
+// CHECK:           %[[DESC6:.*]] = llvm.insertvalue %[[C5]], %[[DESC5]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[DESC7:.*]] = llvm.insertvalue %[[C5]], %[[DESC6]][3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
+// CHECK:           %[[DESC8:.*]] = llvm.insertvalue %[[C1]], %[[DESC7]][4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC8]] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)> to memref<3x4x5xf32>
+// CHECK:           return %[[RES]] : memref<3x4x5xf32>
+// CHECK:         }
+
+// -----
+
+func.func @collapse_shape_dynamic_with_non_identity_layout(
+        %arg0 : memref<4x?x?xf32, strided<[?, 4, 1], offset: ?>>) ->
+        memref<4x?xf32, strided<[?, ?], offset: ?>> {
+  %0 = memref.collapse_shape %arg0 [[0], [1, 2]]:
+    memref<4x?x?xf32, strided<[?, 4, 1], offset: ?>> into
+    memref<4x?xf32, strided<[?, ?], offset: ?>>
+  return %0 : memref<4x?xf32, strided<[?, ?], offset: ?>>
+}
+// CHECK-LABEL:   func.func @collapse_shape_dynamic_with_non_identity_layout(
+// CHECK-SAME:                                                               %[[ARG:.*]]: memref<4x?x?xf32, strided<[?, 4, 1], offset: ?>>) -> memref<4x?xf32, strided<[?, ?], offset: ?>> {
+// CHECK:           %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<4x?x?xf32, strided<[?, 4, 1], offset: ?>> to !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+// CHECK:           %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+// CHECK:           %[[OFFSET:.*]] = llvm.extractvalue %[[MEM]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[SIZE1:.*]] = llvm.extractvalue %[[MEM]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[SIZE2:.*]] = llvm.extractvalue %[[MEM]][3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[STRIDE0_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[STRIDE0]] : i64 to index
+// CHECK:           %[[STRIDE0:.*]] = builtin.unrealized_conversion_cast %[[STRIDE0_TO_IDX]] : index to i64
+// CHECK:           %[[FINAL_SIZE1:.*]] = llvm.mul %[[SIZE1]], %[[SIZE2]]  : i64
+// CHECK:           %[[SIZE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE1]] : i64 to index
+// CHECK:           %[[FINAL_SIZE1:.*]] = builtin.unrealized_conversion_cast %[[SIZE1_TO_IDX]] : index to i64
+// CHECK:           %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[DESC2:.*]] = llvm.insertvalue %[[OFFSET]], %[[DESC1]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[C4:.*]] = llvm.mlir.constant(4 : index) : i64
+// CHECK:           %[[DESC3:.*]] = llvm.insertvalue %[[C4]], %[[DESC2]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[DESC4:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESC3]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[DESC5:.*]] = llvm.insertvalue %[[FINAL_SIZE1]], %[[DESC4]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
+// CHECK:           %[[DESC6:.*]] = llvm.insertvalue %[[C1]], %[[DESC5]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC6]] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)> to memref<4x?xf32, strided<[?, ?], offset: ?>>
+// CHECK:           return %[[RES]] : memref<4x?xf32, strided<[?, ?], offset: ?>>
+// CHECK:         }
+// CHECK32-LABEL: func @collapse_shape_dynamic_with_non_identity_layout(
+//       CHECK32:      llvm.mlir.constant(1 : index) : i32
+//       CHECK32:      llvm.mlir.constant(4 : index) : i32
+//       CHECK32:      llvm.mlir.constant(1 : index) : i32
+
+// -----
+
+
+func.func @expand_shape_static(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> {
+  // Reshapes that expand a contiguous tensor with some 1's.
+  %0 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]]
+      : memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
+  return %0 : memref<1x3x4x1x5xf32>
+}
+// CHECK-LABEL: func @expand_shape_static
+// CHECK-SAME: %[[ARG:.*]]: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> {
+// CHECK:           %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<3x4x5xf32> to !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+// CHECK:           %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+// CHECK:           %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK:           %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK:           %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK:           %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK:           %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK:           %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
+// CHECK:           %[[DESC3:.*]] = llvm.insertvalue %[[C1]], %[[DESC2]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK:           %[[C60:.*]] = llvm.mlir.constant(60 : index) : i64
+// CHECK:           %[[DESC4:.*]] = llvm.insertvalue %[[C60]], %[[DESC3]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK:           %[[C3:.*]] = llvm.mlir.constant(3 : index) : i64
+// CHECK:           %[[DESC5:.*]] = llvm.insertvalue %[[C3]], %[[DESC4]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK:           %[[C20:.*]] = llvm.mlir.constant(20 : index) : i64
+// CHECK:           %[[DESC6:.*]] = llvm.insertvalue %[[C20]], %[[DESC5]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK:           %[[C4:.*]] = llvm.mlir.constant(4 : index) : i64
+// CHECK:           %[[DESC7:.*]] = llvm.insertvalue %[[C4]], %[[DESC6]][3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK:           %[[C5:.*]] = llvm.mlir.constant(5 : index) : i64
+// CHECK:           %[[DESC8:.*]] = llvm.insertvalue %[[C5]], %[[DESC7]][4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK:           %[[DESC9:.*]] = llvm.insertvalue %[[C1]], %[[DESC8]][3, 3] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK:           %[[DESC10:.*]] = llvm.insertvalue %[[C5]], %[[DESC9]][4, 3] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK:           %[[DESC11:.*]] = llvm.insertvalue %[[C5]], %[[DESC10]][3, 4] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK:           %[[DESC12:.*]] = llvm.insertvalue %[[C1]], %[[DESC11]][4, 4] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK:           %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC12]] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)> to memref<1x3x4x1x5xf32>
+// CHECK:           return %[[RES]] : memref<1x3x4x1x5xf32>
+// CHECK:         }
+
+// -----
+
+func.func @collapse_shape_fold_zero_dim(%arg0 : memref<1x1xf32>) -> memref<f32> {
+  %0 = memref.collapse_shape %arg0 [] : memref<1x1xf32> into memref<f32>
+  return %0 : memref<f32>
+}
+// CHECK-LABEL:   func.func @collapse_shape_fold_zero_dim(
+// CHECK-SAME:                                            %[[ARG:.*]]: memref<1x1xf32>) -> memref<f32> {
+// CHECK:           %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x1xf32> to !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+// CHECK:           %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+// CHECK:           %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+// CHECK:           %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+// CHECK:           %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+// CHECK:           %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK:           %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+// CHECK:           %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC2]] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)> to memref<f32>
+// CHECK:           return %[[RES]] : memref<f32>
+// CHECK:         }
+
+// -----
+
+func.func @expand_shape_zero_dim(%arg0 : memref<f32>) -> memref<1x1xf32> {
+  %0 = memref.expand_shape %arg0 [] : memref<f32> into memref<1x1xf32>
+  return %0 : memref<1x1xf32>
+}
+
+// CHECK-LABEL:   func.func @expand_shape_zero_dim(
+// CHECK-SAME:                                     %[[ARG:.*]]: memref<f32>) -> memref<1x1xf32> {
+// CHECK:           %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<f32> to !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+// CHECK:           %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+// CHECK:           %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+// CHECK:           %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK:           %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
+// CHECK:           %[[DESC3:.*]] = llvm.insertvalue %[[C1]], %[[DESC2]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[DESC4:.*]] = llvm.insertvalue %[[C1]], %[[DESC3]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[DESC5:.*]] = llvm.insertvalue %[[C1]], %[[DESC4]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[DESC6:.*]] = llvm.insertvalue %[[C1]], %[[DESC5]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC6]] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)> to memref<1x1xf32>
+// CHECK:           return %[[RES]] : memref<1x1xf32>
+// CHECK:         }
+
+// -----
+
+func.func @collapse_shape_dynamic(%arg0 : memref<1x2x?xf32>) -> memref<1x?xf32> {
+  %0 = memref.collapse_shape %arg0 [[0], [1, 2]]:  memref<1x2x?xf32> into memref<1x?xf32>
+  return %0 : memref<1x?xf32>
+}
+
+// CHECK-LABEL:   func.func @collapse_shape_dynamic(
+// CHECK-SAME:                                      %[[ARG:.*]]: memref<1x2x?xf32>) -> memref<1x?xf32> {
+// CHECK:           %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x2x?xf32> to !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+// CHECK:           %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+// CHECK:           %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK:           %[[SIZE2:.*]] = llvm.extractvalue %[[MEM]][3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[STRIDE1:.*]] = llvm.extractvalue %[[MEM]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[C2:.*]] = llvm.mlir.constant(2 : index) : i64
+// CHECK:           %[[FINAL_SIZE1:.*]] = llvm.mul %[[SIZE2]], %[[C2]]  : i64
+// CHECK:           %[[SIZE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE1]] : i64 to index
+// CHECK:           %[[FINAL_SIZE1:.*]] = builtin.unrealized_conversion_cast %[[SIZE1_TO_IDX]] : index to i64
+// CHECK:           %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
+// CHECK:           %[[IS_MIN_STRIDE1:.*]] = llvm.icmp "slt" %[[STRIDE1]], %[[C1]] : i64
+// CHECK:           %[[MIN_STRIDE1:.*]] = llvm.select %[[IS_MIN_STRIDE1]], %[[STRIDE1]], %[[C1]] : i1, i64
+// CHECK:           %[[MIN_STRIDE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[MIN_STRIDE1]] : i64 to index
+// CHECK:           %[[MIN_STRIDE1:.*]] = builtin.unrealized_conversion_cast %[[MIN_STRIDE1_TO_IDX]] : index to i64
+// CHECK:           %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[DESC3:.*]] = llvm.insertvalue %[[C1]], %[[DESC2]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[DESC4:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESC3]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[DESC5:.*]] = llvm.insertvalue %[[FINAL_SIZE1]], %[[DESC4]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[DESC6:.*]] = llvm.insertvalue %[[MIN_STRIDE1]], %[[DESC5]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC6]] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)> to memref<1x?xf32>
+// CHECK:           return %[[RES]] : memref<1x?xf32>
+// CHECK:         }
+
+// -----
+
+func.func @expand_shape_dynamic(%arg0 : memref<1x?xf32>) -> memref<1x2x?xf32> {
+  %0 = memref.expand_shape %arg0 [[0], [1, 2]]: memref<1x?xf32> into memref<1x2x?xf32>
+  return %0 : memref<1x2x?xf32>
+}
+
+// CHECK-LABEL:   func.func @expand_shape_dynamic(
+// CHECK-SAME:                                    %[[ARG:.*]]: memref<1x?xf32>) -> memref<1x2x?xf32> {
+// CHECK:           %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x?xf32> to !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+// CHECK:           %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+// CHECK:           %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK:           %[[SIZE1:.*]] = llvm.extractvalue %[[MEM]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[C2:.*]] = llvm.mlir.constant(2 : index) : i64
+// CHECK:           %[[CMINUS1:.*]] = llvm.mlir.constant(-1 : index) : i64
+// CHECK:           %[[IS_NEGATIVE_SIZE1:.*]] = llvm.icmp "slt" %[[SIZE1]], %[[C0]] : i64
+// CHECK:           %[[ABS_SIZE1_MINUS_1:.*]] = llvm.sub %[[CMINUS1]], %[[SIZE1]]  : i64
+// CHECK:           %[[ADJ_SIZE1:.*]] = llvm.select %[[IS_NEGATIVE_SIZE1]], %[[ABS_SIZE1_MINUS_1]], %[[SIZE1]] : i1, i64
+// CHECK:           %[[SIZE2:.*]] = llvm.sdiv %[[ADJ_SIZE1]], %[[C2]]  : i64
+// CHECK:           %[[NEGATIVE_SIZE2:.*]] = llvm.sub %[[CMINUS1]], %[[SIZE2]]  : i64
+// CHECK:           %[[FINAL_SIZE2:.*]] = llvm.select %[[IS_NEGATIVE_SIZE1]], %[[NEGATIVE_SIZE2]], %[[SIZE2]] : i1, i64
+// CHECK:           %[[SIZE2_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE2]] : i64 to index
+// CHECK:           %[[FINAL_SIZE2:.*]] = builtin.unrealized_conversion_cast %[[SIZE2_TO_IDX]] : index to i64
+// CHECK:           %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
+// CHECK:           %[[DESC3:.*]] = llvm.insertvalue %[[C1]], %[[DESC2]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[DESC4:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESC3]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[DESC5:.*]] = llvm.insertvalue %[[C2]], %[[DESC4]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// In this example stride1 and size2 are the same.
+// Hence with CSE, we get the same SSA value.
+// CHECK:           %[[DESC6:.*]] = llvm.insertvalue %[[FINAL_SIZE2]], %[[DESC5]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[DESC7:.*]] = llvm.insertvalue %[[FINAL_SIZE2]], %[[DESC6]][3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[DESC8:.*]] = llvm.insertvalue %[[C1]], %[[DESC7]][4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC8]] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)> to memref<1x2x?xf32>
+// CHECK:           return %[[RES]] : memref<1x2x?xf32>
+// CHECK:         }
+
+// -----
+
+func.func @expand_shape_dynamic_with_non_identity_layout(
+            %arg0 : memref<1x?xf32, strided<[?, ?], offset: ?>>) ->
+            memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> {
+  %0 = memref.expand_shape %arg0 [[0], [1, 2]]:
+    memref<1x?xf32, strided<[?, ?], offset: ?>> into
+    memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>>
+  return %0 : memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>>
+}
+// CHECK-LABEL:   func.func @expand_shape_dynamic_with_non_identity_layout(
+// CHECK-SAME:                                                             %[[ARG:.*]]: memref<1x?xf32, strided<[?, ?], offset: ?>>) -> memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> {
+// CHECK:           %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x?xf32, strided<[?, ?], offset: ?>> to !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+// CHECK:           %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+// CHECK:           %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK:           %[[OFFSET:.*]] = llvm.extractvalue %[[MEM]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[SIZE1:.*]] = llvm.extractvalue %[[MEM]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[STRIDE1:.*]] = llvm.extractvalue %[[MEM]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[C2:.*]] = llvm.mlir.constant(2 : index) : i64
+// CHECK:           %[[CMINUS1:.*]] = llvm.mlir.constant(-1 : index) : i64
+// CHECK:           %[[IS_NEGATIVE_SIZE1:.*]] = llvm.icmp "slt" %[[SIZE1]], %[[C0]] : i64
+// CHECK:           %[[ABS_SIZE1_MINUS_1:.*]] = llvm.sub %[[CMINUS1]], %[[SIZE1]]  : i64
+// CHECK:           %[[ADJ_SIZE1:.*]] = llvm.select %[[IS_NEGATIVE_SIZE1]], %[[ABS_SIZE1_MINUS_1]], %[[SIZE1]] : i1, i64
+// CHECK:           %[[SIZE2:.*]] = llvm.sdiv %[[ADJ_SIZE1]], %[[C2]]  : i64
+// CHECK:           %[[NEGATIVE_SIZE2:.*]] = llvm.sub %[[CMINUS1]], %[[SIZE2]]  : i64
+// CHECK:           %[[TMP_SIZE2:.*]] = llvm.select %[[IS_NEGATIVE_SIZE1]], %[[NEGATIVE_SIZE2]], %[[SIZE2]] : i1, i64
+// CHECK:           %[[SIZE2_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[TMP_SIZE2]] : i64 to index
+// CHECK:           %[[FINAL_SIZE2:.*]] = builtin.unrealized_conversion_cast %[[SIZE2_TO_IDX]] : index to i64
+// CHECK:           %[[FINAL_STRIDE1:.*]] = llvm.mul %[[TMP_SIZE2]], %[[STRIDE1]]
+// CHECK:           %[[STRIDE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_STRIDE1]] : i64 to index
+// CHECK:           %[[FINAL_STRIDE1:.*]] = builtin.unrealized_conversion_cast %[[STRIDE1_TO_IDX]] : index to i64
+// CHECK:           %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[DESC1:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[DESC2:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC1]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[DESC3:.*]] = llvm.insertvalue %[[OFFSET]], %[[DESC2]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
+// CHECK:           %[[DESC4:.*]] = llvm.insertvalue %[[C1]], %[[DESC3]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[DESC5:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESC4]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[DESC6:.*]] = llvm.insertvalue %[[C2]], %[[DESC5]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[DESC7:.*]] = llvm.insertvalue %[[FINAL_STRIDE1]], %[[DESC6]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[DESC8:.*]] = llvm.insertvalue %[[FINAL_SIZE2]], %[[DESC7]][3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[DESC9:.*]] = llvm.insertvalue %[[STRIDE1]], %[[DESC8]][4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC9]] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)> to memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>>
+// CHECK:           return %[[RES]] : memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>>
+// CHECK:         }
+
+// -----
+
+// CHECK-LABEL: func @collapse_static_shape_with_non_identity_layout
+func.func @collapse_static_shape_with_non_identity_layout(%arg: memref<1x1x8x8xf32, strided<[64, 64, 8, 1], offset: ?>>) -> memref<64xf32, strided<[1], offset: ?>> {
+// CHECK-NOT: memref.collapse_shape
+  %1 = memref.collapse_shape %arg [[0, 1, 2, 3]] : memref<1x1x8x8xf32, strided<[64, 64, 8, 1], offset: ?>> into memref<64xf32, strided<[1], offset: ?>>
+  return %1 : memref<64xf32, strided<[1], offset: ?>>
+}

diff  --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index c12a07881cd46..1a8a75d1e2b35 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -290,246 +290,27 @@ memref.global "private" @gv4 : memref<f32> = dense<1.0> {alignment = 64}
 
 // -----
 
-func.func @collapse_shape_static(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> {
-  %0 = memref.collapse_shape %arg0 [[0, 1], [2], [3, 4]] :
-    memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
-  return %0 : memref<3x4x5xf32>
-}
-// CHECK-LABEL: func @collapse_shape_static
-//       CHECK:    llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:    llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:    llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:    llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:    llvm.mlir.constant(3 : index) : i64
-//       CHECK:    llvm.mlir.constant(4 : index) : i64
-//       CHECK:    llvm.mlir.constant(5 : index) : i64
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:    llvm.mlir.constant(20 : index) : i64
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:    llvm.mlir.constant(5 : index) : i64
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:    llvm.mlir.constant(1 : index) : i64
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-
-// -----
-
-func.func @collapse_shape_dynamic_with_non_identity_layout(
-        %arg0 : memref<4x?x?xf32, strided<[?, 4, 1], offset: ?>>) ->
-        memref<4x?xf32, strided<[?, ?], offset: ?>> {
-  %0 = memref.collapse_shape %arg0 [[0], [1, 2]]:
-    memref<4x?x?xf32, strided<[?, 4, 1], offset: ?>> into
-    memref<4x?xf32, strided<[?, ?], offset: ?>>
-  return %0 : memref<4x?xf32, strided<[?, ?], offset: ?>>
-}
-// CHECK-LABEL:   func @collapse_shape_dynamic_with_non_identity_layout(
-//       CHECK:      llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-//       CHECK:      llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:      llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-//       CHECK:      llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:      llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-//       CHECK:      llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:      llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-//       CHECK:      llvm.mlir.constant(1 : index) : i64
-//       CHECK:      llvm.extractvalue %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:      llvm.mul %{{.*}}, %{{.*}}  : i64
-//       CHECK:      llvm.extractvalue %{{.*}}[3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:      llvm.mul %{{.*}}, %{{.*}}  : i64
-//       CHECK:      llvm.mlir.constant(4 : index) : i64
-//       CHECK:      llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-//       CHECK:      llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-//       CHECK:      llvm.extractvalue %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:      llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-//       CHECK:      llvm.extractvalue %{{.*}}[4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:      llvm.mlir.constant(1 : index) : i64
-//       CHECK:      llvm.extractvalue %{{.*}}[3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:      llvm.icmp "ne" %{{.*}}, %{{.*}} : i64
-//       CHECK:      llvm.cond_br %{{.*}}, ^bb2(%{{.*}} : i64), ^bb1
-//       CHECK:      ^bb1:
-//       CHECK:      llvm.extractvalue %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:      llvm.br ^bb2(%{{.*}} : i64)
-//       CHECK:      ^bb2(%[[STRIDE:.*]]: i64):
-//       CHECK:      llvm.insertvalue %[[STRIDE]], %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK32-LABEL: func @collapse_shape_dynamic_with_non_identity_layout(
-//       CHECK32:      llvm.mlir.constant(1 : index) : i32
-//       CHECK32:      llvm.mlir.constant(4 : index) : i32
-//       CHECK32:      llvm.mlir.constant(1 : index) : i32
-
-// -----
-
+// Expand shapes need to be expanded outside of the memref-to-llvm pass.
+// CHECK-LABEL: func @expand_shape_static(
+// CHECK-SAME:         %[[ARG:.*]]: memref<{{.*}}>)
 func.func @expand_shape_static(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> {
+  // CHECK: memref.expand_shape %[[ARG]] {{\[}}[0, 1], [2], [3, 4]]
   // Reshapes that expand a contiguous tensor with some 1's.
   %0 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]]
       : memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
   return %0 : memref<1x3x4x1x5xf32>
 }
-// CHECK-LABEL: func @expand_shape_static
-//       CHECK:    llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-//       CHECK:    llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-//       CHECK:    llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-//       CHECK:    llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-//       CHECK:    llvm.mlir.constant(1 : index) : i64
-//       CHECK:    llvm.mlir.constant(3 : index) : i64
-//       CHECK:    llvm.mlir.constant(4 : index) : i64
-//       CHECK:    llvm.mlir.constant(1 : index) : i64
-//       CHECK:    llvm.mlir.constant(5 : index) : i64
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[3, 3] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[3, 4] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-//       CHECK:    llvm.mlir.constant(60 : index) : i64
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-//       CHECK:    llvm.mlir.constant(20 : index) : i64
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-//       CHECK:    llvm.mlir.constant(5 : index) : i64
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-//       CHECK:    llvm.mlir.constant(5 : index) : i64
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[4, 3] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-//       CHECK:    llvm.mlir.constant(1 : index) : i64
-//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[4, 4] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-
-
-// -----
-
-func.func @collapse_shape_fold_zero_dim(%arg0 : memref<1x1xf32>) -> memref<f32> {
-  %0 = memref.collapse_shape %arg0 [] : memref<1x1xf32> into memref<f32>
-  return %0 : memref<f32>
-}
-// CHECK-LABEL: func @collapse_shape_fold_zero_dim
-//       CHECK:   llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
-//       CHECK:   llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
-//       CHECK:   llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
-//       CHECK:   llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
 
 // -----
 
-func.func @expand_shape_zero_dim(%arg0 : memref<f32>) -> memref<1x1xf32> {
-  %0 = memref.expand_shape %arg0 [] : memref<f32> into memref<1x1xf32>
-  return %0 : memref<1x1xf32>
-}
-// CHECK-LABEL: func @expand_shape_zero_dim
-//       CHECK:   llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-//       CHECK:   llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
-//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-//       CHECK:   llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
-//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-//       CHECK:   llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
-//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-//       CHECK:   llvm.mlir.constant(1 : index) : i64
-//       CHECK:   llvm.mlir.constant(1 : index) : i64
-//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-//       CHECK:   llvm.mlir.constant(1 : index) : i64
-//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-//       CHECK:   llvm.mlir.constant(1 : index) : i64
-//       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-
-// -----
-
-func.func @collapse_shape_dynamic(%arg0 : memref<1x2x?xf32>) -> memref<1x?xf32> {
-  %0 = memref.collapse_shape %arg0 [[0], [1, 2]]:  memref<1x2x?xf32> into memref<1x?xf32>
-  return %0 : memref<1x?xf32>
-}
-// CHECK-LABEL:   func @collapse_shape_dynamic(
-// CHECK:           llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           llvm.mlir.constant(1 : index) : i64
-// CHECK:           llvm.mlir.constant(2 : index) : i64
-// CHECK:           llvm.mul %{{.*}}, %{{.*}}  : i64
-// CHECK:           llvm.extractvalue %{{.*}}[3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           llvm.mul %{{.*}}, %{{.*}}  : i64
-// CHECK:           llvm.mlir.constant(1 : index) : i64
-// CHECK:           llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           llvm.mlir.constant(1 : index) : i64
-// CHECK:           llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           llvm.mul %{{.*}}, %{{.*}}  : i64
-// CHECK:           llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           llvm.mul %{{.*}}, %{{.*}}  : i64
-
-// -----
-
-func.func @expand_shape_dynamic(%arg0 : memref<1x?xf32>) -> memref<1x2x?xf32> {
-  %0 = memref.expand_shape %arg0 [[0], [1, 2]]: memref<1x?xf32> into memref<1x2x?xf32>
-  return %0 : memref<1x2x?xf32>
-}
-// CHECK-LABEL:   func @expand_shape_dynamic(
-// CHECK:           llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           llvm.extractvalue %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           llvm.mlir.constant(2 : index) : i64
-// CHECK:           llvm.sdiv %{{.*}}, %{{.*}}  : i64
-// CHECK:           llvm.mlir.constant(1 : index) : i64
-// CHECK:           llvm.mlir.constant(2 : index) : i64
-// CHECK:           llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           llvm.mlir.constant(1 : index) : i64
-// CHECK:           llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           llvm.mul %{{.*}}, %{{.*}}  : i64
-// CHECK:           llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           llvm.mul %{{.*}}, %{{.*}}  : i64
-// CHECK:           llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           llvm.mul %{{.*}}, %{{.*}}  : i64
-
-// -----
-
-func.func @expand_shape_dynamic_with_non_identity_layout(
-            %arg0 : memref<1x?xf32, strided<[?, ?], offset: ?>>) ->
-            memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> {
-  %0 = memref.expand_shape %arg0 [[0], [1, 2]]:
-    memref<1x?xf32, strided<[?, ?], offset: ?>> into
-    memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>>
-  return %0 : memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>>
+// Collapse shapes need to be expanded outside of the memref-to-llvm pass.
+// CHECK-LABEL: func @collapse_shape_static
+// CHECK-SAME: %[[ARG:.*]]: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> {
+func.func @collapse_shape_static(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> {
+  %0 = memref.collapse_shape %arg0 [[0, 1], [2], [3, 4]] :
+    memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
+  return %0 : memref<3x4x5xf32>
 }
-// CHECK-LABEL:   func @expand_shape_dynamic_with_non_identity_layout(
-// CHECK:           llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           llvm.extractvalue %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           llvm.mlir.constant(2 : index) : i64
-// CHECK:           llvm.sdiv %{{.*}}, %{{.*}}  : i64
-// CHECK:           llvm.mlir.constant(1 : index) : i64
-// CHECK:           llvm.mlir.constant(2 : index) : i64
-// CHECK:           llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           llvm.extractvalue %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           llvm.extractvalue %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           llvm.mul %{{.*}}, %{{.*}}  : i64
-// CHECK:           llvm.extractvalue %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           llvm.extractvalue %{{.*}}[3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           llvm.mul %{{.*}}, %{{.*}}  : i64
-// CHECK:           llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           llvm.extractvalue %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           llvm.mul %{{.*}}, %{{.*}}  : i64
 
 // -----
 
@@ -579,15 +360,6 @@ func.func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fv
 
 // -----
 
-// CHECK-LABEL: func @collapse_static_shape_with_non_identity_layout
-func.func @collapse_static_shape_with_non_identity_layout(%arg: memref<1x1x8x8xf32, strided<[64, 64, 8, 1], offset: ?>>) -> memref<64xf32, strided<[1], offset: ?>> {
-// CHECK-NOT: memref.collapse_shape
-  %1 = memref.collapse_shape %arg [[0, 1, 2, 3]] : memref<1x1x8x8xf32, strided<[64, 64, 8, 1], offset: ?>> into memref<64xf32, strided<[1], offset: ?>>
-  return %1 : memref<64xf32, strided<[1], offset: ?>>
-}
-
-// -----
-
 // CHECK-LABEL: func @generic_atomic_rmw
 func.func @generic_atomic_rmw(%I : memref<10xi32>, %i : index) {
   %x = memref.generic_atomic_rmw %I[%i] : memref<10xi32> {

diff  --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-collapse-tensor.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-collapse-tensor.mlir
index d4406c4314f2e..c4714089c2f1e 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/test-collapse-tensor.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-collapse-tensor.mlir
@@ -1,6 +1,7 @@
 // RUN: mlir-opt %s -linalg-bufferize \
 // RUN: -arith-bufferize -tensor-bufferize -func-bufferize \
 // RUN: -finalizing-bufferize -buffer-deallocation -convert-linalg-to-llvm \
+// RUN: -expand-strided-metadata -lower-affine \
 // RUN: -convert-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
 // RUN: mlir-cpu-runner -e main -entry-point-result=void \
 // RUN:   -shared-libs=%mlir_lib_dir/libmlir_runner_utils%shlibext \

diff  --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-expand-tensor.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-expand-tensor.mlir
index ce72fbc5a78a7..868d940d233de 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/test-expand-tensor.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-expand-tensor.mlir
@@ -1,6 +1,7 @@
 // RUN: mlir-opt %s -linalg-bufferize \
 // RUN: -arith-bufferize -tensor-bufferize -func-bufferize \
 // RUN: -finalizing-bufferize -buffer-deallocation -convert-linalg-to-llvm \
+// RUN: -expand-strided-metadata -lower-affine \
 // RUN: -convert-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
 // RUN: mlir-cpu-runner -e main -entry-point-result=void \
 // RUN:   -shared-libs=%mlir_lib_dir/libmlir_runner_utils%shlibext \


        


More information about the Mlir-commits mailing list