[Mlir-commits] [mlir] 81264df - [mlir][Linalg] Add utility method to reshape ops to express output shape in terms of input shape.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 16 13:42:42 PST 2021
Author: MaheshRavishankar
Date: 2021-02-16T13:42:08-08:00
New Revision: 81264dfbe80df08668a325a61613b64243b99c01
URL: https://github.com/llvm/llvm-project/commit/81264dfbe80df08668a325a61613b64243b99c01
DIFF: https://github.com/llvm/llvm-project/commit/81264dfbe80df08668a325a61613b64243b99c01.diff
LOG: [mlir][Linalg] Add utility method to reshape ops to express output shape in terms of input shape.
Resolving the dim of outputs of a tensor_reshape op in terms of its
input shape allows the op to be eliminated when its used only in its
dims. The init_tensor -> tensor_reshape canonicalization can be
simplified to use the dims of the output of the tensor_reshape which
gets canonicalized away later making the tensor_reshape dead.
Differential Revision: https://reviews.llvm.org/D96635
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
index a98336382fe6..f22b00da01c9 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
@@ -107,6 +107,13 @@ SmallVector<AffineExpr, 4> concat(ArrayRef<AffineExpr> a,
void getDimsOfType(Operation *op, StringRef iteratorTypeName,
SmallVectorImpl<AffineExpr> &res);
+/// For reshape operation, compute the shape of the output based on the result
+/// type and shape of the input.
+SmallVector<Value, 4>
+getReshapeOutputShapeFromInputShape(OpBuilder &b, Location loc, Value src,
+ ArrayRef<int64_t> dstStaticShape,
+ ArrayRef<AffineMap> reassociation);
+
namespace detail {
LogicalResult verifyStructuredOpInterface(Operation *op);
} // namespace detail
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 7212700d641e..dc99e217aeb4 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -342,10 +342,15 @@ class Linalg_ReshapeLikeOp<string mnemonic, list<OpTrait> traits = []> :
SmallVector<ReassociationExprs, 4> getReassociationExprs() {
return
llvm::to_vector<4>(llvm::map_range(reassociation(),
- [](Attribute a) {
- return llvm::to_vector<2>(
- a.cast<AffineMapAttr>().getValue().getResults());
- }));
+ [](Attribute a) {
+ return llvm::to_vector<2>(
+ a.cast<AffineMapAttr>().getValue().getResults());
+ }));
+ }
+ SmallVector<Value, 4> getOutputShape(OpBuilder &b, Location loc) {
+ return getReshapeOutputShapeFromInputShape(
+ b, loc, src(), getResultType().getShape(),
+ getReassociationMaps());
}
}];
let assemblyFormat = [{
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 0efcddfbe1c6..7c348672dc37 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -605,85 +605,6 @@ Type InitTensorOp::inferResultType(ArrayRef<int64_t> staticSizes,
return RankedTensorType::get(staticSizes, elementType);
}
-namespace {
-/// Change the type of the result of a `linalg.init_tensor` by making the result
-/// type statically sized along dimension that in the original operation where
-/// defined as dynamic, but the size was defined using a `constant` op. For
-/// example
-///
-/// %c5 = constant 5: index
-/// %0 = linalg.init_tensor [%arg0, %c5] : tensor<?x?xf32>
-///
-/// to
-///
-/// %0 = linalg.init_tensor [%arg0, 5] : tensor<?x5xf32>
-struct ReplaceStaticShapeDims : OpRewritePattern<InitTensorOp> {
- using OpRewritePattern<InitTensorOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(InitTensorOp op,
- PatternRewriter &rewriter) const override {
- SmallVector<Value, 4> dynamicSizes;
- SmallVector<int64_t, 4> staticSizes;
- for (unsigned i = 0, e = op.getType().getRank(); i != e; ++i) {
- // If the size is already static, nothing to do.
- if (!op.isDynamicSize(i)) {
- staticSizes.push_back(op.getStaticSize(i));
- continue;
- }
-
- // If the size is dynamic but defined using a `constant` op, get the
- // constant value to find the static size to use.
- unsigned operandNum = op.getIndexOfDynamicSize(i);
- Value sizeOperand = op.getOperand(operandNum);
- if (auto constantIndexOp = sizeOperand.getDefiningOp<ConstantIndexOp>()) {
- staticSizes.push_back(constantIndexOp.getValue());
- continue;
- }
-
- // Fallback case. Keep the size dynamic.
- dynamicSizes.push_back(sizeOperand);
- staticSizes.push_back(ShapedType::kDynamicSize);
- }
- RankedTensorType newType =
- RankedTensorType::get(staticSizes, op.getType().getElementType());
- if (newType == op.getType())
- return failure();
- auto newOp =
- rewriter.create<InitTensorOp>(op.getLoc(), newType, dynamicSizes,
- rewriter.getI64ArrayAttr(staticSizes));
- rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
- return success();
- }
-};
-
-/// Canonicalize a `linalg.init_tensor` -> `dim` pattern by replacing the `dim`
-/// with
-/// - A constant value if the size is static along the dimension.
-/// - The dynamic value that defines the size of the result of
-/// `linalg.init_tensor` op.
-struct ReplaceDimOfInitTensorOp : public OpRewritePattern<DimOp> {
- using OpRewritePattern<DimOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(DimOp dimOp,
- PatternRewriter &rewriter) const override {
- auto initTensorOp = dimOp.memrefOrTensor().getDefiningOp<InitTensorOp>();
- if (!initTensorOp)
- return failure();
- auto dimIndex = dimOp.index().getDefiningOp<ConstantIndexOp>();
- if (!dimIndex)
- return failure();
- int64_t index = dimIndex.getValue();
- if (!initTensorOp.isDynamicSize(index)) {
- rewriter.replaceOpWithNewOp<ConstantIndexOp>(
- dimOp, initTensorOp.getStaticSize(index));
- } else {
- rewriter.replaceOp(dimOp, initTensorOp.getDynamicSize(index));
- }
- return success();
- }
-};
-} // namespace
-
static Value getCollapsedInitTensor(OpBuilder &builder,
TensorReshapeOp reshapeOp) {
Location loc = reshapeOp.getLoc();
@@ -773,6 +694,85 @@ static Value getExpandedInitTensor(OpBuilder &builder,
srcType.getElementType());
}
+namespace {
+/// Change the type of the result of a `linalg.init_tensor` by making the result
+/// type statically sized along dimension that in the original operation where
+/// defined as dynamic, but the size was defined using a `constant` op. For
+/// example
+///
+/// %c5 = constant 5: index
+/// %0 = linalg.init_tensor [%arg0, %c5] : tensor<?x?xf32>
+///
+/// to
+///
+/// %0 = linalg.init_tensor [%arg0, 5] : tensor<?x5xf32>
+struct ReplaceStaticShapeDims : OpRewritePattern<InitTensorOp> {
+ using OpRewritePattern<InitTensorOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(InitTensorOp op,
+ PatternRewriter &rewriter) const override {
+ SmallVector<Value, 4> dynamicSizes;
+ SmallVector<int64_t, 4> staticSizes;
+ for (unsigned i = 0, e = op.getType().getRank(); i != e; ++i) {
+ // If the size is already static, nothing to do.
+ if (!op.isDynamicSize(i)) {
+ staticSizes.push_back(op.getStaticSize(i));
+ continue;
+ }
+
+ // If the size is dynamic but defined using a `constant` op, get the
+ // constant value to find the static size to use.
+ unsigned operandNum = op.getIndexOfDynamicSize(i);
+ Value sizeOperand = op.getOperand(operandNum);
+ if (auto constantIndexOp = sizeOperand.getDefiningOp<ConstantIndexOp>()) {
+ staticSizes.push_back(constantIndexOp.getValue());
+ continue;
+ }
+
+ // Fallback case. Keep the size dynamic.
+ dynamicSizes.push_back(sizeOperand);
+ staticSizes.push_back(ShapedType::kDynamicSize);
+ }
+ RankedTensorType newType =
+ RankedTensorType::get(staticSizes, op.getType().getElementType());
+ if (newType == op.getType())
+ return failure();
+ auto newOp =
+ rewriter.create<InitTensorOp>(op.getLoc(), newType, dynamicSizes,
+ rewriter.getI64ArrayAttr(staticSizes));
+ rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
+ return success();
+ }
+};
+
+/// Canonicalize a `linalg.init_tensor` -> `dim` pattern by replacing the `dim`
+/// with
+/// - A constant value if the size is static along the dimension.
+/// - The dynamic value that defines the size of the result of
+/// `linalg.init_tensor` op.
+struct ReplaceDimOfInitTensorOp : public OpRewritePattern<DimOp> {
+ using OpRewritePattern<DimOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(DimOp dimOp,
+ PatternRewriter &rewriter) const override {
+ auto initTensorOp = dimOp.memrefOrTensor().getDefiningOp<InitTensorOp>();
+ if (!initTensorOp)
+ return failure();
+ auto dimIndex = dimOp.index().getDefiningOp<ConstantIndexOp>();
+ if (!dimIndex)
+ return failure();
+ int64_t index = dimIndex.getValue();
+ if (!initTensorOp.isDynamicSize(index)) {
+ rewriter.replaceOpWithNewOp<ConstantIndexOp>(
+ dimOp, initTensorOp.getStaticSize(index));
+ } else {
+ rewriter.replaceOp(dimOp, initTensorOp.getDynamicSize(index));
+ }
+ return success();
+ }
+};
+} // namespace
+
namespace {
/// Since `init_tensor` operation creates a tensor needed only for its shape, a
/// subtensor of this is also needed only for its shape. The result can be
@@ -803,17 +803,13 @@ struct FoldInitTensorWithTensorReshapeOp
PatternRewriter &rewriter) const override {
if (!reshapeOp.src().getDefiningOp<InitTensorOp>())
return failure();
- RankedTensorType collapsedType = reshapeOp.getSrcType();
- RankedTensorType expandedType = reshapeOp.getResultType();
- bool isCollapsed = expandedType.getRank() < collapsedType.getRank();
- if (isCollapsed)
- std::swap(collapsedType, expandedType);
- Value initTensorOp = isCollapsed
- ? getCollapsedInitTensor(rewriter, reshapeOp)
- : getExpandedInitTensor(rewriter, reshapeOp);
- if (!initTensorOp)
- return failure();
- rewriter.replaceOp(reshapeOp, initTensorOp);
+ Location loc = reshapeOp.getLoc();
+ SmallVector<Value, 4> resultShapeValues =
+ reshapeOp.getOutputShape(rewriter, loc);
+ Value initTensor = rewriter.create<InitTensorOp>(
+ loc, resultShapeValues, reshapeOp.getResultType().getElementType());
+ rewriter.replaceOpWithNewOp<tensor::CastOp>(
+ reshapeOp, reshapeOp.getResultType(), initTensor);
return success();
}
};
@@ -1255,6 +1251,141 @@ convertReassociationIndicesToMaps(
return reassociationMaps;
}
+/// For reshape op compute the shape at dimension `dimIndex` of the output in
+/// terms of shape of the `src`, when the reshape op is a collapsing
+/// operation. It is the product of the shape of the collapsed dimensions of the
+/// `src`.
+static Value
+getCollapsedOutputDimFromInputShape(OpBuilder &builder, Location loc,
+ int64_t dimIndex, Value src,
+ ArrayRef<AffineMap> reassociationMap) {
+ AffineMap map = reassociationMap[dimIndex];
+ unsigned startPos =
+ map.getResults().front().cast<AffineDimExpr>().getPosition();
+ unsigned endPos = map.getResults().back().cast<AffineDimExpr>().getPosition();
+ AffineExpr expr;
+ SmallVector<Value, 2> dynamicDims;
+ for (auto dim : llvm::seq(startPos, endPos + 1)) {
+ dynamicDims.push_back(builder.create<DimOp>(loc, src, dim));
+ AffineExpr currExpr = builder.getAffineSymbolExpr(dim - startPos);
+ expr = (expr ? expr * currExpr : currExpr);
+ }
+ return applyMapToValues(builder, loc,
+ AffineMap::get(0, endPos - startPos + 1, expr),
+ dynamicDims)[0];
+}
+
+/// Given the `src` of a collapsing reshape op and its reassociation maps,
+/// compute the shape of the result of the reshape.
+static SmallVector<Value, 4> getCollapsedOutputShapeFromInputShape(
+ OpBuilder &builder, Location loc, Value src,
+ ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
+ return llvm::to_vector<4>(llvm::map_range(
+ llvm::seq<int64_t>(0, dstStaticShape.size()), [&](int64_t dim) {
+ return getCollapsedOutputDimFromInputShape(builder, loc, dim, src,
+ reassociation);
+ }));
+}
+
+/// Compute a map that for a given dimension of the expanded type gives the
+/// dimension in the collapsed type it maps to. Essentially its the inverse of
+/// the `reassocation` maps.
+static llvm::DenseMap<int64_t, int64_t>
+getExpandedDimToCollapsedDimMap(ArrayRef<AffineMap> reassociation) {
+ llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
+ for (auto map : enumerate(reassociation)) {
+ unsigned startPos =
+ map.value().getResults().front().cast<AffineDimExpr>().getPosition();
+ unsigned endPos =
+ map.value().getResults().back().cast<AffineDimExpr>().getPosition();
+ for (auto dim : llvm::seq(startPos, endPos + 1)) {
+ expandedDimToCollapsedDim[dim] = map.index();
+ }
+ }
+ return expandedDimToCollapsedDim;
+}
+
+/// For an expanding reshape op, compute the value for a dimension of the output
+/// from the shape of the input.
+static Value getExpandedOutputDimFromInputShape(
+ OpBuilder &builder, Location loc, int64_t dimIndex, Value src,
+ ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation,
+ llvm::DenseMap<int64_t, int64_t> &expandedDimToCollapsedDim) {
+ if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) {
+ return builder.create<ConstantIndexOp>(loc, dstStaticShape[dimIndex]);
+ }
+ unsigned sourceDimPos = expandedDimToCollapsedDim[dimIndex];
+ unsigned startPos = reassociation[sourceDimPos]
+ .getResults()
+ .front()
+ .cast<AffineDimExpr>()
+ .getPosition();
+ unsigned endPos = reassociation[sourceDimPos]
+ .getResults()
+ .back()
+ .cast<AffineDimExpr>()
+ .getPosition();
+ int64_t linearizedStaticDim = 1;
+ for (auto d :
+ llvm::enumerate(dstStaticShape.slice(startPos, endPos - startPos + 1))) {
+ if (d.index() + startPos == static_cast<unsigned>(dimIndex))
+ continue;
+ assert(!ShapedType::isDynamic(d.value()) &&
+ "single dimension cannot be expanded into multiple dynamic "
+ "dimensions");
+ linearizedStaticDim *= d.value();
+ }
+ Value sourceDim = builder.create<DimOp>(loc, src, sourceDimPos);
+ return applyMapToValues(
+ builder, loc,
+ AffineMap::get(
+ 0, 1, builder.getAffineSymbolExpr(0).floorDiv(linearizedStaticDim)),
+ sourceDim)[0];
+}
+
+/// Given the `src` of an expanding reshape op, the reassociation maps and the
+/// result type, compute the shape of the result of the reshape.
+static SmallVector<Value, 4> getExpandedOutputShapeFromInputShape(
+ OpBuilder &builder, Location loc, Value src,
+ ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
+ llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim =
+ getExpandedDimToCollapsedDimMap(reassociation);
+ return llvm::to_vector<4>(llvm::map_range(
+ llvm::seq<int64_t>(0, dstStaticShape.size()), [&](int64_t dim) {
+ return getExpandedOutputDimFromInputShape(builder, loc, dim, src,
+ dstStaticShape, reassociation,
+ expandedDimToCollapsedDim);
+ }));
+}
+
+SmallVector<Value, 4> mlir::linalg::getReshapeOutputShapeFromInputShape(
+ OpBuilder &builder, Location loc, Value src,
+ ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassocation) {
+ return dstStaticShape.size() >
+ static_cast<size_t>(src.getType().cast<ShapedType>().getRank())
+ ? getExpandedOutputShapeFromInputShape(
+ builder, loc, src, dstStaticShape, reassocation)
+ : getCollapsedOutputShapeFromInputShape(
+ builder, loc, src, dstStaticShape, reassocation);
+}
+
+/// For a reshape op, compute the value of a given dimension of the output
+/// (`dimIndex`) from the shape of the inputs and type of the result.
+static Value getReshapeOutputDimFromInputShape(
+ OpBuilder &builder, Location loc, int64_t dimIndex, Value src,
+ ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
+ if (dstStaticShape.size() >
+ static_cast<size_t>(src.getType().cast<ShapedType>().getRank())) {
+ llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim =
+ getExpandedDimToCollapsedDimMap(reassociation);
+ return getExpandedOutputDimFromInputShape(builder, loc, dimIndex, src,
+ dstStaticShape, reassociation,
+ expandedDimToCollapsedDim);
+ }
+ return getCollapsedOutputDimFromInputShape(builder, loc, dimIndex, src,
+ reassociation);
+}
+
void mlir::linalg::ReshapeOp::build(OpBuilder &b, OperationState &result,
Value src,
ArrayRef<ReassociationExprs> reassociation,
@@ -1478,12 +1609,35 @@ struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
return success();
}
};
+
+/// Canonicalize dim ops that use the output shape with dim of the input.
+struct ReplaceDimOfReshapeOpResult : OpRewritePattern<DimOp> {
+ using OpRewritePattern<DimOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(DimOp dimOp,
+ PatternRewriter &rewriter) const override {
+ Value dimValue = dimOp.memrefOrTensor();
+ Optional<int64_t> dimIndex = dimOp.getConstantIndex();
+ if (!dimIndex)
+ return failure();
+
+ auto reshapeOp = dimValue.getDefiningOp<TensorReshapeOp>();
+ if (!reshapeOp)
+ return failure();
+
+ rewriter.replaceOp(dimOp,
+ getReshapeOutputDimFromInputShape(
+ rewriter, dimOp.getLoc(), *dimIndex, reshapeOp.src(),
+ reshapeOp.getResultType().getShape(),
+ reshapeOp.getReassociationMaps()));
+ return success();
+ }
+};
} // namespace
void TensorReshapeOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<CollapseReshapeOps<TensorReshapeOp>, FoldReshapeWithConstant>(
- context);
+ results.insert<CollapseReshapeOps<TensorReshapeOp>, FoldReshapeWithConstant,
+ ReplaceDimOfReshapeOpResult>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 75abef70cd4e..2fb5eb3086e6 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -560,10 +560,10 @@ func @init_tensor_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> {
tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32>
return %1 : tensor<2x3x5x4x?x7xf32>
}
+// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
// CHECK: func @init_tensor_reshape_expansion
// CHECK-SAME: %[[ARG0:.+]]: index
-// CHECK: %[[C28:.+]] = constant 28 : index
-// CHECK: %[[T0:.+]] = divi_unsigned %[[ARG0]], %[[C28]]
+// CHECK: %[[T0:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
// CHECK: %[[T1:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[T0]], 7]
// CHECK: return %[[T1]]
@@ -578,10 +578,10 @@ func @init_tensor_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> {
tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32>
return %1 : tensor<6x5x?xf32>
}
+// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)>
// CHECK: func @init_tensor_reshape_collapse
// CHECK-SAME: %[[ARG0:.+]]: index
-// CHECK: %[[C28:.+]] = constant 28 : index
-// CHECK: %[[T0:.+]] = muli %[[ARG0]], %[[C28]]
+// CHECK: %[[T0:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
// CHECK: %[[T1:.+]] = linalg.init_tensor [6, 5, %[[T0]]]
// CHECK: return %[[T1]]
@@ -716,3 +716,54 @@ func @dead_linalg_tensor(%arg0 : tensor<7x7xi32>, %arg1 : tensor<7x7xf32>,
} : tensor<?x?xf32> to tensor<2x4xf32>
return
}
+
+// -----
+
+func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>) -> (index, index, index)
+{
+ %c1 = constant 1 : index
+ %c3 = constant 3 : index
+ %c4 = constant 4 : index
+ %0 = linalg.tensor_reshape %arg0
+ [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>,
+ affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>,
+ affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>] :
+ tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32>
+ %1 = dim %0, %c1 : tensor<2x3x5x4x?x7xf32>
+ %2 = dim %0, %c3 : tensor<2x3x5x4x?x7xf32>
+ %3 = dim %0, %c4 : tensor<2x3x5x4x?x7xf32>
+ return %1, %2, %3 : index, index, index
+}
+// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
+// CHECK: func @dim_reshape_expansion
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<6x5x?xf32>
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK-DAG: %[[C3:.+]] = constant 3 : index
+// CHECK-DAG: %[[C4:.+]] = constant 4 : index
+// CHECK: %[[D0:.+]] = dim %[[ARG0]], %[[C2]]
+// CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
+// CHECK: return %[[C3]], %[[C4]], %[[D1]]
+
+// -----
+
+func @dim_reshape_collapse(%arg0 : tensor<2x3x5x4x?x7xf32>) -> (index, index)
+{
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %0 = linalg.tensor_reshape %arg0
+ [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>,
+ affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>,
+ affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>] :
+ tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32>
+ %1 = dim %0, %c1 : tensor<6x5x?xf32>
+ %2 = dim %0, %c2 : tensor<6x5x?xf32>
+ return %1, %2 : index, index
+}
+// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)>
+// CHECK: func @dim_reshape_collapse
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x3x5x4x?x7xf32>
+// CHECK-DAG: %[[C4:.+]] = constant 4 : index
+// CHECK-DAG: %[[C5:.+]] = constant 5 : index
+// CHECK: %[[D0:.+]] = dim %[[ARG0]], %[[C4]]
+// CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
+// CHECK: return %[[C5]], %[[D1]]
More information about the Mlir-commits
mailing list