[Mlir-commits] [mlir] a43f7d6 - [mlir][tensor] Extend reshape utils.
Stephan Herhut
llvmlistbot at llvm.org
Fri Feb 18 00:58:09 PST 2022
Author: Stephan Herhut
Date: 2022-02-18T09:57:39+01:00
New Revision: a43f7d6d76984ddae4a5e5e0bebf82ee2edebabb
URL: https://github.com/llvm/llvm-project/commit/a43f7d6d76984ddae4a5e5e0bebf82ee2edebabb
DIFF: https://github.com/llvm/llvm-project/commit/a43f7d6d76984ddae4a5e5e0bebf82ee2edebabb.diff
LOG: [mlir][tensor] Extend reshape utils.
This change changes the handling of trailing dimensions with unknown
extent. Users of the changessociationIndicesForReshape helper should
see benefits when transforming reshape like operations into
expand/collapse pairs if the higher-rank type has trailing unknown
dimensions.
The motivating example is a reshape from tensor<16x1x?xi32> to
tensor<16xi32> that can be modeled as collapsing the three dimensions.
Differential Revision: https://reviews.llvm.org/D119730
Added:
Modified:
mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index fd509621015d2..17f449e489e38 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -35,13 +35,12 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
int64_t prodOfCollapsedDims = 1;
while (sourceDim < sourceShape.size()) {
unsigned targetDim = reassociationMap.size();
+ // If we have mapped all the target dimensions stop and handle the remaining
+ // tail of size-1 dimensions explictly.
+ if (targetDim == targetType.getRank())
+ break;
- // If all the dimensions of the targetShape are exhausted, then the
- // remaining dims in the source shape must be all 1s. So for such cases, set
- // 1 as the target shape. The actual reassociation indices will be handled
- // later.
- int64_t currTargetShape =
- (targetDim < targetType.getRank() ? targetShape[targetDim] : 1);
+ int64_t currTargetShape = targetShape[targetDim];
while (sourceShape[sourceDim] != ShapedType::kDynamicSize &&
prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape &&
sourceDim < sourceShape.size()) {
@@ -69,25 +68,23 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
return llvm::None;
currIndices.push_back(sourceDim++);
- // If the reassociation is empty but the currIndices is not, this by
- // definition is folding unit-dimensions with the result being scalar type.
- // So only append the `currIndices` if reassociation map is not empty.
- if (targetDim == targetShape.size()) {
- while (sourceDim < sourceShape.size())
- currIndices.push_back(sourceDim++);
- if (!reassociationMap.empty() && !currIndices.empty())
- reassociationMap.back().append(currIndices.begin(), currIndices.end());
- // Break out of the loops. We should be done here.
- break;
- }
reassociationMap.emplace_back(ReassociationIndices{});
std::swap(reassociationMap.back(), currIndices);
prodOfCollapsedDims = 1;
}
- // All the dimensions in the two shapes must have been processed.
- if (reassociationMap.size() != targetShape.size() ||
- sourceDim != sourceShape.size())
+ // All the dimensions in the target must have been processed.
+ if (reassociationMap.size() != targetShape.size())
return llvm::None;
+ // Process any remaining entries in the source shape. They all need to be
+ // 1 or dynamic.
+ for (; sourceDim < sourceShape.size(); sourceDim++) {
+ if (sourceShape[sourceDim] != ShapedType::kDynamicSize &&
+ sourceShape[sourceDim] != 1)
+ return llvm::None;
+ // The map is empty when the target type is a scalar.
+ if (!reassociationMap.empty())
+ reassociationMap.back().push_back(sourceDim);
+ }
return reassociationMap;
}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 4e4bbb8a12672..ce3db8d6039c2 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -879,16 +879,17 @@ func @fold_reshape_trailing_unit_dims(%arg0: tensor<12x42x1x1xf32>)
// -----
-func @no_fold_reshapes(%arg0 : tensor<?x?x?xf32>) -> tensor<?x?xf32> {
+func @fold_reshapes_unit_dims_in_middle(%arg0 : tensor<?x?x?xf32>) -> tensor<?x?xf32> {
%0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]]
: tensor<?x?x?xf32> into tensor<?x?x1x?xf32>
%1 = tensor.collapse_shape %0 [[0], [1, 2, 3]]
: tensor<?x?x1x?xf32> into tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
-// CHECK-LABEL: func @no_fold_reshapes
-// CHECK: tensor.expand_shape
-// CHECK: tensor.collapse_shape
+// CHECK-LABEL: func @fold_reshapes_unit_dims_in_middle
+// CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>
+// CHECK: tensor.collapse_shape %[[ARG]] {{\[}}[0], [1, 2]]
+// CHECK-SAME: tensor<?x?x?xf32> into tensor<?x?xf32>
// -----
More information about the Mlir-commits
mailing list