[Mlir-commits] [mlir] 892fdc9 - [mlir][Linalg] Generalize the logic to compute reassociation maps
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Sep 30 07:58:59 PDT 2020
Author: Mahesh Ravishankar
Date: 2020-09-30T07:58:06-07:00
New Revision: 892fdc923f06adbef507ebe594fa7b48224d93f0
URL: https://github.com/llvm/llvm-project/commit/892fdc923f06adbef507ebe594fa7b48224d93f0
DIFF: https://github.com/llvm/llvm-project/commit/892fdc923f06adbef507ebe594fa7b48224d93f0.diff
LOG: [mlir][Linalg] Generalize the logic to compute reassociation maps
while folding tensor_reshape op.
While folding reshapes that introduce unit extent dims, the logic to
compute the reassociation maps can be generalized to handle some
corner cases, for example, when the folded shape still has unit-extent
dims but corresponds to folded unit extent dims of the expanded shape.
Differential Revision: https://reviews.llvm.org/D88521
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 08e7e352d63e9..611c938ab542f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -403,61 +403,58 @@ struct FoldReshapeOpWithUnitExtent : OpRewritePattern<TensorReshapeOp> {
srcType.getRank() < dstType.getRank() ||
parentSrcType.getRank() == dstType.getRank())
return failure();
+
// Check if the result tensor_reshape after folding the reshapeOp and
// parentReshapeOp are combined.
// If the final tensor_reshape is folding, the parentReshapeOp is
// introducing unit-dims, and the reshapeOp does an actual reshape.
- // If the final tensor_reshape op is expanding, the reshapeOp is introducing
- // unit-dims, and the parentReshapeOp does an actual reshape.
+ // If the final tensor_reshape op is expanding, the reshapeOp is
+ // introducing unit-dims, and the parentReshapeOp does an actual reshape.
bool isFoldingPattern = parentSrcType.getRank() > dstType.getRank();
- auto reassociationMaps = isFoldingPattern
- ? reshapeOp.getReassociationMaps()
- : parentReshapeOp.getReassociationMaps();
- DenseSet<unsigned> conservedDimensions;
- for (auto &map : reassociationMaps) {
- if (map.getNumResults() == 1) {
- conservedDimensions.insert(
- map.getResult(0).cast<AffineDimExpr>().getPosition());
- }
- }
-
- // Find positions at which the unit-dims exist.
- int64_t nonUnitDimPos = 0;
- DenseMap<unsigned, unsigned> nonUnitSrcDims;
- ArrayRef<int64_t> nonUnitShape =
+ ArrayRef<int64_t> expandedShape =
isFoldingPattern ? parentSrcType.getShape() : dstType.getShape();
- for (auto shape : enumerate(srcType.getShape())) {
- // Case 1 : It is a conserved dimension.
- if (conservedDimensions.count(shape.index())) {
- nonUnitSrcDims[shape.index()] = nonUnitDimPos++;
- continue;
+ ArrayRef<int64_t> foldedShape =
+ isFoldingPattern ? dstType.getShape() : parentSrcType.getShape();
+
+ unsigned expandedDim = 0, foldedDim = 0;
+ SmallVector<SmallVector<AffineExpr, 4>, 4> reassociationExprs(
+ foldedShape.size());
+ while (expandedDim < expandedShape.size() &&
+ foldedDim < foldedShape.size()) {
+ int64_t dstSize = foldedShape[foldedDim];
+ int64_t srcSize = expandedShape[expandedDim];
+ while (srcSize < dstSize && expandedDim < expandedShape.size()) {
+ reassociationExprs[foldedDim].push_back(
+ rewriter.getAffineDimExpr(expandedDim++));
+ srcSize *= expandedShape[expandedDim];
}
- // Case 2 : Dimensions dont match but the intermediate tensor is unit-dim.
- if (shape.value() == 1)
- continue;
- // Case 3 : Dimensions match, treat it as a non-unit src dim.
- if (nonUnitDimPos < static_cast<int64_t>(nonUnitShape.size()) &&
- nonUnitShape[nonUnitDimPos] == shape.value()) {
- nonUnitSrcDims[shape.index()] = nonUnitDimPos++;
- continue;
+ if (srcSize == dstSize) {
+ reassociationExprs[foldedDim].push_back(
+ rewriter.getAffineDimExpr(expandedDim++));
+ // If the next dim in foldedShape is not 1, treat subsequent dims in
+ // expandedShape which are 1 to be collapsed.
+ if (foldedDim == foldedShape.size() - 1 ||
+ foldedShape[foldedDim + 1] != 1) {
+ while (expandedDim < expandedShape.size() &&
+ expandedShape[expandedDim] == 1) {
+ reassociationExprs[foldedDim].push_back(
+ rewriter.getAffineDimExpr(expandedDim++));
+ }
+ }
+ } else {
+ return failure();
}
- return failure();
+ foldedDim++;
}
+ if (expandedDim != expandedShape.size())
+ return failure();
- // Compute reassociation maps for the final operation. Use the reassociation
- // maps that is actually doing a reshape (and not just introducing
- // unit-dims). From these maps, prune the unit-extent dimensions.
- for (AffineMap &map : reassociationMaps) {
- SmallVector<AffineExpr, 4> exprs;
- exprs.reserve(nonUnitSrcDims.size());
- for (auto result : map.getResults()) {
- unsigned dim = result.cast<AffineDimExpr>().getPosition();
- if (nonUnitSrcDims.count(dim))
- exprs.push_back(rewriter.getAffineDimExpr(nonUnitSrcDims[dim]));
- }
- map = AffineMap::get(nonUnitSrcDims.size(), 0, exprs,
- rewriter.getContext());
- }
+ SmallVector<AffineMap, 4> reassociationMaps =
+ llvm::to_vector<4>(llvm::map_range(
+ reassociationExprs, [&](ArrayRef<AffineExpr> exprs) -> AffineMap {
+ return AffineMap::get(expandedShape.size(), 0, exprs,
+ rewriter.getContext());
+ }));
rewriter.replaceOpWithNewOp<TensorReshapeOp>(
reshapeOp, dstType, parentReshapeOp.src(),
rewriter.getAffineMapArrayAttr(reassociationMaps));
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 06e56c5cb7d2a..1793d2b59b706 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -240,3 +240,19 @@ func @fold_reshape(%arg0 : tensor<2048x1x2048xf32>) -> tensor<4x512x1x512x4xf32>
: tensor<1x4x1x512x1x1x512x1x4xf32> into tensor<4x512x1x512x4xf32>
return %1 : tensor<4x512x1x512x4xf32>
}
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: func @fold_reshape
+// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]]
+// CHECK-SAME: tensor<2xf32> into tensor<2x1xf32>
+func @fold_reshape(%arg0: tensor<2xf32>) -> tensor<2x1xf32>
+{
+ %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] : tensor<2xf32> into tensor<2x1x1xf32>
+ %1 = linalg.tensor_reshape %0
+ [affine_map<(d0, d1, d2) -> (d0)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>
+ ] : tensor<2x1x1xf32> into tensor<2x1xf32>
+ return %1 : tensor<2x1xf32>
+}
More information about the Mlir-commits
mailing list