[Mlir-commits] [mlir] 3914836 - [mlir][MemRefToLLVM] Remove the code for lowering collaspe/expand_shape
Quentin Colombet
llvmlistbot at llvm.org
Mon Jan 23 07:48:34 PST 2023
Author: Quentin Colombet
Date: 2023-01-23T15:41:55Z
New Revision: 3914836273aa51f14121abd77cc6986cd3ccee11
URL: https://github.com/llvm/llvm-project/commit/3914836273aa51f14121abd77cc6986cd3ccee11
DIFF: https://github.com/llvm/llvm-project/commit/3914836273aa51f14121abd77cc6986cd3ccee11.diff
LOG: [mlir][MemRefToLLVM] Remove the code for lowering collaspe/expand_shape
collapse/expand_shape are supposed to be expanded before we hit the
lowering code.
The expansion is done with the pass called expand-strided-metadata.
This patch is NFC in spirit but not in practice because
expand-strided-metadata won't try to accomodate for "invalid" strides
for dynamic sizes that are 1 at runtime.
The previous code was broken in that respect too, but differently: it
handled only the case of row-major layouts.
That whole part is being reworked separately.
Differential Revision: https://reviews.llvm.org/D136483
Added:
Modified:
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
mlir/test/Integration/Dialect/Linalg/CPU/test-collapse-tensor.mlir
mlir/test/Integration/Dialect/Linalg/CPU/test-expand-tensor.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 4d5af81757d03..7132560ad8b1d 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1374,256 +1374,8 @@ struct MemRefReshapeOpLowering
}
};
-/// Helper function to convert a vector of `OpFoldResult`s into a vector of
-/// `Value`s.
-static SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
- Type &llvmIndexType,
- ArrayRef<OpFoldResult> valueOrAttrVec) {
- return llvm::to_vector<4>(
- llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
- if (auto attr = value.dyn_cast<Attribute>())
- return b.create<LLVM::ConstantOp>(loc, llvmIndexType, attr);
- return value.get<Value>();
- }));
-}
-
-/// Compute a map that for a given dimension of the expanded type gives the
-/// dimension in the collapsed type it maps to. Essentially its the inverse of
-/// the `reassocation` maps.
-static DenseMap<int64_t, int64_t>
-getExpandedDimToCollapsedDimMap(ArrayRef<ReassociationIndices> reassociation) {
- llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
- for (auto &en : enumerate(reassociation)) {
- for (auto dim : en.value())
- expandedDimToCollapsedDim[dim] = en.index();
- }
- return expandedDimToCollapsedDim;
-}
-
-static OpFoldResult
-getExpandedOutputDimSize(OpBuilder &b, Location loc, Type &llvmIndexType,
- int64_t outDimIndex, ArrayRef<int64_t> outStaticShape,
- MemRefDescriptor &inDesc,
- ArrayRef<int64_t> inStaticShape,
- ArrayRef<ReassociationIndices> reassocation,
- DenseMap<int64_t, int64_t> &outDimToInDimMap) {
- int64_t outDimSize = outStaticShape[outDimIndex];
- if (!ShapedType::isDynamic(outDimSize))
- return b.getIndexAttr(outDimSize);
-
- // Calculate the multiplication of all the out dim sizes except the
- // current dim.
- int64_t inDimIndex = outDimToInDimMap[outDimIndex];
- int64_t otherDimSizesMul = 1;
- for (auto otherDimIndex : reassocation[inDimIndex]) {
- if (otherDimIndex == static_cast<unsigned>(outDimIndex))
- continue;
- int64_t otherDimSize = outStaticShape[otherDimIndex];
- assert(!ShapedType::isDynamic(otherDimSize) &&
- "single dimension cannot be expanded into multiple dynamic "
- "dimensions");
- otherDimSizesMul *= otherDimSize;
- }
-
- // outDimSize = inDimSize / otherOutDimSizesMul
- int64_t inDimSize = inStaticShape[inDimIndex];
- Value inDimSizeDynamic =
- ShapedType::isDynamic(inDimSize)
- ? inDesc.size(b, loc, inDimIndex)
- : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
- b.getIndexAttr(inDimSize));
- Value outDimSizeDynamic = b.create<LLVM::SDivOp>(
- loc, inDimSizeDynamic,
- b.create<LLVM::ConstantOp>(loc, llvmIndexType,
- b.getIndexAttr(otherDimSizesMul)));
- return outDimSizeDynamic;
-}
-
-static OpFoldResult getCollapsedOutputDimSize(
- OpBuilder &b, Location loc, Type &llvmIndexType, int64_t outDimIndex,
- int64_t outDimSize, ArrayRef<int64_t> inStaticShape,
- MemRefDescriptor &inDesc, ArrayRef<ReassociationIndices> reassocation) {
- if (!ShapedType::isDynamic(outDimSize))
- return b.getIndexAttr(outDimSize);
-
- Value c1 = b.create<LLVM::ConstantOp>(loc, llvmIndexType, b.getIndexAttr(1));
- Value outDimSizeDynamic = c1;
- for (auto inDimIndex : reassocation[outDimIndex]) {
- int64_t inDimSize = inStaticShape[inDimIndex];
- Value inDimSizeDynamic =
- ShapedType::isDynamic(inDimSize)
- ? inDesc.size(b, loc, inDimIndex)
- : b.create<LLVM::ConstantOp>(loc, llvmIndexType,
- b.getIndexAttr(inDimSize));
- outDimSizeDynamic =
- b.create<LLVM::MulOp>(loc, outDimSizeDynamic, inDimSizeDynamic);
- }
- return outDimSizeDynamic;
-}
-
-static SmallVector<OpFoldResult, 4>
-getCollapsedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
- ArrayRef<ReassociationIndices> reassociation,
- ArrayRef<int64_t> inStaticShape,
- MemRefDescriptor &inDesc,
- ArrayRef<int64_t> outStaticShape) {
- return llvm::to_vector<4>(llvm::map_range(
- llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
- return getCollapsedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
- outStaticShape[outDimIndex],
- inStaticShape, inDesc, reassociation);
- }));
-}
-
-static SmallVector<OpFoldResult, 4>
-getExpandedOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
- ArrayRef<ReassociationIndices> reassociation,
- ArrayRef<int64_t> inStaticShape,
- MemRefDescriptor &inDesc,
- ArrayRef<int64_t> outStaticShape) {
- DenseMap<int64_t, int64_t> outDimToInDimMap =
- getExpandedDimToCollapsedDimMap(reassociation);
- return llvm::to_vector<4>(llvm::map_range(
- llvm::seq<int64_t>(0, outStaticShape.size()), [&](int64_t outDimIndex) {
- return getExpandedOutputDimSize(b, loc, llvmIndexType, outDimIndex,
- outStaticShape, inDesc, inStaticShape,
- reassociation, outDimToInDimMap);
- }));
-}
-
-static SmallVector<Value>
-getDynamicOutputShape(OpBuilder &b, Location loc, Type &llvmIndexType,
- ArrayRef<ReassociationIndices> reassociation,
- ArrayRef<int64_t> inStaticShape, MemRefDescriptor &inDesc,
- ArrayRef<int64_t> outStaticShape) {
- return outStaticShape.size() < inStaticShape.size()
- ? getAsValues(b, loc, llvmIndexType,
- getCollapsedOutputShape(b, loc, llvmIndexType,
- reassociation, inStaticShape,
- inDesc, outStaticShape))
- : getAsValues(b, loc, llvmIndexType,
- getExpandedOutputShape(b, loc, llvmIndexType,
- reassociation, inStaticShape,
- inDesc, outStaticShape));
-}
-
-static void fillInStridesForExpandedMemDescriptor(
- OpBuilder &b, Location loc, MemRefType srcType, MemRefDescriptor &srcDesc,
- MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) {
- // See comments for computeExpandedLayoutMap for details on how the strides
- // are calculated.
- for (auto &en : llvm::enumerate(reassociation)) {
- auto currentStrideToExpand = srcDesc.stride(b, loc, en.index());
- for (auto dstIndex : llvm::reverse(en.value())) {
- dstDesc.setStride(b, loc, dstIndex, currentStrideToExpand);
- Value size = dstDesc.size(b, loc, dstIndex);
- currentStrideToExpand =
- b.create<LLVM::MulOp>(loc, size, currentStrideToExpand);
- }
- }
-}
-
-static void fillInStridesForCollapsedMemDescriptor(
- ConversionPatternRewriter &rewriter, Location loc, Operation *op,
- TypeConverter *typeConverter, MemRefType srcType, MemRefDescriptor &srcDesc,
- MemRefDescriptor &dstDesc, ArrayRef<ReassociationIndices> reassociation) {
- auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
- // See comments for computeCollapsedLayoutMap for details on how the strides
- // are calculated.
- auto srcShape = srcType.getShape();
- for (auto &en : llvm::enumerate(reassociation)) {
- rewriter.setInsertionPoint(op);
- auto dstIndex = en.index();
- ArrayRef<int64_t> ref = llvm::ArrayRef(en.value());
- while (srcShape[ref.back()] == 1 && ref.size() > 1)
- ref = ref.drop_back();
- if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
- dstDesc.setStride(rewriter, loc, dstIndex,
- srcDesc.stride(rewriter, loc, ref.back()));
- } else {
- // Iterate over the source strides in reverse order. Skip over the
- // dimensions whose size is 1.
- // TODO: we should take the minimum stride in the reassociation group
- // instead of just the first where the dimension is not 1.
- //
- // +------------------------------------------------------+
- // | curEntry: |
- // | %srcStride = strides[srcIndex] |
- // | %neOne = cmp sizes[srcIndex],1 +--+
- // | cf.cond_br %neOne, continue(%srcStride), nextEntry | |
- // +-------------------------+----------------------------+ |
- // | |
- // v |
- // +-----------------------------+ |
- // | nextEntry: | |
- // | ... +---+ |
- // +--------------+--------------+ | |
- // | | |
- // v | |
- // +-----------------------------+ | |
- // | nextEntry: | | |
- // | ... | | |
- // +--------------+--------------+ | +--------+
- // | | |
- // v v v
- // +--------------------------------------------------+
- // | continue(%newStride): |
- // | %newMemRefDes = setStride(%newStride,dstIndex) |
- // +--------------------------------------------------+
- OpBuilder::InsertionGuard guard(rewriter);
- Block *initBlock = rewriter.getInsertionBlock();
- Block *continueBlock =
- rewriter.splitBlock(initBlock, rewriter.getInsertionPoint());
- continueBlock->insertArgument(unsigned(0), srcDesc.getIndexType(), loc);
- rewriter.setInsertionPointToStart(continueBlock);
- dstDesc.setStride(rewriter, loc, dstIndex, continueBlock->getArgument(0));
-
- Block *curEntryBlock = initBlock;
- Block *nextEntryBlock;
- for (auto srcIndex : llvm::reverse(ref)) {
- if (srcShape[srcIndex] == 1 && srcIndex != ref.front())
- continue;
- rewriter.setInsertionPointToEnd(curEntryBlock);
- Value srcStride = srcDesc.stride(rewriter, loc, srcIndex);
- if (srcIndex == ref.front()) {
- rewriter.create<LLVM::BrOp>(loc, srcStride, continueBlock);
- break;
- }
- Value one = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType,
- rewriter.getIndexAttr(1));
- Value predNeOne = rewriter.create<LLVM::ICmpOp>(
- loc, LLVM::ICmpPredicate::ne, srcDesc.size(rewriter, loc, srcIndex),
- one);
- {
- OpBuilder::InsertionGuard guard(rewriter);
- nextEntryBlock = rewriter.createBlock(
- initBlock->getParent(), Region::iterator(continueBlock), {});
- }
- rewriter.create<LLVM::CondBrOp>(loc, predNeOne, continueBlock,
- srcStride, nextEntryBlock,
- std::nullopt);
- curEntryBlock = nextEntryBlock;
- }
- }
- }
-}
-
-static void fillInDynamicStridesForMemDescriptor(
- ConversionPatternRewriter &b, Location loc, Operation *op,
- TypeConverter *typeConverter, MemRefType srcType, MemRefType dstType,
- MemRefDescriptor &srcDesc, MemRefDescriptor &dstDesc,
- ArrayRef<ReassociationIndices> reassociation) {
- if (srcType.getRank() > dstType.getRank())
- fillInStridesForCollapsedMemDescriptor(b, loc, op, typeConverter, srcType,
- srcDesc, dstDesc, reassociation);
- else
- fillInStridesForExpandedMemDescriptor(b, loc, srcType, srcDesc, dstDesc,
- reassociation);
-}
-
-// ReshapeOp creates a new view descriptor of the proper rank.
-// For now, the only conversion supported is for target MemRef with static sizes
-// and strides.
+/// RessociatingReshapeOp must be expanded before we reach this stage.
+/// Report that information.
template <typename ReshapeOp>
class ReassociatingReshapeOpConversion
: public ConvertOpToLLVMPattern<ReshapeOp> {
@@ -1634,56 +1386,9 @@ class ReassociatingReshapeOpConversion
LogicalResult
matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- MemRefType dstType = reshapeOp.getResultType();
- MemRefType srcType = reshapeOp.getSrcType();
-
- int64_t offset;
- SmallVector<int64_t, 4> strides;
- if (failed(getStridesAndOffset(dstType, strides, offset))) {
- return rewriter.notifyMatchFailure(
- reshapeOp, "failed to get stride and offset exprs");
- }
-
- MemRefDescriptor srcDesc(adaptor.getSrc());
- Location loc = reshapeOp->getLoc();
- auto dstDesc = MemRefDescriptor::undef(
- rewriter, loc, this->typeConverter->convertType(dstType));
- dstDesc.setAllocatedPtr(rewriter, loc, srcDesc.allocatedPtr(rewriter, loc));
- dstDesc.setAlignedPtr(rewriter, loc, srcDesc.alignedPtr(rewriter, loc));
- dstDesc.setOffset(rewriter, loc, srcDesc.offset(rewriter, loc));
-
- ArrayRef<int64_t> srcStaticShape = srcType.getShape();
- ArrayRef<int64_t> dstStaticShape = dstType.getShape();
- Type llvmIndexType =
- this->typeConverter->convertType(rewriter.getIndexType());
- SmallVector<Value> dstShape = getDynamicOutputShape(
- rewriter, loc, llvmIndexType, reshapeOp.getReassociationIndices(),
- srcStaticShape, srcDesc, dstStaticShape);
- for (auto &en : llvm::enumerate(dstShape))
- dstDesc.setSize(rewriter, loc, en.index(), en.value());
-
- if (llvm::all_of(strides, isStaticStrideOrOffset)) {
- for (auto &en : llvm::enumerate(strides))
- dstDesc.setConstantStride(rewriter, loc, en.index(), en.value());
- } else if (srcType.getLayout().isIdentity() &&
- dstType.getLayout().isIdentity()) {
- Value c1 = rewriter.create<LLVM::ConstantOp>(loc, llvmIndexType,
- rewriter.getIndexAttr(1));
- Value stride = c1;
- for (auto dimIndex :
- llvm::reverse(llvm::seq<int64_t>(0, dstShape.size()))) {
- dstDesc.setStride(rewriter, loc, dimIndex, stride);
- stride = rewriter.create<LLVM::MulOp>(loc, dstShape[dimIndex], stride);
- }
- } else {
- // There could be mixed static/dynamic strides. For simplicity, we
- // recompute all strides if there is at least one dynamic stride.
- fillInDynamicStridesForMemDescriptor(
- rewriter, loc, reshapeOp, this->typeConverter, srcType, dstType,
- srcDesc, dstDesc, reshapeOp.getReassociationIndices());
- }
- rewriter.replaceOp(reshapeOp, {dstDesc});
- return success();
+ return rewriter.notifyMatchFailure(
+ reshapeOp,
+ "reassociation operations should have been expanded beforehand");
}
};
diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
index 061cc6d17a29f..51650adc9a9ff 100644
--- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
@@ -373,3 +373,298 @@ func.func @subview_negative_stride(%arg0 : memref<7xf32>) -> memref<7xf32, strid
%0 = memref.subview %arg0[6] [7] [-1] : memref<7xf32> to memref<7xf32, strided<[-1], offset: 6>>
return %0 : memref<7xf32, strided<[-1], offset: 6>>
}
+
+// -----
+
+func.func @collapse_shape_static(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> {
+ %0 = memref.collapse_shape %arg0 [[0, 1], [2], [3, 4]] :
+ memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
+ return %0 : memref<3x4x5xf32>
+}
+// CHECK-LABEL: func @collapse_shape_static
+// CHECK-SAME: %[[ARG:.*]]: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> {
+// CHECK: %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x3x4x1x5xf32> to !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK: %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK: %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : index) : i64
+// CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[C3]], %[[DESC2]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[C20:.*]] = llvm.mlir.constant(20 : index) : i64
+// CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[C20]], %[[DESC3]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : index) : i64
+// CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[C4]], %[[DESC4]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[C5:.*]] = llvm.mlir.constant(5 : index) : i64
+// CHECK: %[[DESC6:.*]] = llvm.insertvalue %[[C5]], %[[DESC5]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[DESC7:.*]] = llvm.insertvalue %[[C5]], %[[DESC6]][3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
+// CHECK: %[[DESC8:.*]] = llvm.insertvalue %[[C1]], %[[DESC7]][4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC8]] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)> to memref<3x4x5xf32>
+// CHECK: return %[[RES]] : memref<3x4x5xf32>
+// CHECK: }
+
+// -----
+
+func.func @collapse_shape_dynamic_with_non_identity_layout(
+ %arg0 : memref<4x?x?xf32, strided<[?, 4, 1], offset: ?>>) ->
+ memref<4x?xf32, strided<[?, ?], offset: ?>> {
+ %0 = memref.collapse_shape %arg0 [[0], [1, 2]]:
+ memref<4x?x?xf32, strided<[?, 4, 1], offset: ?>> into
+ memref<4x?xf32, strided<[?, ?], offset: ?>>
+ return %0 : memref<4x?xf32, strided<[?, ?], offset: ?>>
+}
+// CHECK-LABEL: func.func @collapse_shape_dynamic_with_non_identity_layout(
+// CHECK-SAME: %[[ARG:.*]]: memref<4x?x?xf32, strided<[?, 4, 1], offset: ?>>) -> memref<4x?xf32, strided<[?, ?], offset: ?>> {
+// CHECK: %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<4x?x?xf32, strided<[?, 4, 1], offset: ?>> to !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+// CHECK: %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+// CHECK: %[[OFFSET:.*]] = llvm.extractvalue %[[MEM]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[SIZE1:.*]] = llvm.extractvalue %[[MEM]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[SIZE2:.*]] = llvm.extractvalue %[[MEM]][3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[STRIDE0_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[STRIDE0]] : i64 to index
+// CHECK: %[[STRIDE0:.*]] = builtin.unrealized_conversion_cast %[[STRIDE0_TO_IDX]] : index to i64
+// CHECK: %[[FINAL_SIZE1:.*]] = llvm.mul %[[SIZE1]], %[[SIZE2]] : i64
+// CHECK: %[[SIZE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE1]] : i64 to index
+// CHECK: %[[FINAL_SIZE1:.*]] = builtin.unrealized_conversion_cast %[[SIZE1_TO_IDX]] : index to i64
+// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[OFFSET]], %[[DESC1]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : index) : i64
+// CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[C4]], %[[DESC2]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESC3]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[FINAL_SIZE1]], %[[DESC4]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
+// CHECK: %[[DESC6:.*]] = llvm.insertvalue %[[C1]], %[[DESC5]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC6]] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)> to memref<4x?xf32, strided<[?, ?], offset: ?>>
+// CHECK: return %[[RES]] : memref<4x?xf32, strided<[?, ?], offset: ?>>
+// CHECK: }
+// CHECK32-LABEL: func @collapse_shape_dynamic_with_non_identity_layout(
+// CHECK32: llvm.mlir.constant(1 : index) : i32
+// CHECK32: llvm.mlir.constant(4 : index) : i32
+// CHECK32: llvm.mlir.constant(1 : index) : i32
+
+// -----
+
+
+func.func @expand_shape_static(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> {
+ // Reshapes that expand a contiguous tensor with some 1's.
+ %0 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]]
+ : memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
+ return %0 : memref<1x3x4x1x5xf32>
+}
+// CHECK-LABEL: func @expand_shape_static
+// CHECK-SAME: %[[ARG:.*]]: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> {
+// CHECK: %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<3x4x5xf32> to !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+// CHECK: %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
+// CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[C1]], %[[DESC2]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK: %[[C60:.*]] = llvm.mlir.constant(60 : index) : i64
+// CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[C60]], %[[DESC3]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : index) : i64
+// CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[C3]], %[[DESC4]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK: %[[C20:.*]] = llvm.mlir.constant(20 : index) : i64
+// CHECK: %[[DESC6:.*]] = llvm.insertvalue %[[C20]], %[[DESC5]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : index) : i64
+// CHECK: %[[DESC7:.*]] = llvm.insertvalue %[[C4]], %[[DESC6]][3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK: %[[C5:.*]] = llvm.mlir.constant(5 : index) : i64
+// CHECK: %[[DESC8:.*]] = llvm.insertvalue %[[C5]], %[[DESC7]][4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK: %[[DESC9:.*]] = llvm.insertvalue %[[C1]], %[[DESC8]][3, 3] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK: %[[DESC10:.*]] = llvm.insertvalue %[[C5]], %[[DESC9]][4, 3] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK: %[[DESC11:.*]] = llvm.insertvalue %[[C5]], %[[DESC10]][3, 4] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK: %[[DESC12:.*]] = llvm.insertvalue %[[C1]], %[[DESC11]][4, 4] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC12]] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)> to memref<1x3x4x1x5xf32>
+// CHECK: return %[[RES]] : memref<1x3x4x1x5xf32>
+// CHECK: }
+
+// -----
+
+func.func @collapse_shape_fold_zero_dim(%arg0 : memref<1x1xf32>) -> memref<f32> {
+ %0 = memref.collapse_shape %arg0 [] : memref<1x1xf32> into memref<f32>
+ return %0 : memref<f32>
+}
+// CHECK-LABEL: func.func @collapse_shape_fold_zero_dim(
+// CHECK-SAME: %[[ARG:.*]]: memref<1x1xf32>) -> memref<f32> {
+// CHECK: %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x1xf32> to !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+// CHECK: %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC2]] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)> to memref<f32>
+// CHECK: return %[[RES]] : memref<f32>
+// CHECK: }
+
+// -----
+
+func.func @expand_shape_zero_dim(%arg0 : memref<f32>) -> memref<1x1xf32> {
+ %0 = memref.expand_shape %arg0 [] : memref<f32> into memref<1x1xf32>
+ return %0 : memref<1x1xf32>
+}
+
+// CHECK-LABEL: func.func @expand_shape_zero_dim(
+// CHECK-SAME: %[[ARG:.*]]: memref<f32>) -> memref<1x1xf32> {
+// CHECK: %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<f32> to !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+// CHECK: %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+// CHECK: %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
+// CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[C1]], %[[DESC2]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[C1]], %[[DESC3]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[C1]], %[[DESC4]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[DESC6:.*]] = llvm.insertvalue %[[C1]], %[[DESC5]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC6]] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)> to memref<1x1xf32>
+// CHECK: return %[[RES]] : memref<1x1xf32>
+// CHECK: }
+
+// -----
+
+func.func @collapse_shape_dynamic(%arg0 : memref<1x2x?xf32>) -> memref<1x?xf32> {
+ %0 = memref.collapse_shape %arg0 [[0], [1, 2]]: memref<1x2x?xf32> into memref<1x?xf32>
+ return %0 : memref<1x?xf32>
+}
+
+// CHECK-LABEL: func.func @collapse_shape_dynamic(
+// CHECK-SAME: %[[ARG:.*]]: memref<1x2x?xf32>) -> memref<1x?xf32> {
+// CHECK: %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x2x?xf32> to !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+// CHECK: %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[SIZE2:.*]] = llvm.extractvalue %[[MEM]][3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEM]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : index) : i64
+// CHECK: %[[FINAL_SIZE1:.*]] = llvm.mul %[[SIZE2]], %[[C2]] : i64
+// CHECK: %[[SIZE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE1]] : i64 to index
+// CHECK: %[[FINAL_SIZE1:.*]] = builtin.unrealized_conversion_cast %[[SIZE1_TO_IDX]] : index to i64
+// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
+// CHECK: %[[IS_MIN_STRIDE1:.*]] = llvm.icmp "slt" %[[STRIDE1]], %[[C1]] : i64
+// CHECK: %[[MIN_STRIDE1:.*]] = llvm.select %[[IS_MIN_STRIDE1]], %[[STRIDE1]], %[[C1]] : i1, i64
+// CHECK: %[[MIN_STRIDE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[MIN_STRIDE1]] : i64 to index
+// CHECK: %[[MIN_STRIDE1:.*]] = builtin.unrealized_conversion_cast %[[MIN_STRIDE1_TO_IDX]] : index to i64
+// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[C1]], %[[DESC2]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESC3]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[FINAL_SIZE1]], %[[DESC4]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[DESC6:.*]] = llvm.insertvalue %[[MIN_STRIDE1]], %[[DESC5]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC6]] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)> to memref<1x?xf32>
+// CHECK: return %[[RES]] : memref<1x?xf32>
+// CHECK: }
+
+// -----
+
+func.func @expand_shape_dynamic(%arg0 : memref<1x?xf32>) -> memref<1x2x?xf32> {
+ %0 = memref.expand_shape %arg0 [[0], [1, 2]]: memref<1x?xf32> into memref<1x2x?xf32>
+ return %0 : memref<1x2x?xf32>
+}
+
+// CHECK-LABEL: func.func @expand_shape_dynamic(
+// CHECK-SAME: %[[ARG:.*]]: memref<1x?xf32>) -> memref<1x2x?xf32> {
+// CHECK: %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x?xf32> to !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+// CHECK: %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[SIZE1:.*]] = llvm.extractvalue %[[MEM]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : index) : i64
+// CHECK: %[[CMINUS1:.*]] = llvm.mlir.constant(-1 : index) : i64
+// CHECK: %[[IS_NEGATIVE_SIZE1:.*]] = llvm.icmp "slt" %[[SIZE1]], %[[C0]] : i64
+// CHECK: %[[ABS_SIZE1_MINUS_1:.*]] = llvm.sub %[[CMINUS1]], %[[SIZE1]] : i64
+// CHECK: %[[ADJ_SIZE1:.*]] = llvm.select %[[IS_NEGATIVE_SIZE1]], %[[ABS_SIZE1_MINUS_1]], %[[SIZE1]] : i1, i64
+// CHECK: %[[SIZE2:.*]] = llvm.sdiv %[[ADJ_SIZE1]], %[[C2]] : i64
+// CHECK: %[[NEGATIVE_SIZE2:.*]] = llvm.sub %[[CMINUS1]], %[[SIZE2]] : i64
+// CHECK: %[[FINAL_SIZE2:.*]] = llvm.select %[[IS_NEGATIVE_SIZE1]], %[[NEGATIVE_SIZE2]], %[[SIZE2]] : i1, i64
+// CHECK: %[[SIZE2_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE2]] : i64 to index
+// CHECK: %[[FINAL_SIZE2:.*]] = builtin.unrealized_conversion_cast %[[SIZE2_TO_IDX]] : index to i64
+// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
+// CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[C1]], %[[DESC2]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESC3]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[C2]], %[[DESC4]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// In this example stride1 and size2 are the same.
+// Hence with CSE, we get the same SSA value.
+// CHECK: %[[DESC6:.*]] = llvm.insertvalue %[[FINAL_SIZE2]], %[[DESC5]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[DESC7:.*]] = llvm.insertvalue %[[FINAL_SIZE2]], %[[DESC6]][3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[DESC8:.*]] = llvm.insertvalue %[[C1]], %[[DESC7]][4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC8]] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)> to memref<1x2x?xf32>
+// CHECK: return %[[RES]] : memref<1x2x?xf32>
+// CHECK: }
+
+// -----
+
+func.func @expand_shape_dynamic_with_non_identity_layout(
+ %arg0 : memref<1x?xf32, strided<[?, ?], offset: ?>>) ->
+ memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> {
+ %0 = memref.expand_shape %arg0 [[0], [1, 2]]:
+ memref<1x?xf32, strided<[?, ?], offset: ?>> into
+ memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>>
+ return %0 : memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>>
+}
+// CHECK-LABEL: func.func @expand_shape_dynamic_with_non_identity_layout(
+// CHECK-SAME: %[[ARG:.*]]: memref<1x?xf32, strided<[?, ?], offset: ?>>) -> memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> {
+// CHECK: %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x?xf32, strided<[?, ?], offset: ?>> to !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+// CHECK: %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[OFFSET:.*]] = llvm.extractvalue %[[MEM]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[SIZE1:.*]] = llvm.extractvalue %[[MEM]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEM]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : index) : i64
+// CHECK: %[[CMINUS1:.*]] = llvm.mlir.constant(-1 : index) : i64
+// CHECK: %[[IS_NEGATIVE_SIZE1:.*]] = llvm.icmp "slt" %[[SIZE1]], %[[C0]] : i64
+// CHECK: %[[ABS_SIZE1_MINUS_1:.*]] = llvm.sub %[[CMINUS1]], %[[SIZE1]] : i64
+// CHECK: %[[ADJ_SIZE1:.*]] = llvm.select %[[IS_NEGATIVE_SIZE1]], %[[ABS_SIZE1_MINUS_1]], %[[SIZE1]] : i1, i64
+// CHECK: %[[SIZE2:.*]] = llvm.sdiv %[[ADJ_SIZE1]], %[[C2]] : i64
+// CHECK: %[[NEGATIVE_SIZE2:.*]] = llvm.sub %[[CMINUS1]], %[[SIZE2]] : i64
+// CHECK: %[[TMP_SIZE2:.*]] = llvm.select %[[IS_NEGATIVE_SIZE1]], %[[NEGATIVE_SIZE2]], %[[SIZE2]] : i1, i64
+// CHECK: %[[SIZE2_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[TMP_SIZE2]] : i64 to index
+// CHECK: %[[FINAL_SIZE2:.*]] = builtin.unrealized_conversion_cast %[[SIZE2_TO_IDX]] : index to i64
+// CHECK: %[[FINAL_STRIDE1:.*]] = llvm.mul %[[TMP_SIZE2]], %[[STRIDE1]]
+// CHECK: %[[STRIDE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_STRIDE1]] : i64 to index
+// CHECK: %[[FINAL_STRIDE1:.*]] = builtin.unrealized_conversion_cast %[[STRIDE1_TO_IDX]] : index to i64
+// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC1]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[OFFSET]], %[[DESC2]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
+// CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[C1]], %[[DESC3]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESC4]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[DESC6:.*]] = llvm.insertvalue %[[C2]], %[[DESC5]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[DESC7:.*]] = llvm.insertvalue %[[FINAL_STRIDE1]], %[[DESC6]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[DESC8:.*]] = llvm.insertvalue %[[FINAL_SIZE2]], %[[DESC7]][3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[DESC9:.*]] = llvm.insertvalue %[[STRIDE1]], %[[DESC8]][4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC9]] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)> to memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>>
+// CHECK: return %[[RES]] : memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>>
+// CHECK: }
+
+// -----
+
+// CHECK-LABEL: func @collapse_static_shape_with_non_identity_layout
+func.func @collapse_static_shape_with_non_identity_layout(%arg: memref<1x1x8x8xf32, strided<[64, 64, 8, 1], offset: ?>>) -> memref<64xf32, strided<[1], offset: ?>> {
+// CHECK-NOT: memref.collapse_shape
+ %1 = memref.collapse_shape %arg [[0, 1, 2, 3]] : memref<1x1x8x8xf32, strided<[64, 64, 8, 1], offset: ?>> into memref<64xf32, strided<[1], offset: ?>>
+ return %1 : memref<64xf32, strided<[1], offset: ?>>
+}
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index c12a07881cd46..1a8a75d1e2b35 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -290,246 +290,27 @@ memref.global "private" @gv4 : memref<f32> = dense<1.0> {alignment = 64}
// -----
-func.func @collapse_shape_static(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> {
- %0 = memref.collapse_shape %arg0 [[0, 1], [2], [3, 4]] :
- memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
- return %0 : memref<3x4x5xf32>
-}
-// CHECK-LABEL: func @collapse_shape_static
-// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.mlir.constant(3 : index) : i64
-// CHECK: llvm.mlir.constant(4 : index) : i64
-// CHECK: llvm.mlir.constant(5 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.mlir.constant(20 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.mlir.constant(5 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.mlir.constant(1 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-
-// -----
-
-func.func @collapse_shape_dynamic_with_non_identity_layout(
- %arg0 : memref<4x?x?xf32, strided<[?, 4, 1], offset: ?>>) ->
- memref<4x?xf32, strided<[?, ?], offset: ?>> {
- %0 = memref.collapse_shape %arg0 [[0], [1, 2]]:
- memref<4x?x?xf32, strided<[?, 4, 1], offset: ?>> into
- memref<4x?xf32, strided<[?, ?], offset: ?>>
- return %0 : memref<4x?xf32, strided<[?, ?], offset: ?>>
-}
-// CHECK-LABEL: func @collapse_shape_dynamic_with_non_identity_layout(
-// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.mlir.constant(1 : index) : i64
-// CHECK: llvm.extractvalue %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64
-// CHECK: llvm.extractvalue %{{.*}}[3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64
-// CHECK: llvm.mlir.constant(4 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.extractvalue %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.extractvalue %{{.*}}[4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.mlir.constant(1 : index) : i64
-// CHECK: llvm.extractvalue %{{.*}}[3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.icmp "ne" %{{.*}}, %{{.*}} : i64
-// CHECK: llvm.cond_br %{{.*}}, ^bb2(%{{.*}} : i64), ^bb1
-// CHECK: ^bb1:
-// CHECK: llvm.extractvalue %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.br ^bb2(%{{.*}} : i64)
-// CHECK: ^bb2(%[[STRIDE:.*]]: i64):
-// CHECK: llvm.insertvalue %[[STRIDE]], %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK32-LABEL: func @collapse_shape_dynamic_with_non_identity_layout(
-// CHECK32: llvm.mlir.constant(1 : index) : i32
-// CHECK32: llvm.mlir.constant(4 : index) : i32
-// CHECK32: llvm.mlir.constant(1 : index) : i32
-
-// -----
-
+// Expand shapes need to be expanded outside of the memref-to-llvm pass.
+// CHECK-LABEL: func @expand_shape_static(
+// CHECK-SAME: %[[ARG:.*]]: memref<{{.*}}>)
func.func @expand_shape_static(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> {
+ // CHECK: memref.expand_shape %[[ARG]] {{\[}}[0, 1], [2], [3, 4]]
// Reshapes that expand a contiguous tensor with some 1's.
%0 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]]
: memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
return %0 : memref<1x3x4x1x5xf32>
}
-// CHECK-LABEL: func @expand_shape_static
-// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK: llvm.mlir.constant(1 : index) : i64
-// CHECK: llvm.mlir.constant(3 : index) : i64
-// CHECK: llvm.mlir.constant(4 : index) : i64
-// CHECK: llvm.mlir.constant(1 : index) : i64
-// CHECK: llvm.mlir.constant(5 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 3] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 4] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK: llvm.mlir.constant(60 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK: llvm.mlir.constant(20 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK: llvm.mlir.constant(5 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK: llvm.mlir.constant(5 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 3] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK: llvm.mlir.constant(1 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 4] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-
-
-// -----
-
-func.func @collapse_shape_fold_zero_dim(%arg0 : memref<1x1xf32>) -> memref<f32> {
- %0 = memref.collapse_shape %arg0 [] : memref<1x1xf32> into memref<f32>
- return %0 : memref<f32>
-}
-// CHECK-LABEL: func @collapse_shape_fold_zero_dim
-// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
-// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
-// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
-// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
// -----
-func.func @expand_shape_zero_dim(%arg0 : memref<f32>) -> memref<1x1xf32> {
- %0 = memref.expand_shape %arg0 [] : memref<f32> into memref<1x1xf32>
- return %0 : memref<1x1xf32>
-}
-// CHECK-LABEL: func @expand_shape_zero_dim
-// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.mlir.constant(1 : index) : i64
-// CHECK: llvm.mlir.constant(1 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.mlir.constant(1 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.mlir.constant(1 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-
-// -----
-
-func.func @collapse_shape_dynamic(%arg0 : memref<1x2x?xf32>) -> memref<1x?xf32> {
- %0 = memref.collapse_shape %arg0 [[0], [1, 2]]: memref<1x2x?xf32> into memref<1x?xf32>
- return %0 : memref<1x?xf32>
-}
-// CHECK-LABEL: func @collapse_shape_dynamic(
-// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.mlir.constant(1 : index) : i64
-// CHECK: llvm.mlir.constant(2 : index) : i64
-// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64
-// CHECK: llvm.extractvalue %{{.*}}[3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64
-// CHECK: llvm.mlir.constant(1 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.mlir.constant(1 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64
-
-// -----
-
-func.func @expand_shape_dynamic(%arg0 : memref<1x?xf32>) -> memref<1x2x?xf32> {
- %0 = memref.expand_shape %arg0 [[0], [1, 2]]: memref<1x?xf32> into memref<1x2x?xf32>
- return %0 : memref<1x2x?xf32>
-}
-// CHECK-LABEL: func @expand_shape_dynamic(
-// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.extractvalue %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.mlir.constant(2 : index) : i64
-// CHECK: llvm.sdiv %{{.*}}, %{{.*}} : i64
-// CHECK: llvm.mlir.constant(1 : index) : i64
-// CHECK: llvm.mlir.constant(2 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.mlir.constant(1 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64
-
-// -----
-
-func.func @expand_shape_dynamic_with_non_identity_layout(
- %arg0 : memref<1x?xf32, strided<[?, ?], offset: ?>>) ->
- memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> {
- %0 = memref.expand_shape %arg0 [[0], [1, 2]]:
- memref<1x?xf32, strided<[?, ?], offset: ?>> into
- memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>>
- return %0 : memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>>
+// Collapse shapes need to be expanded outside of the memref-to-llvm pass.
+// CHECK-LABEL: func @collapse_shape_static
+// CHECK-SAME: %[[ARG:.*]]: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> {
+func.func @collapse_shape_static(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> {
+ %0 = memref.collapse_shape %arg0 [[0, 1], [2], [3, 4]] :
+ memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
+ return %0 : memref<3x4x5xf32>
}
-// CHECK-LABEL: func @expand_shape_dynamic_with_non_identity_layout(
-// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.extractvalue %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.mlir.constant(2 : index) : i64
-// CHECK: llvm.sdiv %{{.*}}, %{{.*}} : i64
-// CHECK: llvm.mlir.constant(1 : index) : i64
-// CHECK: llvm.mlir.constant(2 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.extractvalue %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.extractvalue %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64
-// CHECK: llvm.extractvalue %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.extractvalue %{{.*}}[3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.extractvalue %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.mul %{{.*}}, %{{.*}} : i64
// -----
@@ -579,15 +360,6 @@ func.func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fv
// -----
-// CHECK-LABEL: func @collapse_static_shape_with_non_identity_layout
-func.func @collapse_static_shape_with_non_identity_layout(%arg: memref<1x1x8x8xf32, strided<[64, 64, 8, 1], offset: ?>>) -> memref<64xf32, strided<[1], offset: ?>> {
-// CHECK-NOT: memref.collapse_shape
- %1 = memref.collapse_shape %arg [[0, 1, 2, 3]] : memref<1x1x8x8xf32, strided<[64, 64, 8, 1], offset: ?>> into memref<64xf32, strided<[1], offset: ?>>
- return %1 : memref<64xf32, strided<[1], offset: ?>>
-}
-
-// -----
-
// CHECK-LABEL: func @generic_atomic_rmw
func.func @generic_atomic_rmw(%I : memref<10xi32>, %i : index) {
%x = memref.generic_atomic_rmw %I[%i] : memref<10xi32> {
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-collapse-tensor.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-collapse-tensor.mlir
index d4406c4314f2e..c4714089c2f1e 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/test-collapse-tensor.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-collapse-tensor.mlir
@@ -1,6 +1,7 @@
// RUN: mlir-opt %s -linalg-bufferize \
// RUN: -arith-bufferize -tensor-bufferize -func-bufferize \
// RUN: -finalizing-bufferize -buffer-deallocation -convert-linalg-to-llvm \
+// RUN: -expand-strided-metadata -lower-affine \
// RUN: -convert-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_lib_dir/libmlir_runner_utils%shlibext \
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-expand-tensor.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-expand-tensor.mlir
index ce72fbc5a78a7..868d940d233de 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/test-expand-tensor.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-expand-tensor.mlir
@@ -1,6 +1,7 @@
// RUN: mlir-opt %s -linalg-bufferize \
// RUN: -arith-bufferize -tensor-bufferize -func-bufferize \
// RUN: -finalizing-bufferize -buffer-deallocation -convert-linalg-to-llvm \
+// RUN: -expand-strided-metadata -lower-affine \
// RUN: -convert-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_lib_dir/libmlir_runner_utils%shlibext \
More information about the Mlir-commits
mailing list