[Mlir-commits] [mlir] Fixes in 'tosa.reshape' lowering and folder (PR #85798)
Krzysztof Drewniak
llvmlistbot at llvm.org
Tue Mar 19 12:39:19 PDT 2024
================
@@ -19,217 +19,98 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
+#include <numeric>
+
using namespace mlir;
using namespace tosa;
-static bool findIntermediateShape(ArrayRef<int64_t> lhsShape,
- ArrayRef<int64_t> rhsShape,
- SmallVector<int64_t> &intermediateShape,
- bool isDynamic) {
- if (isDynamic) {
- // TODO (natashaknk): Make dynamic intermediate shape not always be rank-1
- intermediateShape = {ShapedType::kDynamic};
- return true;
- }
-
- if (lhsShape.empty() || rhsShape.empty()) {
- intermediateShape = {};
- return true;
- }
-
- unsigned currLhsDim = 0, currRhsDim = 0;
- while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
- int64_t rhsSize = rhsShape[currRhsDim];
- int64_t lhsSize = lhsShape[currLhsDim];
- while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
- currRhsDim < rhsShape.size()) {
- if (lhsSize < rhsSize) {
- currLhsDim++;
- if (currLhsDim < lhsShape.size()) {
- lhsSize *= lhsShape[currLhsDim];
- }
- } else {
- currRhsDim++;
- if (currRhsDim < rhsShape.size()) {
- rhsSize *= rhsShape[currRhsDim];
- }
- }
- }
- if (lhsSize == rhsSize) {
- intermediateShape.push_back(lhsSize);
- }
- currRhsDim++;
- currLhsDim++;
- }
-
- // If the iterators didn't reach the end and their leftover dimensions are not
- // equal to 1 an intermediate shape was not found.
- while (currLhsDim < lhsShape.size()) {
- if (lhsShape[currLhsDim++] != 1) {
- return false;
- }
- }
-
- while (currRhsDim < rhsShape.size()) {
- if (rhsShape[currRhsDim++] != 1) {
- return false;
- }
- }
-
- return true;
+static Value getIndexConstant(OpBuilder& builder, Location loc, int64_t index) {
+ return builder.create<arith::ConstantIndexOp>(loc, index);
}
-static bool createReassociationMapsForCollapse(
- PatternRewriter &rewriter, ArrayRef<int64_t> srcShape,
- ArrayRef<int64_t> dstShape,
- SmallVector<ReassociationExprs, 4> &reassociationMap, bool isDynamic) {
-
- // If the shape is dynamic, create a map for collapsing into one dimension.
- if (isDynamic) {
- SmallVector<AffineExpr, 2> exprs;
- for (int i = 0, s = srcShape.size(); i < s; ++i)
- exprs.push_back(rewriter.getAffineDimExpr(i));
- reassociationMap = {exprs};
- return true;
- }
-
- if (dstShape.empty()) {
- reassociationMap = {};
- return true;
- }
-
- reassociationMap.resize(dstShape.size());
- unsigned currSrcDim = 0, currDstDim = 0;
- while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
- int64_t dstSize = dstShape[currDstDim];
- int64_t srcSize = srcShape[currSrcDim];
- while (srcSize < dstSize && currSrcDim < srcShape.size()) {
- reassociationMap[currDstDim].push_back(
- rewriter.getAffineDimExpr(currSrcDim++));
- srcSize *= srcShape[currSrcDim];
- }
- if (srcSize == dstSize) {
- reassociationMap[currDstDim].push_back(
- rewriter.getAffineDimExpr(currSrcDim++));
- // If the next dim in collapsedShape is not 1, treat subsequent dims in
- // expandedShape which are 1 to be collapsed.
- if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
- while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
- reassociationMap[currDstDim].push_back(
- rewriter.getAffineDimExpr(currSrcDim++));
- }
- }
- }
- currDstDim++;
+// Return the total size of the given input tensor.
+static Value getTensorSize(OpBuilder& builder, Location loc, TypedValue<TensorType> input) {
+ // If the input tensor is statically shaped, return its size as a constant.
+ if (input.getType().hasStaticShape()) {
+ auto shape = input.getType().getShape();
+ auto size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies());
+ return getIndexConstant(builder, loc, size);
}
- // If both iterators didn't reach the end, we have leftover dimentions which
- // implies that we have a mismatch in shape.
- return currSrcDim == srcShape.size() && currDstDim == dstShape.size();
+ // When the input tensor has at least one dynamic dimension, collapse it into
+ // a 1D tensor and get its size.
+ auto rank = input.getType().getRank();
+ auto elementType = input.getType().getElementType();
+ auto collapsedType = RankedTensorType::get({ShapedType::kDynamic}, elementType);
+ auto reassociationIndices = SmallVector<ReassociationIndices>{
+ llvm::to_vector(llvm::seq<int64_t>(rank))
+ };
+ auto collapsed = builder.create<tensor::CollapseShapeOp>(
+ loc, collapsedType, input, reassociationIndices);
+ return builder.create<tensor::DimOp>(loc, collapsed, 0);
}
-namespace {
-Value createCollapse(ConversionPatternRewriter &rewriter, Location loc,
- ShapedType resultTy, Value operand) {
- ShapedType operandTy = cast<ShapedType>(operand.getType());
- if (resultTy == operandTy)
- return operand;
-
- bool isDynamic = !operandTy.hasStaticShape();
-
- if (isDynamic && resultTy.getRank() != 1) {
- (void)rewriter.notifyMatchFailure(
- loc, "Cannot collapse dynamic dims to more than one dimension");
- return {};
- }
-
- SmallVector<ReassociationExprs, 4> reassociationMap;
- if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(),
- resultTy.getShape(),
- reassociationMap, isDynamic)) {
- (void)rewriter.notifyMatchFailure(
- loc, "tosa.reshape Attempting to collapse into an incompatible shape");
- return {};
- }
-
- SmallVector<int64_t> intermediateShape;
- if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
- intermediateShape, isDynamic)) {
- (void)rewriter.notifyMatchFailure(
- loc, "tosa.reshape Cannot collapse into given shape");
- return {};
- }
- return rewriter.create<tensor::CollapseShapeOp>(loc, resultTy, operand,
- reassociationMap);
+// Compute the dimension size of the result tensor corresponding to the
+// placeholder value set to -1 in the 'new_shape' attribute of a 'tosa.reshape'
+// op. Argument 'index' indicates the position of the -1 placeholder.
+static Value getReshapePlaceholderDimSize(OpBuilder &builder,
+ tosa::ReshapeOp reshape,
+ int64_t index) {
+ auto loc = reshape.getLoc();
----------------
krzysz00 wrote:
Same nitpicking re `auto`s
https://github.com/llvm/llvm-project/pull/85798
More information about the Mlir-commits
mailing list