[Mlir-commits] [mlir] 46ef86b - [mlir] Move linalg::Expand/CollapseShapeOp to memref dialect.
Alexander Belyaev
llvmlistbot at llvm.org
Fri Jul 16 04:32:50 PDT 2021
Author: Alexander Belyaev
Date: 2021-07-16T13:32:17+02:00
New Revision: 46ef86b5d82ea8ec36de680f10b69d36487e4d1d
URL: https://github.com/llvm/llvm-project/commit/46ef86b5d82ea8ec36de680f10b69d36487e4d1d
DIFF: https://github.com/llvm/llvm-project/commit/46ef86b5d82ea8ec36de680f10b69d36487e4d1d.diff
LOG: [mlir] Move linalg::Expand/CollapseShapeOp to memref dialect.
RFC: https://llvm.discourse.group/t/rfc-reshape-ops-restructuring/3310
Differential Revision: https://reviews.llvm.org/D106141
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
mlir/test/Dialect/Linalg/bufferize.mlir
mlir/test/Dialect/Linalg/canonicalize.mlir
mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
mlir/test/Dialect/Linalg/invalid.mlir
mlir/test/Dialect/Linalg/llvm.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir
mlir/test/Dialect/MemRef/canonicalize.mlir
mlir/test/Dialect/MemRef/invalid.mlir
mlir/test/Dialect/MemRef/ops.mlir
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index cb0507d6f903..8d14880c148b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -396,91 +396,6 @@ class Linalg_ReshapeLikeOp<string mnemonic, list<OpTrait> traits = []> :
def IndexListArrayAttr :
TypedArrayAttrBase<I64ArrayAttr, "Array of 64-bit integer array attributes">;
-class Linalg_ReshapeOp<string mnemonic> : Linalg_ReshapeLikeOp<mnemonic,
- [DeclareOpInterfaceMethods<ViewLikeOpInterface>]>,
- Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>,
- Results<(outs AnyStridedMemRef:$result)> {
- let extraClassDeclaration = commonExtraClassDeclaration # [{
- MemRefType getSrcType() { return src().getType().cast<MemRefType>(); }
- MemRefType getResultType() { return result().getType().cast<MemRefType>(); }
- }];
- let hasFolder = 1;
- let hasCanonicalizer = 1;
- let printer = [{ return ::print(p, *this); }];
-}
-
-def Linalg_ExpandShapeOp : Linalg_ReshapeOp<"expand_shape"> {
- let summary = "operation to produce a memref with a higher rank.";
- let description = [{
- The `linalg.expand_shape` op produces a new view with a higher rank whose
- sizes are a reassociation of the original `view`. Depending on whether or
- not the reassociated MemRefType is contiguous, the resulting memref may
- require explicit alloc and copies.
-
- A reassociation is defined as a continuous grouping of dimensions and is
- represented with an array of I64ArrayAttr attribute.
-
- For now, it is assumed that either:
- 1. a reassociation produces and consumes contiguous MemRefType or,
- 2. the reshape op will be folded into its consumers (by changing the shape
- of the computations).
- All other cases are undefined behavior and a reshape op may not lower to
- LLVM if it cannot be proven statically that it does not require alloc+copy.
-
- The operand memref type when dimensions can be zero-ranked if the result
- memref type is statically shaped with all dimensions being unit extent. In
- such case the reassociation map is empty.
-
- The verification rule is that the reassociation maps are applied to the
- result memref with the larger rank to obtain the operand memref with the
- smaller rank.
-
- Example:
-
- ```mlir
- // Dimension expansion i -> (i', j') and (k) -> (k')
- %1 = linalg.expand_shape %0 [[0, 1], [2]] :
- memref<?x?xf32, stride_spec> into memref<?x?x?xf32, stride_spec_2>
- ```
- }];
-}
-
-def Linalg_CollapseShapeOp : Linalg_ReshapeOp<"collapse_shape"> {
- let summary = "operation to produce a memref with a smaller rank.";
- let description = [{
- The `linalg.collapse_shape` op produces a new view with a smaller rank
- whose sizes are a reassociation of the original `view`. Depending on
- whether or not the reassociated MemRefType is contiguous, the resulting
- memref may require explicit alloc and copies.
-
- A reassociation is defined as a continuous grouping of dimensions and is
- represented with an array of I64ArrayAttr attribute.
-
- For now, it is assumed that either:
- 1. a reassociation produces and consumes contiguous MemRefType or,
- 2. the reshape op will be folded into its consumers (by changing the shape
- of the computations).
- All other cases are undefined behavior and a reshape op may not lower to
- LLVM if it cannot be proven statically that it does not require alloc+copy.
-
- The result memref type of a reshape can be zero-ranked if the operand
- memref type is statically shaped with all dimensions being unit extent. In
- such case the reassociation map is empty.
-
- The verification rule is that the reassociation maps are applied to the
- operand memref with the larger rank to obtain the result memref with the
- smaller rank.
-
- Examples:
-
- ```mlir
- // Dimension collapse (i, j) -> i' and k -> k'
- %1 = linalg.collapse_shape %0 [[0, 1], [2]] :
- memref<?x?x?xf32, stride_spec> into memref<?x?xf32, stride_spec_2>
- ```
- }];
-}
-
class Linalg_TensorReshapeOp<string mnemonic> : Linalg_ReshapeLikeOp<
mnemonic,
[DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
index c0f3c3445823..8c61248567f6 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
@@ -10,6 +10,7 @@
#define MLIR_DIALECT_MEMREF_IR_MEMREF_H_
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/CastInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index e60cf298d36e..3385a7e8e78c 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -203,7 +203,7 @@ def MemRef_AllocaOp : AllocLikeOp<"alloca", AutomaticAllocationScopeResource> {
// AllocaScopeOp
//===----------------------------------------------------------------------===//
-def MemRef_AllocaScopeOp : MemRef_Op<"alloca_scope",
+def MemRef_AllocaScopeOp : MemRef_Op<"alloca_scope",
[AutomaticAllocationScope,
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
SingleBlockImplicitTerminator<"AllocaScopeReturnOp">,
@@ -225,8 +225,8 @@ def MemRef_AllocaScopeOp : MemRef_Op<"alloca_scope",
Here, `%myalloca` memref is valid within the explicitly delimited scope
and is automatically deallocated at the end of the given region. Conceptually,
- `memref.alloca_scope` is a passthrough operation with
- `AutomaticAllocationScope` that spans the body of the region within the operation.
+ `memref.alloca_scope` is a passthrough operation with
+ `AutomaticAllocationScope` that spans the body of the region within the operation.
`memref.alloca_scope` may also return results that are defined in the nested
region. To return a value, one should use `memref.alloca_scope.return`
@@ -251,14 +251,14 @@ def MemRef_AllocaScopeOp : MemRef_Op<"alloca_scope",
// AllocaScopeReturnOp
//===----------------------------------------------------------------------===//
-def MemRef_AllocaScopeReturnOp : MemRef_Op<"alloca_scope.return",
+def MemRef_AllocaScopeReturnOp : MemRef_Op<"alloca_scope.return",
[HasParent<"AllocaScopeOp">,
NoSideEffect,
ReturnLike,
Terminator]> {
let summary = "terminator for alloca_scope operation";
let description = [{
- `memref.alloca_scope.return` operation returns zero or more SSA values
+ `memref.alloca_scope.return` operation returns zero or more SSA values
from the region within `memref.alloca_scope`. If no values are returned,
the return operation may be omitted. Otherwise, it has to be present
to indicate which values are going to be returned. For example:
@@ -927,6 +927,150 @@ def MemRef_ReshapeOp: MemRef_Op<"reshape", [
}];
}
+//===----------------------------------------------------------------------===//
+// ExpandShapeOp / CollapseShapeOp
+//===----------------------------------------------------------------------===//
+
+def IndexListArrayAttr :
+ TypedArrayAttrBase<I64ArrayAttr, "Array of 64-bit integer array attributes">;
+
+class MemRef_ReassociativeReshapeOp<string mnemonic, list<OpTrait> traits = []> :
+ MemRef_Op<mnemonic, !listconcat(traits,
+ [NoSideEffect, ViewLikeOpInterface])>,
+ Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>,
+ Results<(outs AnyStridedMemRef:$result)>{
+ let builders = [
+ // Builders for a contracting reshape whose result type is computed from
+ // `src` and `reassociation`.
+ OpBuilder<(ins "Value":$src,
+ "ArrayRef<ReassociationIndices>":$reassociation,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+ OpBuilder<(ins "Value":$src,
+ "ArrayRef<ReassociationExprs>":$reassociation,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+ [{
+ auto reassociationMaps =
+ convertReassociationMapsToIndices($_builder, reassociation);
+ build($_builder, $_state, src, reassociationMaps, attrs);
+ }]>,
+
+ // Builders for a reshape whose result type is passed explicitly. This may
+ // be either a contracting or expanding reshape.
+ OpBuilder<(ins "Type":$resultType, "Value":$src,
+ "ArrayRef<ReassociationIndices>":$reassociation,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+ [{
+ build($_builder, $_state, resultType, src, attrs);
+ $_state.addAttribute("reassociation",
+ getReassociationIndicesAttribute($_builder, reassociation));
+ }]>,
+ OpBuilder<(ins "Type":$resultType, "Value":$src,
+ "ArrayRef<ReassociationExprs>":$reassociation,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+ [{
+ auto reassociationMaps =
+ convertReassociationMapsToIndices($_builder, reassociation);
+ build($_builder, $_state, resultType, src, reassociationMaps, attrs);
+ }]>
+ ];
+
+ code commonExtraClassDeclaration = [{
+ SmallVector<AffineMap, 4> getReassociationMaps();
+ SmallVector<ReassociationExprs, 4> getReassociationExprs();
+ SmallVector<ReassociationIndices, 4> getReassociationIndices() {
+ SmallVector<ReassociationIndices, 4> reassociationIndices;
+ for (auto attr : reassociation())
+ reassociationIndices.push_back(llvm::to_vector<2>(
+ llvm::map_range(attr.cast<ArrayAttr>(), [&](Attribute indexAttr) {
+ return indexAttr.cast<IntegerAttr>().getInt();
+ })));
+ return reassociationIndices;
+ };
+ MemRefType getSrcType() { return src().getType().cast<MemRefType>(); }
+ MemRefType getResultType() { return result().getType().cast<MemRefType>(); }
+ Value getViewSource() { return src(); }
+ }];
+
+ let hasFolder = 1;
+ let hasCanonicalizer = 1;
+ let printer = [{ return ::print(p, *this); }];
+ let parser = [{ return ::parseReshapeLikeOp(parser, result); }];
+}
+
+def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape"> {
+ let summary = "operation to produce a memref with a higher rank.";
+ let description = [{
+ The `memref.expand_shape` op produces a new view with a higher rank whose
+ sizes are a reassociation of the original `view`. Depending on whether or
+ not the reassociated MemRefType is contiguous, the resulting memref may
+ require explicit alloc and copies.
+
+ A reassociation is defined as a continuous grouping of dimensions and is
+ represented with an array of I64ArrayAttr attribute.
+
+ For now, it is assumed that either:
+ 1. a reassociation produces and consumes contiguous MemRefType or,
+ 2. the reshape op will be folded into its consumers (by changing the shape
+ of the computations).
+ All other cases are undefined behavior and a reshape op may not lower to
+ LLVM if it cannot be proven statically that it does not require alloc+copy.
+
+ The operand memref type when dimensions can be zero-ranked if the result
+ memref type is statically shaped with all dimensions being unit extent. In
+ such case the reassociation map is empty.
+
+ The verification rule is that the reassociation maps are applied to the
+ result memref with the larger rank to obtain the operand memref with the
+ smaller rank.
+
+ Example:
+
+ ```mlir
+ // Dimension expansion i -> (i', j') and (k) -> (k')
+ %1 = memref.expand_shape %0 [[0, 1], [2]] :
+ memref<?x?xf32, stride_spec> into memref<?x?x?xf32, stride_spec_2>
+ ```
+ }];
+ let extraClassDeclaration = commonExtraClassDeclaration;
+}
+
+def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape"> {
+ let summary = "operation to produce a memref with a smaller rank.";
+ let description = [{
+ The `memref.collapse_shape` op produces a new view with a smaller rank
+ whose sizes are a reassociation of the original `view`. Depending on
+ whether or not the reassociated MemRefType is contiguous, the resulting
+ memref may require explicit alloc and copies.
+
+ A reassociation is defined as a continuous grouping of dimensions and is
+ represented with an array of I64ArrayAttr attribute.
+
+ For now, it is assumed that either:
+ 1. a reassociation produces and consumes contiguous MemRefType or,
+ 2. the reshape op will be folded into its consumers (by changing the shape
+ of the computations).
+ All other cases are undefined behavior and a reshape op may not lower to
+ LLVM if it cannot be proven statically that it does not require alloc+copy.
+
+ The result memref type of a reshape can be zero-ranked if the operand
+ memref type is statically shaped with all dimensions being unit extent. In
+ such case the reassociation map is empty.
+
+ The verification rule is that the reassociation maps are applied to the
+ operand memref with the larger rank to obtain the result memref with the
+ smaller rank.
+
+ Examples:
+
+ ```mlir
+ // Dimension collapse (i, j) -> i' and k -> k'
+ %1 = memref.collapse_shape %0 [[0, 1], [2]] :
+ memref<?x?x?xf32, stride_spec> into memref<?x?xf32, stride_spec_2>
+ ```
+ }];
+ let extraClassDeclaration = commonExtraClassDeclaration;
+}
+
//===----------------------------------------------------------------------===//
// StoreOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 3a557a63fd26..57350c5b1406 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -1,4 +1,4 @@
-//===- RehshapeOpsUtils.h - Utilities used by reshape ops --*- C++ -*------===//
+//===- ReshapeOpsUtils.h - Utilities used by reshape ops --*- C++ -*------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -26,7 +26,7 @@ using ReassociationIndicesRef = ArrayRef<int64_t>;
using ReassociationExprs = SmallVector<AffineExpr, 2>;
/// Attribute name for the ArrayAttr which encodes reassociation indices.
-constexpr StringRef getReassociationAttrName();
+constexpr StringRef getReassociationAttrName() { return "reassociation"; }
/// Compose reassociation maps that are used in pair of reshape ops where one
/// is a producer and other is the consumer. Only valid to use this method when
@@ -45,6 +45,23 @@ Optional<SmallVector<ReassociationIndices>> composeReassociationIndices(
ArrayRef<ReassociationIndices> consumerReassociations,
MLIRContext *context);
+/// Convert reassociation indices to affine expressions.
+SmallVector<SmallVector<AffineExpr, 2>, 2> convertReassociationIndicesToExprs(
+ OpBuilder &b, ArrayRef<ReassociationIndices> reassociationIndices);
+
+/// Constructs affine maps out of Array<Array<AffineExpr>>.
+SmallVector<AffineMap, 4>
+getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation);
+
+/// Wraps a list of reassociations in an ArrayAttr.
+ArrayAttr
+getReassociationIndicesAttribute(OpBuilder &b,
+ ArrayRef<ReassociationIndices> reassociation);
+
+/// Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
+SmallVector<ReassociationIndices, 2> convertReassociationMapsToIndices(
+ OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs);
+
/// Return the reassociations maps to use to reshape given the source type and
/// the target type when possible. Return llvm::None when this computation
/// failed.
@@ -78,7 +95,7 @@ void printReshapeOp(OpAsmPrinter &p, ReshapeLikeOp op) {
p << "] ";
p.printOptionalAttrDict(op->getAttrs(),
- /*elidedAttrs=*/{op.getReassociationAttrName()});
+ /*elidedAttrs=*/{getReassociationAttrName()});
p << ": " << op.src().getType() << " into " << op.getType();
}
diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index 9357407f5521..580bc059e14f 100644
--- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -13,6 +13,7 @@
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
@@ -93,48 +94,6 @@ class RangeOpConversion : public ConvertOpToLLVMPattern<RangeOp> {
}
};
-// ReshapeOp creates a new view descriptor of the proper rank.
-// For now, the only conversion supported is for target MemRef with static sizes
-// and strides.
-template <typename ReshapeOp>
-class ReshapeOpConversion : public ConvertOpToLLVMPattern<ReshapeOp> {
-public:
- using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
- using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
-
- LogicalResult
- matchAndRewrite(ReshapeOp reshapeOp, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- MemRefType dstType = reshapeOp.getResultType();
-
- if (!dstType.hasStaticShape())
- return failure();
-
- int64_t offset;
- SmallVector<int64_t, 4> strides;
- auto res = getStridesAndOffset(dstType, strides, offset);
- if (failed(res) || llvm::any_of(strides, [](int64_t val) {
- return ShapedType::isDynamicStrideOrOffset(val);
- }))
- return failure();
-
- ReshapeOpAdaptor adaptor(operands);
- MemRefDescriptor baseDesc(adaptor.src());
- Location loc = reshapeOp->getLoc();
- auto desc =
- MemRefDescriptor::undef(rewriter, reshapeOp->getLoc(),
- this->typeConverter->convertType(dstType));
- desc.setAllocatedPtr(rewriter, loc, baseDesc.allocatedPtr(rewriter, loc));
- desc.setAlignedPtr(rewriter, loc, baseDesc.alignedPtr(rewriter, loc));
- desc.setOffset(rewriter, loc, baseDesc.offset(rewriter, loc));
- for (auto en : llvm::enumerate(dstType.getShape()))
- desc.setConstantSize(rewriter, loc, en.index(), en.value());
- for (auto en : llvm::enumerate(strides))
- desc.setConstantStride(rewriter, loc, en.index(), en.value());
- rewriter.replaceOp(reshapeOp, {desc});
- return success();
- }
-};
// YieldOp produces and LLVM::ReturnOp.
class YieldOpConversion : public ConvertOpToLLVMPattern<linalg::YieldOp> {
@@ -153,9 +112,7 @@ class YieldOpConversion : public ConvertOpToLLVMPattern<linalg::YieldOp> {
/// Populate the given list with patterns that convert from Linalg to LLVM.
void mlir::populateLinalgToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
- patterns.add<RangeOpConversion, ReshapeOpConversion<ExpandShapeOp>,
- ReshapeOpConversion<CollapseShapeOp>, YieldOpConversion>(
- converter);
+ patterns.add<RangeOpConversion, YieldOpConversion>(converter);
// Populate the type conversions for the linalg types.
converter.addConversion(
@@ -176,6 +133,7 @@ void ConvertLinalgToLLVMPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
LLVMTypeConverter converter(&getContext());
populateLinalgToLLVMConversionPatterns(converter, patterns);
+ populateMemRefToLLVMConversionPatterns(converter, patterns);
LLVMConversionTarget target(getContext());
target.addIllegalOp<RangeOp>();
diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
index 6f422e5f629f..f47b8878a17c 100644
--- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
+++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
@@ -186,9 +186,7 @@ void ConvertLinalgToStandardPass::runOnOperation() {
ConversionTarget target(getContext());
target.addLegalDialect<AffineDialect, memref::MemRefDialect, scf::SCFDialect,
StandardOpsDialect>();
- target.addLegalOp<ModuleOp, FuncOp, ReturnOp>();
- target.addLegalOp<linalg::ExpandShapeOp, linalg::CollapseShapeOp,
- linalg::RangeOp>();
+ target.addLegalOp<ModuleOp, FuncOp, ReturnOp, linalg::RangeOp>();
RewritePatternSet patterns(&getContext());
populateLinalgToStandardConversionPatterns(patterns);
if (failed(applyFullConversion(module, target, std::move(patterns))))
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index a92c4069fcad..4d5fe3307a8a 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1000,6 +1000,49 @@ struct MemRefReshapeOpLowering
}
};
+// ReshapeOp creates a new view descriptor of the proper rank.
+// For now, the only conversion supported is for target MemRef with static sizes
+// and strides.
+template <typename ReshapeOp>
+class ReassociatingReshapeOpConversion
+ : public ConvertOpToLLVMPattern<ReshapeOp> {
+public:
+ using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
+ using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
+
+ LogicalResult
+ matchAndRewrite(ReshapeOp reshapeOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ MemRefType dstType = reshapeOp.getResultType();
+
+ if (!dstType.hasStaticShape())
+ return failure();
+
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ auto res = getStridesAndOffset(dstType, strides, offset);
+ if (failed(res) || llvm::any_of(strides, [](int64_t val) {
+ return ShapedType::isDynamicStrideOrOffset(val);
+ }))
+ return failure();
+
+ ReshapeOpAdaptor adaptor(operands);
+ MemRefDescriptor baseDesc(adaptor.src());
+ Location loc = reshapeOp->getLoc();
+ auto desc =
+ MemRefDescriptor::undef(rewriter, reshapeOp->getLoc(),
+ this->typeConverter->convertType(dstType));
+ desc.setAllocatedPtr(rewriter, loc, baseDesc.allocatedPtr(rewriter, loc));
+ desc.setAlignedPtr(rewriter, loc, baseDesc.alignedPtr(rewriter, loc));
+ desc.setOffset(rewriter, loc, baseDesc.offset(rewriter, loc));
+ for (auto en : llvm::enumerate(dstType.getShape()))
+ desc.setConstantSize(rewriter, loc, en.index(), en.value());
+ for (auto en : llvm::enumerate(strides))
+ desc.setConstantStride(rewriter, loc, en.index(), en.value());
+ rewriter.replaceOp(reshapeOp, {desc});
+ return success();
+ }
+};
/// Conversion pattern that transforms a subview op into:
/// 1. An `llvm.mlir.undef` operation to create a memref descriptor
/// 2. Updates to the descriptor to introduce the data ptr, offset, size
@@ -1355,6 +1398,8 @@ void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
MemRefReinterpretCastOpLowering,
MemRefReshapeOpLowering,
PrefetchOpLowering,
+ ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
+ ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
StoreOpLowering,
SubViewOpLowering,
TransposeOpLowering,
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 12fbb8ce839b..b95fc5da547b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/Matchers.h"
@@ -1103,14 +1104,6 @@ OpFoldResult PadTensorOp::fold(ArrayRef<Attribute>) {
// ReshapeOp
//===----------------------------------------------------------------------===//
-static void print(OpAsmPrinter &p, linalg::ExpandShapeOp op) {
- ::mlir::printReshapeOp<linalg::ExpandShapeOp>(p, op);
-}
-
-static void print(OpAsmPrinter &p, linalg::CollapseShapeOp op) {
- ::mlir::printReshapeOp<linalg::CollapseShapeOp>(p, op);
-}
-
static void print(OpAsmPrinter &p, linalg::TensorExpandShapeOp op) {
::mlir::printReshapeOp<linalg::TensorExpandShapeOp>(p, op);
}
@@ -1260,20 +1253,6 @@ convertReassociationIndicesToExprs(
return reassociationMaps;
}
-SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
- return getSymbolLessAffineMaps(getReassociationExprs());
-}
-SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
- OpBuilder b(this->getContext());
- return convertReassociationIndicesToExprs(b, getReassociationIndices());
-}
-SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
- return getSymbolLessAffineMaps(getReassociationExprs());
-}
-SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
- OpBuilder b(this->getContext());
- return convertReassociationIndicesToExprs(b, getReassociationIndices());
-}
SmallVector<AffineMap, 4> TensorCollapseShapeOp::getReassociationMaps() {
return getSymbolLessAffineMaps(getReassociationExprs());
@@ -1422,71 +1401,6 @@ getReassociationIndicesAttribute(OpBuilder &b,
return b.getArrayAttr(reassociationAttr);
}
-void mlir::linalg::ExpandShapeOp::build(
- OpBuilder &b, OperationState &result, Value src,
- ArrayRef<ReassociationIndices> reassociation,
- ArrayRef<NamedAttribute> attrs) {
- auto memRefType = src.getType().cast<MemRefType>();
- auto resultType = computeReshapeCollapsedType(
- memRefType, getSymbolLessAffineMaps(
- convertReassociationIndicesToExprs(b, reassociation)));
- build(b, result, resultType, src, attrs);
- result.addAttribute(getReassociationAttrName(),
- getReassociationIndicesAttribute(b, reassociation));
-}
-
-Value mlir::linalg::ExpandShapeOp::getViewSource() { return src(); }
-
-void mlir::linalg::CollapseShapeOp::build(
- OpBuilder &b, OperationState &result, Value src,
- ArrayRef<ReassociationIndices> reassociation,
- ArrayRef<NamedAttribute> attrs) {
- auto memRefType = src.getType().cast<MemRefType>();
- auto resultType = computeReshapeCollapsedType(
- memRefType, getSymbolLessAffineMaps(
- convertReassociationIndicesToExprs(b, reassociation)));
- build(b, result, resultType, src, attrs);
- result.addAttribute(getReassociationAttrName(),
- getReassociationIndicesAttribute(b, reassociation));
-}
-
-Value mlir::linalg::CollapseShapeOp::getViewSource() { return src(); }
-
-template <typename ReshapeOp,
- bool isExpansion = std::is_same<ReshapeOp, ExpandShapeOp>::value>
-static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType,
- MemRefType collapsedType) {
- if (failed(
- verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
- return failure();
- auto maps = op.getReassociationMaps();
- MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps);
- if (collapsedType != expectedType)
- return op.emitOpError("expected collapsed type to be ")
- << expectedType << ", but got " << collapsedType;
- return success();
-}
-
-static LogicalResult verify(ExpandShapeOp op) {
- return verifyReshapeOp(op, op.getResultType(), op.getSrcType());
-}
-
-void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<CollapseReshapeOps<ExpandShapeOp>,
- CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>>(context);
-}
-
-static LogicalResult verify(CollapseShapeOp op) {
- return verifyReshapeOp(op, op.getSrcType(), op.getResultType());
-}
-
-void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<CollapseReshapeOps<CollapseShapeOp>,
- CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>>(context);
-}
-
//===----------------------------------------------------------------------===//
// TensorReshapeOp
//===----------------------------------------------------------------------===//
@@ -2433,16 +2347,6 @@ std::string mlir::linalg::generateLibraryCallName(Operation *op) {
// TODO: Consider making all this boilerplate easy to autogenerate
// with Tablegen. This seems a desirable property in the context of
// OpInterfaces where a Linalg "named" op **isa** LinalgOp.
-OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
- if (succeeded(foldMemRefCast(*this)))
- return getResult();
- return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
-}
-OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
- if (succeeded(foldMemRefCast(*this)))
- return getResult();
- return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands);
-}
OpFoldResult TensorExpandShapeOp::fold(ArrayRef<Attribute> operands) {
return foldReshapeOp<TensorExpandShapeOp, TensorCollapseShapeOp>(*this,
operands);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
index e4dcbde6f1da..b918b98a76b8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
@@ -155,8 +155,8 @@ class BufferizeTensorReshapeOp : public OpConversionPattern<TensorReshapeOp> {
public:
using OpConversionPattern<TensorReshapeOp>::OpConversionPattern;
using ReshapeOp = typename std::conditional_t<
- std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value, ExpandShapeOp,
- CollapseShapeOp>;
+ std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value,
+ memref::ExpandShapeOp, memref::CollapseShapeOp>;
LogicalResult
matchAndRewrite(TensorReshapeOp op, ArrayRef<Value> operands,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 865d4bf9227c..de847d1f0fe7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -352,7 +352,7 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
convertAffineMapArrayToExprs(reassociationMap));
}
if (origResultType.isa<MemRefType>()) {
- return rewriter.create<linalg::ExpandShapeOp>(
+ return rewriter.create<memref::ExpandShapeOp>(
loc, origResultType, result,
convertAffineMapArrayToExprs(reassociationMap));
}
@@ -368,7 +368,7 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
if (operandType == newInputOutputType)
return operand;
if (operandType.isa<MemRefType>()) {
- return rewriter.create<linalg::CollapseShapeOp>(
+ return rewriter.create<memref::CollapseShapeOp>(
loc, newInputOutputType, operand,
convertAffineMapArrayToExprs(reassociationMap));
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 518539376c9f..03fa871b1f7a 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1300,6 +1300,189 @@ static LogicalResult verify(ReinterpretCastOp op) {
return success();
}
+//===----------------------------------------------------------------------===//
+// Reassociative reshape ops
+//===----------------------------------------------------------------------===//
+
+SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
+ return getSymbolLessAffineMaps(getReassociationExprs());
+}
+SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
+ OpBuilder b(this->getContext());
+ return convertReassociationIndicesToExprs(b, getReassociationIndices());
+}
+
+SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
+ return getSymbolLessAffineMaps(getReassociationExprs());
+}
+SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
+ OpBuilder b(this->getContext());
+ return convertReassociationIndicesToExprs(b, getReassociationIndices());
+}
+
+static void print(OpAsmPrinter &p, ExpandShapeOp op) {
+ ::mlir::printReshapeOp<ExpandShapeOp>(p, op);
+}
+
+static void print(OpAsmPrinter &p, CollapseShapeOp op) {
+ ::mlir::printReshapeOp<CollapseShapeOp>(p, op);
+}
+
+/// Detect whether memref dims [dim, dim + extent) can be reshaped without
+/// copies.
+static bool isReshapableDimBand(unsigned dim, unsigned extent,
+ ArrayRef<int64_t> sizes,
+ ArrayRef<AffineExpr> strides) {
+ assert(sizes.size() == strides.size() && "mismatched ranks");
+ // off by 1 indexing to avoid out of bounds
+ // V
+ for (auto idx = dim, e = dim + extent; idx + 1 < e; ++idx) {
+ // Only bands of static shapes are reshapable. This is due to the fact that
+ // there is no relation between dynamic sizes and dynamic strides: we do not
+ // have enough information to know whether a "-1" size corresponds to the
+ // proper symbol in the AffineExpr of a stride.
+ if (ShapedType::isDynamic(sizes[dim + 1]))
+ return false;
+ // TODO: Refine this by passing the proper nDims and nSymbols so we can
+ // simplify on the fly and catch more reshapable cases.
+ if (strides[idx] != strides[idx + 1] * sizes[idx + 1])
+ return false;
+ }
+ return true;
+}
+
+/// Compute the MemRefType obtained by applying the `reassociation` (which is
+/// expected to be valid) to `type`.
+/// If `type` is Contiguous MemRefType, this always produce a contiguous
+/// MemRefType.
+static MemRefType
+computeReshapeCollapsedType(MemRefType type,
+ ArrayRef<AffineMap> reassociation) {
+ auto sizes = type.getShape();
+ AffineExpr offset;
+ SmallVector<AffineExpr, 4> strides;
+ auto status = getStridesAndOffset(type, strides, offset);
+ (void)status;
+ assert(succeeded(status) && "expected strided memref");
+
+ SmallVector<int64_t, 4> newSizes;
+ newSizes.reserve(reassociation.size());
+ SmallVector<AffineExpr, 4> newStrides;
+ newStrides.reserve(reassociation.size());
+
+ // Use the fact that reassociation is valid to simplify the logic: only use
+ // each map's rank.
+ assert(isReassociationValid(reassociation) && "invalid reassociation");
+ unsigned currentDim = 0;
+ for (AffineMap m : reassociation) {
+ unsigned dim = m.getNumResults();
+ int64_t size = 1;
+ AffineExpr stride = strides[currentDim + dim - 1];
+ if (!isReshapableDimBand(currentDim, dim, sizes, strides)) {
+ size = ShapedType::kDynamicSize;
+ stride = AffineExpr();
+ } else {
+ for (unsigned d = 0; d < dim; ++d)
+ size *= sizes[currentDim + d];
+ }
+ newSizes.push_back(size);
+ newStrides.push_back(stride);
+ currentDim += dim;
+ }
+
+ // Early-exit: if `type` is contiguous, the result must be contiguous.
+ if (canonicalizeStridedLayout(type).getAffineMaps().empty())
+ return MemRefType::Builder(type).setShape(newSizes).setAffineMaps({});
+
+ // Convert back to int64_t because we don't have enough information to create
+ // new strided layouts from AffineExpr only. This corresponds to a case where
+ // copies may be necessary.
+ int64_t intOffset = ShapedType::kDynamicStrideOrOffset;
+ if (auto o = offset.dyn_cast<AffineConstantExpr>())
+ intOffset = o.getValue();
+ SmallVector<int64_t, 4> intStrides;
+ intStrides.reserve(strides.size());
+ for (auto stride : newStrides) {
+ if (auto cst = stride.dyn_cast_or_null<AffineConstantExpr>())
+ intStrides.push_back(cst.getValue());
+ else
+ intStrides.push_back(ShapedType::kDynamicStrideOrOffset);
+ }
+ auto layout =
+ makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext());
+ return canonicalizeStridedLayout(
+ MemRefType::Builder(type).setShape(newSizes).setAffineMaps({layout}));
+}
+
+void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src,
+ ArrayRef<ReassociationIndices> reassociation,
+ ArrayRef<NamedAttribute> attrs) {
+ auto memRefType = src.getType().cast<MemRefType>();
+ auto resultType = computeReshapeCollapsedType(
+ memRefType, getSymbolLessAffineMaps(
+ convertReassociationIndicesToExprs(b, reassociation)));
+ build(b, result, resultType, src, attrs);
+ result.addAttribute(getReassociationAttrName(),
+ getReassociationIndicesAttribute(b, reassociation));
+}
+
+void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
+ ArrayRef<ReassociationIndices> reassociation,
+ ArrayRef<NamedAttribute> attrs) {
+ auto memRefType = src.getType().cast<MemRefType>();
+ auto resultType = computeReshapeCollapsedType(
+ memRefType, getSymbolLessAffineMaps(
+ convertReassociationIndicesToExprs(b, reassociation)));
+ build(b, result, resultType, src, attrs);
+ result.addAttribute(getReassociationAttrName(),
+ getReassociationIndicesAttribute(b, reassociation));
+}
+
+template <typename ReshapeOp,
+ bool isExpansion = std::is_same<ReshapeOp, ExpandShapeOp>::value>
+static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType,
+ MemRefType collapsedType) {
+ if (failed(
+ verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
+ return failure();
+ auto maps = op.getReassociationMaps();
+ MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps);
+ if (collapsedType != expectedType)
+ return op.emitOpError("expected collapsed type to be ")
+ << expectedType << ", but got " << collapsedType;
+ return success();
+}
+
+static LogicalResult verify(ExpandShapeOp op) {
+ return verifyReshapeOp(op, op.getResultType(), op.getSrcType());
+}
+
+void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<CollapseReshapeOps<ExpandShapeOp>,
+ CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>>(context);
+}
+
+static LogicalResult verify(CollapseShapeOp op) {
+ return verifyReshapeOp(op, op.getSrcType(), op.getResultType());
+}
+
+void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<CollapseReshapeOps<CollapseShapeOp>,
+ CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>>(context);
+}
+OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
+ if (succeeded(foldMemRefCast(*this)))
+ return getResult();
+ return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
+}
+OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
+ if (succeeded(foldMemRefCast(*this)))
+ return getResult();
+ return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands);
+}
+
//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 4cd72e2c9ff3..0c7aa52848b3 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -15,8 +15,6 @@
using namespace mlir;
-constexpr StringRef mlir::getReassociationAttrName() { return "reassociation"; }
-
Optional<SmallVector<ReassociationIndices>>
mlir::getReassociationIndicesForReshape(ShapedType sourceType,
ShapedType targetType) {
@@ -183,6 +181,70 @@ Optional<SmallVector<ReassociationIndices>> mlir::composeReassociationIndices(
return composedIndices;
}
+SmallVector<SmallVector<AffineExpr, 2>, 2>
+mlir::convertReassociationIndicesToExprs(
+ OpBuilder &b, ArrayRef<ReassociationIndices> reassociationIndices) {
+ SmallVector<SmallVector<AffineExpr, 2>, 2> reassociationMaps;
+ for (const auto &indices : reassociationIndices) {
+ SmallVector<AffineExpr, 2> reassociationMap;
+ reassociationMap.reserve(indices.size());
+ for (int64_t index : indices)
+ reassociationMap.push_back(b.getAffineDimExpr(index));
+ reassociationMaps.push_back(std::move(reassociationMap));
+ }
+ return reassociationMaps;
+}
+
+template <typename AffineExprTy>
+unsigned getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays) {
+ unsigned pos = 0;
+ for (const auto &exprs : exprArrays) {
+ for (auto expr : exprs) {
+ expr.walk([&pos](AffineExpr e) {
+ if (auto d = e.dyn_cast<AffineExprTy>())
+ pos = std::max(pos, d.getPosition());
+ });
+ }
+ }
+ return pos;
+}
+
+ArrayAttr mlir::getReassociationIndicesAttribute(
+ OpBuilder &b, ArrayRef<ReassociationIndices> reassociation) {
+ SmallVector<Attribute, 4> reassociationAttr =
+ llvm::to_vector<4>(llvm::map_range(
+ reassociation, [&](ReassociationIndices indices) -> Attribute {
+ return b.getI64ArrayAttr(indices).cast<Attribute>();
+ }));
+ return b.getArrayAttr(reassociationAttr);
+}
+
+SmallVector<ReassociationIndices, 2> mlir::convertReassociationMapsToIndices(
+ OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs) {
+ SmallVector<ReassociationIndices, 2> reassociationIndices;
+ for (const auto &exprs : reassociationExprs) {
+ ReassociationIndices indices;
+ indices.reserve(exprs.size());
+ for (const auto &expr : exprs)
+ indices.push_back(expr.cast<AffineDimExpr>().getPosition());
+ reassociationIndices.push_back(indices);
+ }
+ return reassociationIndices;
+}
+
+SmallVector<AffineMap, 4>
+mlir::getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation) {
+ unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation);
+ assert(getMaxPosOfType<AffineSymbolExpr>(reassociation) == 0 &&
+ "Expected symbol-less expressions");
+ SmallVector<AffineMap, 4> maps;
+ maps.reserve(reassociation.size());
+ for (const auto &exprs : reassociation) {
+ assert(!exprs.empty());
+ maps.push_back(AffineMap::get(maxDim + 1, 0, exprs, exprs[0].getContext()));
+ }
+ return maps;
+}
bool mlir::isReassociationValid(ArrayRef<AffineMap> reassociation,
int *invalidIndex) {
if (reassociation.empty())
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 2c4d386a9b3f..aa372355f5f6 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -698,3 +698,105 @@ func @get_gv3_memref() {
return
}
+// -----
+
+func @expand_shape_static(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> {
+ // Reshapes that expand a contiguous tensor with some 1's.
+ %0 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]]
+ : memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
+ return %0 : memref<1x3x4x1x5xf32>
+}
+// CHECK-LABEL: func @expand_shape_static
+// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK: llvm.mlir.constant(1 : index) : i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK: llvm.mlir.constant(3 : index) : i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK: llvm.mlir.constant(4 : index) : i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK: llvm.mlir.constant(1 : index) : i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 3] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK: llvm.mlir.constant(5 : index) : i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 4] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK: llvm.mlir.constant(60 : index) : i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK: llvm.mlir.constant(20 : index) : i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK: llvm.mlir.constant(5 : index) : i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK: llvm.mlir.constant(5 : index) : i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 3] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK: llvm.mlir.constant(1 : index) : i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 4] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+
+// -----
+
+func @collapse_shape_static(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> {
+ %0 = memref.collapse_shape %arg0 [[0, 1], [2], [3, 4]] :
+ memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
+ return %0 : memref<3x4x5xf32>
+}
+// CHECK-LABEL: func @collapse_shape_static
+// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.mlir.constant(3 : index) : i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.mlir.constant(4 : index) : i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.mlir.constant(5 : index) : i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.mlir.constant(20 : index) : i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.mlir.constant(5 : index) : i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: llvm.mlir.constant(1 : index) : i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
+
+// -----
+
+func @collapse_shape_fold_zero_dim(%arg0 : memref<1x1xf32>) -> memref<f32> {
+ %0 = memref.collapse_shape %arg0 [] : memref<1x1xf32> into memref<f32>
+ return %0 : memref<f32>
+}
+// CHECK-LABEL: func @collapse_shape_fold_zero_dim
+// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+
+// -----
+
+func @expand_shape_zero_dim(%arg0 : memref<f32>) -> memref<1x1xf32> {
+ %0 = memref.expand_shape %arg0 [] : memref<f32> into memref<1x1xf32>
+ return %0 : memref<1x1xf32>
+}
+// CHECK-LABEL: func @expand_shape_zero_dim
+// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: llvm.mlir.constant(1 : index) : i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: llvm.mlir.constant(1 : index) : i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: llvm.mlir.constant(1 : index) : i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: llvm.mlir.constant(1 : index) : i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index 98ab9ab06796..fa95deb1dbaa 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -261,7 +261,7 @@ func @bufferize_tensor_collapse_shape(%arg0: tensor<4x5xf32>) -> tensor<20xf32>
return %out : tensor<20xf32>
}
// CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[IN]] : memref<4x5xf32>
-// CHECK: %[[RESHAPE:.*]] = linalg.collapse_shape %[[MEMREF]] {{\[}}[0, 1]]
+// CHECK: %[[RESHAPE:.*]] = memref.collapse_shape %[[MEMREF]] {{\[}}[0, 1]]
// CHECK-SAME: : memref<4x5xf32> into memref<20xf32>
// CHECK: %[[TENSOR:.*]] = memref.tensor_load %[[RESHAPE]] : memref<20xf32>
// CHECK: return %[[TENSOR]]
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 03e19909fa28..c453255d3948 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -81,19 +81,6 @@ func @collapsing_tensor_reshapes_to_zero_dim(%arg0 : tensor<1x1x1xf32>)
// -----
-func @collapsing_memref_reshapes_to_zero_dim(%arg0 : memref<1x1x1xf32>)
- -> memref<f32> {
- %0 = linalg.collapse_shape %arg0 [[0, 1, 2]]
- : memref<1x1x1xf32> into memref<1xf32>
- %1 = linalg.collapse_shape %0 [] : memref<1xf32> into memref<f32>
- return %1 : memref<f32>
-}
-// CHECK-LABEL: collapsing_memref_reshapes_to_zero
-// CHECK: linalg.collapse_shape %{{.*}} []
-// CHECK-SAME: memref<1x1x1xf32> into memref<f32>
-
-// -----
-
func @expanding_tensor_reshapes(%arg0 : tensor<?x?xf32>) -> tensor<?x6x4x?x5xf32>
{
%0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2]]
@@ -108,34 +95,6 @@ func @expanding_tensor_reshapes(%arg0 : tensor<?x?xf32>) -> tensor<?x6x4x?x5xf32
// -----
-func @collapsing_memref_reshapes(%arg0 : memref<?x?x?x?x?xf32>) -> memref<?x?xf32>
-{
- %0 = linalg.collapse_shape %arg0 [[0, 1], [2], [3, 4]]
- : memref<?x?x?x?x?xf32> into memref<?x?x?xf32>
- %1 = linalg.collapse_shape %0 [[0, 1], [2]]
- : memref<?x?x?xf32> into memref<?x?xf32>
- return %1 : memref<?x?xf32>
-}
-// CHECK-LABEL: collapsing_memref_reshapes
-// CHECK: linalg.collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
-// CHECK-NOT: linalg.collapse_shape
-
-// -----
-
-func @expanding_memref_reshapes(%arg0 : memref<?x?xf32>) -> memref<?x6x4x5x?xf32>
-{
- %0 = linalg.expand_shape %arg0 [[0, 1], [2]]
- : memref<?x?xf32> into memref<?x4x?xf32>
- %1 = linalg.expand_shape %0 [[0, 1], [2], [3, 4]]
- : memref<?x4x?xf32> into memref<?x6x4x5x?xf32>
- return %1 : memref<?x6x4x5x?xf32>
-}
-// CHECK-LABEL: expanding_memref_reshapes
-// CHECK: linalg.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
-// CHECK-NOT: linalg.expand_shape
-
-// -----
-
func @expanding_tensor_reshapes_to_zero_dim(%arg0 : tensor<f32>)
-> tensor<1x1x1xf32> {
%0 = linalg.tensor_expand_shape %arg0 [] : tensor<f32> into tensor<1xf32>
@@ -149,19 +108,6 @@ func @expanding_tensor_reshapes_to_zero_dim(%arg0 : tensor<f32>)
// -----
-func @expanding_memref_reshapes_to_zero_dim(%arg0 : memref<f32>)
- -> memref<1x1x1xf32> {
- %0 = linalg.expand_shape %arg0 [] : memref<f32> into memref<1xf32>
- %1 = linalg.expand_shape %0 [[0, 1, 2]]
- : memref<1xf32> into memref<1x1x1xf32>
- return %1 : memref<1x1x1xf32>
-}
-// CHECK-LABEL: expanding_memref_reshapes_to_zero
-// CHECK: linalg.expand_shape %{{.*}} []
-// CHECK-SAME: memref<f32> into memref<1x1x1xf32>
-
-// -----
-
func @fold_tensor_reshape(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32>
{
%0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2]]
@@ -188,32 +134,6 @@ func @fold_tensor_reshape_dynamic(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
// -----
-func @fold_memref_reshape(%arg0 : memref<12x4xf32>) -> memref<12x4xf32>
-{
- %0 = linalg.expand_shape %arg0 [[0, 1], [2]]
- : memref<12x4xf32> into memref<3x4x4xf32>
- %1 = linalg.collapse_shape %0 [[0, 1], [2]]
- : memref<3x4x4xf32> into memref<12x4xf32>
- return %1 : memref<12x4xf32>
-}
-// CHECK-LABEL: @fold_memref_reshape
-// CHECK-NOT: linalg.{{.*}}_shape
-
-// -----
-
-func @fold_memref_reshape_dynamic(%arg0 : memref<?x?xf32>) -> memref<?x?xf32>
-{
- %0 = linalg.expand_shape %arg0 [[0, 1], [2]]
- : memref<?x?xf32> into memref<?x4x?xf32>
- %1 = linalg.collapse_shape %0 [[0, 1], [2]]
- : memref<?x4x?xf32> into memref<?x?xf32>
- return %1 : memref<?x?xf32>
-}
-// CHECK-LABEL: @fold_memref_reshape_dynamic
-// CHECK-NOT: linalg.{{.*}}_shape
-
-// -----
-
func @reshape_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>) -> tensor<24x5x42x8xf32>
{
%0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2, 3, 4, 5, 6]]
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 69353eb7b744..a4357b6e4cd1 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -479,7 +479,7 @@ func @drop_one_trip_loops(%arg0 : memref<?x1x?xf32>, %arg1 : f32, %shape: memref
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> ()>
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @drop_one_trip_loops
-// CHECK: linalg.collapse_shape %{{.*}} {{\[}}[0, 1], [2]]
+// CHECK: memref.collapse_shape %{{.*}} {{\[}}[0, 1], [2]]
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP3]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
@@ -556,7 +556,7 @@ func @drop_all_loops(%arg0 : memref<1x1xf32>) -> memref<1x1xf32>
}
// CHECK: #[[$MAP0:.*]] = affine_map<() -> ()>
// CHECK-LABEL: func @drop_all_loops
-// CHECK: linalg.collapse_shape %{{.*}} []
+// CHECK: memref.collapse_shape %{{.*}} []
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]]]
// CHECK-SAME: iterator_types = []
@@ -617,7 +617,7 @@ func @leading_dim_1_canonicalization(%arg0: memref<1x5xf32>, %shape: memref<5xf3
// CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: func @leading_dim_1_canonicalization
-// CHECK: linalg.collapse_shape %{{.*}} {{\[}}[0, 1]]
+// CHECK: memref.collapse_shape %{{.*}} {{\[}}[0, 1]]
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP1]]]
// CHECK-SAME: iterator_types = ["parallel"]
@@ -638,8 +638,8 @@ func @leading_dim_1_canonicalization(%arg0: memref<1x5xf32>, %shape: memref<5xf3
func @broadcast_test(%arg0 : memref<5xf32>, %arg1 : memref<5xf32>, %shape : memref<5x5xf32>) -> memref<5x5xf32>
{
- %0 = linalg.expand_shape %arg0 [[0, 1]] : memref<5xf32> into memref<1x5xf32>
- %1 = linalg.expand_shape %arg1 [[0, 1]] : memref<5xf32> into memref<5x1xf32>
+ %0 = memref.expand_shape %arg0 [[0, 1]] : memref<5xf32> into memref<1x5xf32>
+ %1 = memref.expand_shape %arg1 [[0, 1]] : memref<5xf32> into memref<5x1xf32>
linalg.generic #trait
ins(%0, %1 : memref<1x5xf32>, memref<5x1xf32>)
outs(%shape : memref<5x5xf32>) {
@@ -686,7 +686,7 @@ func @broadcast_scalar(%arg0 : memref<1x1xf32>, %shape : memref<?x?xf32>) -> mem
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @broadcast_scalar
// CHECK-SAME: %[[ARG0:.*]]: memref<1x1xf32>
-// CHECK: %[[A:.*]] = linalg.collapse_shape %[[ARG0]] []
+// CHECK: %[[A:.*]] = memref.collapse_shape %[[ARG0]] []
// CHECK-SAME: memref<1x1xf32> into memref<f32>
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
@@ -706,16 +706,16 @@ func @fold_unit_dim_memref_reshape_op(%arg0 : memref<5xf32>) -> memref<2x5xf32>
^bb0(%arg1: f32, %arg2: f32): // no predecessors
linalg.yield %arg1 : f32
}
- %3 = linalg.collapse_shape %1 [[0, 1], [2]]
+ %3 = memref.collapse_shape %1 [[0, 1], [2]]
: memref<1x2x5xf32> into memref<2x5xf32>
return %3 : memref<2x5xf32>
}
// CHECK-LABEL: func @fold_unit_dim_memref_reshape_op
// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<1x2x5xf32>
-// CHECK: %[[OUT:.*]] = linalg.collapse_shape %[[ALLOC]]
+// CHECK: %[[OUT:.*]] = memref.collapse_shape %[[ALLOC]]
// CHECK: linalg.generic
// CHECK-SAME: outs(%[[OUT:.*]] :
-// CHECK: %[[RESULT:.*]] = linalg.collapse_shape %[[ALLOC]]
+// CHECK: %[[RESULT:.*]] = memref.collapse_shape %[[ALLOC]]
// CHECK: return %[[RESULT]]
// -----
@@ -740,8 +740,8 @@ func @fold_unit_dim_for_init_memref(%input: memref<1x1000xf32>) -> memref<1xf32>
// CHECK: func @fold_unit_dim_for_init_memref
// CHECK: %[[INIT:.+]] = memref.alloc() : memref<1xf32>
-// CHECK: %[[INPUT_RESHAPE:.+]] = linalg.collapse_shape %{{.+}} {{\[}}[0, 1]] : memref<1x1000xf32> into memref<1000xf32>
-// CHECK: %[[INIT_RESHAPE:.+]] = linalg.collapse_shape %[[INIT]] [] : memref<1xf32> into memref<f32>
+// CHECK: %[[INPUT_RESHAPE:.+]] = memref.collapse_shape %{{.+}} {{\[}}[0, 1]] : memref<1x1000xf32> into memref<1000xf32>
+// CHECK: %[[INIT_RESHAPE:.+]] = memref.collapse_shape %[[INIT]] [] : memref<1xf32> into memref<f32>
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP2]]]
// CHECK-SAME: iterator_types = ["reduction"]
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 1de474708803..569b9a1b387d 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -308,58 +308,6 @@ func @generic(%arg0: memref<?x?xi4>) {
// -----
-func @reshape(%arg0: memref<f32>) {
- // expected-error @+1 {{expected non-zero memref ranks}}
- %0 = linalg.expand_shape %arg0 [[0]] : memref<f32> into memref<f32>
-}
-
-// -----
-
-func @collapse_to_higher_rank(%arg0: memref<f32>) {
- // expected-error @+1 {{expected the type 'memref<f32>' to have higher rank than the type = 'memref<1xf32>'}}
- %0 = linalg.collapse_shape %arg0 [[0]] : memref<f32> into memref<1xf32>
-}
-
-// -----
-
-func @expand_to_smaller_rank(%arg0: memref<1xf32>) {
- // expected-error @+1 {{expected the type 'memref<f32>' to have higher rank than the type = 'memref<1xf32>'}}
- %0 = linalg.expand_shape %arg0 [[0]] : memref<1xf32> into memref<f32>
-}
-
-// -----
-
-func @reshape(%arg0: memref<?xf32>) {
- // expected-error @+1 {{expected to collapse or expand dims}}
- %0 = linalg.collapse_shape %arg0 [[0]] : memref<?xf32> into memref<?xf32>
-}
-
-// -----
-
-func @reshape(%arg0: memref<?x?x?xf32>) {
- // expected-error @+1 {{expected rank of the collapsed type(2) to be the number of reassociation maps(1)}}
- %0 = linalg.collapse_shape %arg0 [[0, 1]] :
- memref<?x?x?xf32> into memref<?x?xf32, offset: 0, strides: [?, 1]>
-}
-
-// -----
-
-func @reshape(%arg0: memref<?x?x?xf32>) {
- // expected-error @+1 {{expected reassociation map #1 to be valid and contiguous}}
- %0 = linalg.collapse_shape %arg0 [[0, 1], [1, 2]] :
- memref<?x?x?xf32> into memref<?x?xf32, offset: 0, strides: [?, 1]>
-}
-
-// -----
-
-func @reshape(%arg0: memref<?x?x?xf32>) {
- // expected-error @+1 {{expected collapsed type to be 'memref<?x?xf32>', but got 'memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * s0 + d1)>>'}}
- %0 = linalg.collapse_shape %arg0 [[0, 1], [2]] :
- memref<?x?x?xf32> into memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * s0 + d1)>>
-}
-
-// -----
-
func @pooling_rank_mismatch(%arg0: memref<?x?x?xf32>,
%arg1: memref<2x3xf32>,
%arg2: memref<?x?x?xf32>) {
@@ -397,7 +345,6 @@ func @matching_inits(%m: memref<?x?xf32>, %t: tensor<?x?xf32>) {
return
}
-
// -----
func @init_tensor_err(%arg0 : index, %arg1 : index)
@@ -438,16 +385,6 @@ func @illegal_expanding_reshape_dynamic_tensor
// -----
-func @illegal_expanding_reshape_dynamic_memref
- (%arg0: memref<?x?x?xf32>) -> memref<?x?x?x4x?xf32>
-{
- // expected-error @+1 {{invalid to have a single dimension (2) expanded into multiple dynamic dims (2,4)}}
- %0 = linalg.expand_shape %arg0 [[0], [1], [2, 3, 4]]
- : memref<?x?x?xf32> into memref<?x?x?x4x?xf32>
- return %0 : memref<?x?x?x4x?xf32>
-}
-
-// -----
func @illegal_expanding_reshape_static_tensor
(%arg0: tensor<2x3x20xf32>) -> tensor<2x3x2x4x5xf32>
@@ -471,28 +408,6 @@ func @illegal_collapsing_reshape_static_tensor
// -----
-func @illegal_expanding_reshape_static_memref
- (%arg0: memref<2x3x20xf32>) -> memref<2x3x2x4x5xf32>
-{
- // expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}}
- %0 = linalg.expand_shape %arg0 [[0], [1], [2, 3, 4]]
- : memref<2x3x20xf32> into memref<2x3x2x4x5xf32>
- return %0 : memref<2x3x2x4x5xf32>
-}
-
-// -----
-
-func @illegal_collapsing_reshape_static_memref
- (%arg0: memref<2x3x2x4x5xf32>) -> memref<2x3x20xf32>
-{
- // expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}}
- %0 = linalg.collapse_shape %arg0 [[0], [1], [2, 3, 4]]
- : memref<2x3x2x4x5xf32> into memref<2x3x20xf32>
- return %0 : memref<2x3x20xf32>
-}
-
-// -----
-
func @illegal_expanding_reshape_mixed_tensor(%arg0 : tensor<?x?xf32>) -> tensor<?x4x5xf32>
{
// expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}}
@@ -533,46 +448,6 @@ func @illegal_collapsing_reshape_mixed_tensor_2(%arg0 : tensor<?x4x5xf32>) -> te
// -----
-func @illegal_expanding_reshape_mixed_memref(%arg0 : memref<?x?xf32>) -> memref<?x4x5xf32>
-{
- // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}}
- %0 = linalg.expand_shape %arg0 [[0, 1], [2]]
- : memref<?x?xf32> into memref<?x4x5xf32>
- return %0 : memref<?x4x5xf32>
-}
-
-// -----
-
-func @illegal_expanding_reshape_mixed_memref_2(%arg0 : memref<?x?xf32>) -> memref<?x4x5xf32>
-{
- // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}}
- %0 = linalg.expand_shape %arg0 [[0], [1, 2]]
- : memref<?x?xf32> into memref<?x4x5xf32>
- return %0 : memref<?x4x5xf32>
-}
-
-// -----
-
-func @illegal_collapsing_reshape_mixed_memref(%arg0 : memref<?x4x5xf32>) -> memref<?x?xf32>
-{
- // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}}
- %0 = linalg.collapse_shape %arg0 [[0, 1], [2]]
- : memref<?x4x5xf32> into memref<?x?xf32>
- return %0 : memref<?x?xf32>
-}
-
-// -----
-
-func @illegal_collapse_reshape_mixed_memref_2(%arg0 : memref<?x4x5xf32>) -> memref<?x?xf32>
-{
- // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}}
- %0 = linalg.collapse_shape %arg0 [[0], [1, 2]]
- : memref<?x4x5xf32> into memref<?x?xf32>
- return %0 : memref<?x?xf32>
-}
-
-// -----
-
func @pad_result_type(%arg0: tensor<?x2x3x4xi32>, %arg1: index, %arg2: i32) -> tensor<?x?x?x8xf32> {
// expected-error @+1 {{specified type 'tensor<?x?x?x8xf32>' does not match the inferred type 'tensor<?x?x?x9xi32>}}
%0 = linalg.pad_tensor %arg0 low[1, %arg1, 2, 2] high[1, 2, %arg1, 3] {
@@ -824,6 +699,6 @@ func @invalid_reverse(%A: memref<5xf32>, %B: memref<5xf32>) {
linalg.generic #attrs ins(%A: memref<5xf32>) outs(%B: memref<5xf32>) {
^bb0(%a: f32, %b: f32):
linalg.yield %a : f32
- }
+ }
return
}
diff --git a/mlir/test/Dialect/Linalg/llvm.mlir b/mlir/test/Dialect/Linalg/llvm.mlir
index 4ddee115c0b0..6c60c5f3af3a 100644
--- a/mlir/test/Dialect/Linalg/llvm.mlir
+++ b/mlir/test/Dialect/Linalg/llvm.mlir
@@ -13,98 +13,3 @@ func @range(%arg0: index) {
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(i64, i64, i64)>
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(i64, i64, i64)>
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(i64, i64, i64)>
-
-func @expand_shape_static(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> {
- // Reshapes that expand a contiguous tensor with some 1's.
- %0 = linalg.expand_shape %arg0 [[0, 1], [2], [3, 4]]
- : memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
- return %0 : memref<1x3x4x1x5xf32>
-}
-// CHECK-LABEL: func @expand_shape_static
-// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK: llvm.mlir.constant(1 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK: llvm.mlir.constant(3 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK: llvm.mlir.constant(4 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK: llvm.mlir.constant(1 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 3] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK: llvm.mlir.constant(5 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 4] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK: llvm.mlir.constant(60 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK: llvm.mlir.constant(20 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK: llvm.mlir.constant(5 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK: llvm.mlir.constant(5 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 3] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK: llvm.mlir.constant(1 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 4] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-
-func @collapse_shape_static(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> {
- %0 = linalg.collapse_shape %arg0 [[0, 1], [2], [3, 4]] :
- memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
- return %0 : memref<3x4x5xf32>
-}
-// CHECK-LABEL: func @collapse_shape_static
-// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.mlir.constant(3 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.mlir.constant(4 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.mlir.constant(5 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.mlir.constant(20 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.mlir.constant(5 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: llvm.mlir.constant(1 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
-
-func @collapse_shape_fold_zero_dim(%arg0 : memref<1x1xf32>) -> memref<f32> {
- %0 = linalg.collapse_shape %arg0 [] : memref<1x1xf32> into memref<f32>
- return %0 : memref<f32>
-}
-// CHECK-LABEL: func @collapse_shape_fold_zero_dim
-// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
-// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
-// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
-// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
-
-func @expand_shape_zero_dim(%arg0 : memref<f32>) -> memref<1x1xf32> {
- %0 = linalg.expand_shape %arg0 [] : memref<f32> into memref<1x1xf32>
- return %0 : memref<1x1xf32>
-}
-// CHECK-LABEL: func @expand_shape_zero_dim
-// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64)>
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.mlir.constant(1 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.mlir.constant(1 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.mlir.constant(1 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: llvm.mlir.constant(1 : index) : i64
-// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 5d842bba5960..e0d7ab2dfb24 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -12,9 +12,7 @@
// CHECK-DAG: #[[$permute_1:.*]] = affine_map<(d0, d1, d2) -> (d2, d1, d0)>
// CHECK-DAG: #[[$strided1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
// CHECK-DAG: #[[$strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
-// CHECK-DAG: #[[$strided2DOFF0:.*]] = affine_map<(d0, d1)[s0] -> (d0 * s0 + d1)>
// CHECK-DAG: #[[$strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>
-// CHECK-DAG: #[[$strided3DOFF0:.*]] = affine_map<(d0, d1, d2)[s0, s1] -> (d0 * s0 + d1 * s1 + d2)>
// CHECK-DAG: #[[$strided3DT:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2 * s1 + s0 + d1 * s2 + d0)>
// CHECK-DAG: #[[$strided6D:.*]] = affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4 + d4 * s5 + d5)>
@@ -169,7 +167,6 @@ func @ops(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
// -----
-
func @fill_view(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: f32) {
linalg.fill(%arg1, %arg0) : f32, memref<?xf32, offset: ?, strides: [1]>
return
@@ -541,96 +538,6 @@ func @generic_region(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1
// -----
-func @reshape_static(%arg0: memref<3x4x5xf32>, %arg1: tensor<3x4x5xf32>,
- %arg2: tensor<3x?x5xf32>) {
- // Reshapes that collapse and expand back a contiguous buffer.
- %0 = linalg.collapse_shape %arg0 [[0, 1], [2]] :
- memref<3x4x5xf32> into memref<12x5xf32>
- %r0 = linalg.expand_shape %0 [[0, 1], [2]] :
- memref<12x5xf32> into memref<3x4x5xf32>
- %1 = linalg.collapse_shape %arg0 [[0], [1, 2]] :
- memref<3x4x5xf32> into memref<3x20xf32>
- %r1 = linalg.expand_shape %1 [[0], [1, 2]] :
- memref<3x20xf32> into memref<3x4x5xf32>
- %2 = linalg.collapse_shape %arg0 [[0, 1, 2]] :
- memref<3x4x5xf32> into memref<60xf32>
- %r2 = linalg.expand_shape %2 [[0, 1, 2]] :
- memref<60xf32> into memref<3x4x5xf32>
- // Reshapes that expand and collapse back a contiguous buffer with some 1's.
- %3 = linalg.expand_shape %arg0 [[0, 1], [2], [3, 4]] :
- memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
- %r3 = linalg.collapse_shape %3 [[0, 1], [2], [3, 4]] :
- memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
- // Reshapes on tensors.
- %t0 = linalg.tensor_expand_shape %arg1 [[0, 1], [2], [3, 4]] :
- tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32>
- %rt0 = linalg.tensor_collapse_shape %t0 [[0, 1], [2], [3, 4]] :
- tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32>
- %t1 = linalg.tensor_expand_shape %arg2 [[0, 1], [2], [3, 4]] :
- tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32>
- %rt1 = linalg.tensor_collapse_shape %t1 [[0], [1, 2], [3, 4]] :
- tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32>
- return
-}
-// CHECK-LABEL: func @reshape_static
-// CHECK: linalg.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
-// CHECK-SAME: memref<3x4x5xf32> into memref<12x5xf32>
-// CHECK: linalg.expand_shape {{.*}} {{\[}}[0, 1], [2]]
-// CHECK-SAME: memref<12x5xf32> into memref<3x4x5xf32>
-// CHECK: linalg.collapse_shape {{.*}} {{\[}}[0], [1, 2]]
-// CHECK-SAME: memref<3x4x5xf32> into memref<3x20xf32>
-// CHECK: linalg.expand_shape {{.*}} {{\[}}[0], [1, 2]]
-// CHECK-SAME: memref<3x20xf32> into memref<3x4x5xf32>
-// CHECK: linalg.collapse_shape {{.*}} {{\[}}[0, 1, 2]]
-// CHECK-SAME: memref<3x4x5xf32> into memref<60xf32>
-// CHECK: linalg.expand_shape {{.*}} {{\[}}[0, 1, 2]]
-// CHECK-SAME: memref<60xf32> into memref<3x4x5xf32>
-// CHECK: linalg.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
-// CHECK-SAME: memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
-// CHECK: linalg.collapse_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
-// CHECK-SAME: memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
-//
-// CHECK: linalg.tensor_expand_shape {{.*}}: tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32>
-// CHECK: linalg.tensor_collapse_shape {{.*}}: tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32>
-// CHECK: linalg.tensor_expand_shape {{.*}}: tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32>
-// CHECK: linalg.tensor_collapse_shape {{.*}}: tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32>
-
-// -----
-
-func @reshape_dynamic(%arg0: memref<?x?x?xf32>,
- %arg1: memref<?x?x?xf32, offset : 0, strides : [?, ?, 1]>,
- %arg2: memref<?x?x?xf32, offset : ?, strides : [?, ?, 1]>) {
- %0 = linalg.collapse_shape %arg0 [[0, 1], [2]] :
- memref<?x?x?xf32> into memref<?x?xf32>
- %r0 = linalg.expand_shape %0 [[0, 1], [2]] :
- memref<?x?xf32> into memref<?x4x?xf32>
- %1 = linalg.collapse_shape %arg1 [[0, 1], [2]] :
- memref<?x?x?xf32, offset : 0, strides : [?, ?, 1]> into
- memref<?x?xf32, offset : 0, strides : [?, 1]>
- %r1 = linalg.expand_shape %1 [[0, 1], [2]] :
- memref<?x?xf32, offset : 0, strides : [?, 1]> into
- memref<?x4x?xf32, offset : 0, strides : [?, ?, 1]>
- %2 = linalg.collapse_shape %arg2 [[0, 1], [2]] :
- memref<?x?x?xf32, offset : ?, strides : [?, ?, 1]> into
- memref<?x?xf32, offset : ?, strides : [?, 1]>
- %r2 = linalg.expand_shape %2 [[0, 1], [2]] :
- memref<?x?xf32, offset : ?, strides : [?, 1]> into
- memref<?x4x?xf32, offset : ?, strides : [?, ?, 1]>
- return
-}
-// CHECK-LABEL: func @reshape
-// CHECK: linalg.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
-// CHECK-SAME: memref<?x?x?xf32> into memref<?x?xf32>
-// CHECK: linalg.expand_shape {{.*}} {{\[}}[0, 1], [2]]
-// CHECK-SAME: memref<?x?xf32> into memref<?x4x?xf32>
-// CHECK: linalg.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
-// CHECK-SAME: memref<?x?x?xf32, #[[$strided3DOFF0]]> into memref<?x?xf32, #[[$strided2DOFF0]]>
-// CHECK: linalg.expand_shape {{.*}} {{\[}}[0, 1], [2]]
-// CHECK-SAME: memref<?x?xf32, #[[$strided2DOFF0]]> into memref<?x4x?xf32, #[[$strided3DOFF0]]>
-// CHECK: linalg.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
-// CHECK-SAME: memref<?x?x?xf32, #[[$strided3D]]> into memref<?x?xf32, #[[$strided2D]]>
-// CHECK: linalg.expand_shape {{.*}} {{\[}}[0, 1], [2]]
-// CHECK-SAME: memref<?x?xf32, #[[$strided2D]]> into memref<?x4x?xf32, #[[$strided3D]]>
func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?x?xf32>, %c3: memref<?x?x?xf32>,
%ta3: tensor<?x?x?xf32>, %tb3: tensor<?x?x?xf32>, %tc3: tensor<?x?x?xf32>)
@@ -670,17 +577,6 @@ func @tensor_reshape_zero_dim(%arg0 : tensor<1x1xf32>, %arg1 : tensor<f32>) -> (
// -----
-func @memref_reshape_zero_dim(%arg0 : memref<1x1xf32>, %arg1 : memref<f32>) -> (memref<f32>, memref<1x1xf32>)
-{
- %0 = linalg.collapse_shape %arg0 [] : memref<1x1xf32> into memref<f32>
- %1 = linalg.expand_shape %0 [] : memref<f32> into memref<1x1xf32>
- return %0, %1 : memref<f32>, memref<1x1xf32>
-}
-// CHECK-LABEL: func @memref_reshape_zero_dim
-// CHECK: linalg.collapse_shape %{{.*}} [] : memref<1x1xf32> into memref<f32>
-// CHECK: linalg.expand_shape %{{.*}} [] : memref<f32> into memref<1x1xf32>
-
-// -----
func @init_tensor(%arg0 : index, %arg1 : index)
{
@@ -707,19 +603,6 @@ func @legal_collapsing_reshape_dynamic_tensor
// -----
-func @legal_collapsing_reshape_dynamic_memref
- (%arg0: memref<?x?x?x4x?xf32>) -> memref<?x?x?xf32>
-{
- %0 = linalg.collapse_shape %arg0 [[0], [1], [2, 3, 4]] :
- memref<?x?x?x4x?xf32> into memref<?x?x?xf32>
- return %0 : memref<?x?x?xf32>
-}
-// CHECK: func @legal_collapsing_reshape_dynamic_memref
-// CHECK: linalg.collapse_shape
-// CHECK-SAME: [0], [1], [2, 3, 4]
-
-// -----
-
func @fill_tensor(%arg0 : index, %arg1 : index, %arg2 : f32) -> tensor<?x?xf32> {
%0 = linalg.init_tensor [%arg0, %arg1] : tensor<?x?xf32>
%1 = linalg.fill(%arg2, %0) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 2ae2c06dea92..02a8ce4441c3 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -395,3 +395,81 @@ func @allocator(%arg0 : memref<memref<?xi32>>, %arg1 : index) {
memref.store %0, %arg0[] : memref<memref<?xi32>>
return
}
+
+// -----
+
+func @collapsing_memref_reshapes_to_zero_dim(%arg0 : memref<1x1x1xf32>)
+ -> memref<f32> {
+ %0 = memref.collapse_shape %arg0 [[0, 1, 2]]
+ : memref<1x1x1xf32> into memref<1xf32>
+ %1 = memref.collapse_shape %0 [] : memref<1xf32> into memref<f32>
+ return %1 : memref<f32>
+}
+// CHECK-LABEL: collapsing_memref_reshapes_to_zero
+// CHECK: memref.collapse_shape %{{.*}} []
+// CHECK-SAME: memref<1x1x1xf32> into memref<f32>
+
+// -----
+
+func @collapsing_memref_reshapes(%arg0 : memref<?x?x?x?x?xf32>)
+ -> memref<?x?xf32> {
+ %0 = memref.collapse_shape %arg0 [[0, 1], [2], [3, 4]]
+ : memref<?x?x?x?x?xf32> into memref<?x?x?xf32>
+ %1 = memref.collapse_shape %0 [[0, 1], [2]]
+ : memref<?x?x?xf32> into memref<?x?xf32>
+ return %1 : memref<?x?xf32>
+}
+// CHECK-LABEL: collapsing_memref_reshapes
+// CHECK: memref.collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
+// CHECK-NOT: memref.collapse_shape
+
+// -----
+
+func @expanding_memref_reshapes(%arg0 : memref<?x?xf32>)
+ -> memref<?x6x4x5x?xf32> {
+ %0 = memref.expand_shape %arg0 [[0, 1], [2]]
+ : memref<?x?xf32> into memref<?x4x?xf32>
+ %1 = memref.expand_shape %0 [[0, 1], [2], [3, 4]]
+ : memref<?x4x?xf32> into memref<?x6x4x5x?xf32>
+ return %1 : memref<?x6x4x5x?xf32>
+}
+// CHECK-LABEL: expanding_memref_reshapes
+// CHECK: memref.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
+// CHECK-NOT: memref.expand_shape
+
+// -----
+
+func @expanding_memref_reshapes_to_zero_dim(%arg0 : memref<f32>)
+ -> memref<1x1x1xf32> {
+ %0 = memref.expand_shape %arg0 [] : memref<f32> into memref<1xf32>
+ %1 = memref.expand_shape %0 [[0, 1, 2]]
+ : memref<1xf32> into memref<1x1x1xf32>
+ return %1 : memref<1x1x1xf32>
+}
+// CHECK-LABEL: expanding_memref_reshapes_to_zero
+// CHECK: memref.expand_shape %{{.*}} []
+// CHECK-SAME: memref<f32> into memref<1x1x1xf32>
+
+// -----
+
+func @fold_memref_reshape(%arg0 : memref<12x4xf32>) -> memref<12x4xf32> {
+ %0 = memref.expand_shape %arg0 [[0, 1], [2]]
+ : memref<12x4xf32> into memref<3x4x4xf32>
+ %1 = memref.collapse_shape %0 [[0, 1], [2]]
+ : memref<3x4x4xf32> into memref<12x4xf32>
+ return %1 : memref<12x4xf32>
+}
+// CHECK-LABEL: @fold_memref_reshape
+// CHECK-NOT: linalg.{{.*}}_shape
+
+// -----
+
+func @fold_memref_reshape_dynamic(%arg0 : memref<?x?xf32>) -> memref<?x?xf32> {
+ %0 = memref.expand_shape %arg0 [[0, 1], [2]]
+ : memref<?x?xf32> into memref<?x4x?xf32>
+ %1 = memref.collapse_shape %0 [[0, 1], [2]]
+ : memref<?x4x?xf32> into memref<?x?xf32>
+ return %1 : memref<?x?xf32>
+}
+// CHECK-LABEL: @fold_memref_reshape_dynamic
+// CHECK-NOT: linalg.{{.*}}_shape
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 63209ef10894..dcd1a6b12849 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -231,3 +231,125 @@ func @copy_
diff erent_eltype(%arg0: memref<2xf32>, %arg1: memref<2xf16>) {
memref.copy %arg0, %arg1 : memref<2xf32> to memref<2xf16>
return
}
+
+// -----
+
+func @expand_shape(%arg0: memref<f32>) {
+ // expected-error @+1 {{expected non-zero memref ranks}}
+ %0 = memref.expand_shape %arg0 [[0]] : memref<f32> into memref<f32>
+}
+
+// -----
+
+func @collapse_shape_to_higher_rank(%arg0: memref<f32>) {
+ // expected-error @+1 {{expected the type 'memref<f32>' to have higher rank than the type = 'memref<1xf32>'}}
+ %0 = memref.collapse_shape %arg0 [[0]] : memref<f32> into memref<1xf32>
+}
+
+// -----
+
+func @expand_shape_to_smaller_rank(%arg0: memref<1xf32>) {
+ // expected-error @+1 {{expected the type 'memref<f32>' to have higher rank than the type = 'memref<1xf32>'}}
+ %0 = memref.expand_shape %arg0 [[0]] : memref<1xf32> into memref<f32>
+}
+
+// -----
+
+func @collapse_shape(%arg0: memref<?xf32>) {
+ // expected-error @+1 {{expected to collapse or expand dims}}
+ %0 = memref.collapse_shape %arg0 [[0]] : memref<?xf32> into memref<?xf32>
+}
+
+// -----
+
+func @collapse_shape_mismatch_indices_num(%arg0: memref<?x?x?xf32>) {
+ // expected-error @+1 {{expected rank of the collapsed type(2) to be the number of reassociation maps(1)}}
+ %0 = memref.collapse_shape %arg0 [[0, 1]] :
+ memref<?x?x?xf32> into memref<?x?xf32, offset: 0, strides: [?, 1]>
+}
+
+// -----
+
+func @collapse_shape_invalid_reassociation(%arg0: memref<?x?x?xf32>) {
+ // expected-error @+1 {{expected reassociation map #1 to be valid and contiguous}}
+ %0 = memref.collapse_shape %arg0 [[0, 1], [1, 2]] :
+ memref<?x?x?xf32> into memref<?x?xf32, offset: 0, strides: [?, 1]>
+}
+
+// -----
+
+func @collapse_shape_wrong_collapsed_type(%arg0: memref<?x?x?xf32>) {
+ // expected-error @+1 {{expected collapsed type to be 'memref<?x?xf32>', but got 'memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * s0 + d1)>>'}}
+ %0 = memref.collapse_shape %arg0 [[0, 1], [2]] :
+ memref<?x?x?xf32> into memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * s0 + d1)>>
+}
+
+// -----
+
+func @expand_shape_illegal_dynamic_memref
+ (%arg0: memref<?x?x?xf32>) -> memref<?x?x?x4x?xf32> {
+ // expected-error @+1 {{invalid to have a single dimension (2) expanded into multiple dynamic dims (2,4)}}
+ %0 = memref.expand_shape %arg0 [[0], [1], [2, 3, 4]]
+ : memref<?x?x?xf32> into memref<?x?x?x4x?xf32>
+ return %0 : memref<?x?x?x4x?xf32>
+}
+
+// -----
+
+func @expand_shape_illegal_static_memref
+ (%arg0: memref<2x3x20xf32>) -> memref<2x3x2x4x5xf32> {
+ // expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}}
+ %0 = memref.expand_shape %arg0 [[0], [1], [2, 3, 4]]
+ : memref<2x3x20xf32> into memref<2x3x2x4x5xf32>
+ return %0 : memref<2x3x2x4x5xf32>
+}
+
+// -----
+
+func @collapse_shape_illegal_static_memref
+ (%arg0: memref<2x3x2x4x5xf32>) -> memref<2x3x20xf32> {
+ // expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}}
+ %0 = memref.collapse_shape %arg0 [[0], [1], [2, 3, 4]]
+ : memref<2x3x2x4x5xf32> into memref<2x3x20xf32>
+ return %0 : memref<2x3x20xf32>
+}
+
+// -----
+
+func @expand_shape_illegal_mixed_memref(%arg0 : memref<?x?xf32>)
+ -> memref<?x4x5xf32> {
+ // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}}
+ %0 = memref.expand_shape %arg0 [[0, 1], [2]]
+ : memref<?x?xf32> into memref<?x4x5xf32>
+ return %0 : memref<?x4x5xf32>
+}
+
+// -----
+
+func @expand_shape_illegal_mixed_memref_2(%arg0 : memref<?x?xf32>)
+ -> memref<?x4x5xf32> {
+ // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}}
+ %0 = memref.expand_shape %arg0 [[0], [1, 2]]
+ : memref<?x?xf32> into memref<?x4x5xf32>
+ return %0 : memref<?x4x5xf32>
+}
+
+// -----
+
+func @collapse_shape_illegal_mixed_memref(%arg0 : memref<?x4x5xf32>)
+ -> memref<?x?xf32> {
+ // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}}
+ %0 = memref.collapse_shape %arg0 [[0, 1], [2]]
+ : memref<?x4x5xf32> into memref<?x?xf32>
+ return %0 : memref<?x?xf32>
+}
+
+// -----
+
+func @collapse_shape_illegal_mixed_memref_2(%arg0 : memref<?x4x5xf32>)
+ -> memref<?x?xf32> {
+ // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}}
+ %0 = memref.collapse_shape %arg0 [[0], [1, 2]]
+ : memref<?x4x5xf32> into memref<?x?xf32>
+ return %0 : memref<?x?xf32>
+}
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 993a6131ab51..714b769099a1 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -1,6 +1,11 @@
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
+// CHECK-DAG: #[[$strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+// CHECK-DAG: #[[$strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>
+// CHECK-DAG: #[[$strided2DOFF0:.*]] = affine_map<(d0, d1)[s0] -> (d0 * s0 + d1)>
+// CHECK-DAG: #[[$strided3DOFF0:.*]] = affine_map<(d0, d1, d2)[s0, s1] -> (d0 * s0 + d1 * s1 + d2)>
+
// CHECK-LABEL: test_buffer_cast
func @test_buffer_cast(%arg0: tensor<?xi64>, %arg1: tensor<*xi64>) -> (memref<?xi64, affine_map<(d0) -> (d0 + 7)>>, memref<*xi64, 1>) {
%0 = memref.buffer_cast %arg0 : memref<?xi64, affine_map<(d0) -> (d0 + 7)>>
@@ -95,3 +100,114 @@ func @memref_alloca_scope() {
}
return
}
+
+func @expand_collapse_shape_static(%arg0: memref<3x4x5xf32>,
+ %arg1: tensor<3x4x5xf32>,
+ %arg2: tensor<3x?x5xf32>) {
+ // Reshapes that collapse and expand back a contiguous buffer.
+ %0 = memref.collapse_shape %arg0 [[0, 1], [2]] :
+ memref<3x4x5xf32> into memref<12x5xf32>
+ %r0 = memref.expand_shape %0 [[0, 1], [2]] :
+ memref<12x5xf32> into memref<3x4x5xf32>
+ %1 = memref.collapse_shape %arg0 [[0], [1, 2]] :
+ memref<3x4x5xf32> into memref<3x20xf32>
+ %r1 = memref.expand_shape %1 [[0], [1, 2]] :
+ memref<3x20xf32> into memref<3x4x5xf32>
+ %2 = memref.collapse_shape %arg0 [[0, 1, 2]] :
+ memref<3x4x5xf32> into memref<60xf32>
+ %r2 = memref.expand_shape %2 [[0, 1, 2]] :
+ memref<60xf32> into memref<3x4x5xf32>
+ // Reshapes that expand and collapse back a contiguous buffer with some 1's.
+ %3 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]] :
+ memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
+ %r3 = memref.collapse_shape %3 [[0, 1], [2], [3, 4]] :
+ memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
+ // Reshapes on tensors.
+ %t0 = linalg.tensor_expand_shape %arg1 [[0, 1], [2], [3, 4]] :
+ tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32>
+ %rt0 = linalg.tensor_collapse_shape %t0 [[0, 1], [2], [3, 4]] :
+ tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32>
+ %t1 = linalg.tensor_expand_shape %arg2 [[0, 1], [2], [3, 4]] :
+ tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32>
+ %rt1 = linalg.tensor_collapse_shape %t1 [[0], [1, 2], [3, 4]] :
+ tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32>
+ return
+}
+// CHECK-LABEL: func @expand_collapse_shape_static
+// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
+// CHECK-SAME: memref<3x4x5xf32> into memref<12x5xf32>
+// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
+// CHECK-SAME: memref<12x5xf32> into memref<3x4x5xf32>
+// CHECK: memref.collapse_shape {{.*}} {{\[}}[0], [1, 2]]
+// CHECK-SAME: memref<3x4x5xf32> into memref<3x20xf32>
+// CHECK: memref.expand_shape {{.*}} {{\[}}[0], [1, 2]]
+// CHECK-SAME: memref<3x20xf32> into memref<3x4x5xf32>
+// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1, 2]]
+// CHECK-SAME: memref<3x4x5xf32> into memref<60xf32>
+// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1, 2]]
+// CHECK-SAME: memref<60xf32> into memref<3x4x5xf32>
+// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
+// CHECK-SAME: memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
+// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
+// CHECK-SAME: memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
+//
+// CHECK: linalg.tensor_expand_shape {{.*}}: tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32>
+// CHECK: linalg.tensor_collapse_shape {{.*}}: tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32>
+// CHECK: linalg.tensor_expand_shape {{.*}}: tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32>
+// CHECK: linalg.tensor_collapse_shape {{.*}}: tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32>
+
+
+func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
+ %arg1: memref<?x?x?xf32, offset : 0, strides : [?, ?, 1]>,
+ %arg2: memref<?x?x?xf32, offset : ?, strides : [?, ?, 1]>) {
+ %0 = memref.collapse_shape %arg0 [[0, 1], [2]] :
+ memref<?x?x?xf32> into memref<?x?xf32>
+ %r0 = memref.expand_shape %0 [[0, 1], [2]] :
+ memref<?x?xf32> into memref<?x4x?xf32>
+ %1 = memref.collapse_shape %arg1 [[0, 1], [2]] :
+ memref<?x?x?xf32, offset : 0, strides : [?, ?, 1]> into
+ memref<?x?xf32, offset : 0, strides : [?, 1]>
+ %r1 = memref.expand_shape %1 [[0, 1], [2]] :
+ memref<?x?xf32, offset : 0, strides : [?, 1]> into
+ memref<?x4x?xf32, offset : 0, strides : [?, ?, 1]>
+ %2 = memref.collapse_shape %arg2 [[0, 1], [2]] :
+ memref<?x?x?xf32, offset : ?, strides : [?, ?, 1]> into
+ memref<?x?xf32, offset : ?, strides : [?, 1]>
+ %r2 = memref.expand_shape %2 [[0, 1], [2]] :
+ memref<?x?xf32, offset : ?, strides : [?, 1]> into
+ memref<?x4x?xf32, offset : ?, strides : [?, ?, 1]>
+ return
+}
+// CHECK-LABEL: func @expand_collapse_shape_dynamic
+// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
+// CHECK-SAME: memref<?x?x?xf32> into memref<?x?xf32>
+// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
+// CHECK-SAME: memref<?x?xf32> into memref<?x4x?xf32>
+// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
+// CHECK-SAME: memref<?x?x?xf32, #[[$strided3DOFF0]]> into memref<?x?xf32, #[[$strided2DOFF0]]>
+// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
+// CHECK-SAME: memref<?x?xf32, #[[$strided2DOFF0]]> into memref<?x4x?xf32, #[[$strided3DOFF0]]>
+// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
+// CHECK-SAME: memref<?x?x?xf32, #[[$strided3D]]> into memref<?x?xf32, #[[$strided2D]]>
+// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
+// CHECK-SAME: memref<?x?xf32, #[[$strided2D]]> into memref<?x4x?xf32, #[[$strided3D]]>
+
+func @expand_collapse_shape_zero_dim(%arg0 : memref<1x1xf32>, %arg1 : memref<f32>)
+ -> (memref<f32>, memref<1x1xf32>) {
+ %0 = memref.collapse_shape %arg0 [] : memref<1x1xf32> into memref<f32>
+ %1 = memref.expand_shape %0 [] : memref<f32> into memref<1x1xf32>
+ return %0, %1 : memref<f32>, memref<1x1xf32>
+}
+// CHECK-LABEL: func @expand_collapse_shape_zero_dim
+// CHECK: memref.collapse_shape %{{.*}} [] : memref<1x1xf32> into memref<f32>
+// CHECK: memref.expand_shape %{{.*}} [] : memref<f32> into memref<1x1xf32>
+
+func @collapse_shape_to_dynamic
+ (%arg0: memref<?x?x?x4x?xf32>) -> memref<?x?x?xf32> {
+ %0 = memref.collapse_shape %arg0 [[0], [1], [2, 3, 4]] :
+ memref<?x?x?x4x?xf32> into memref<?x?x?xf32>
+ return %0 : memref<?x?x?xf32>
+}
+// CHECK: func @collapse_shape_to_dynamic
+// CHECK: memref.collapse_shape
+// CHECK-SAME: [0], [1], [2, 3, 4]
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index c9f831f07a7d..1fd0a2bd2100 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -5862,6 +5862,7 @@ cc_library(
":LLVMDialect",
":LinalgOps",
":LinalgTransforms",
+ ":MemRefToLLVM",
":Pass",
":SCFDialect",
":SCFToStandard",
More information about the Mlir-commits
mailing list