[Mlir-commits] [mlir] Fixes in 'tosa.reshape' lowering and folder (PR #85798)

Krzysztof Drewniak llvmlistbot at llvm.org
Tue Mar 19 12:39:18 PDT 2024

@@ -19,217 +19,98 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include <numeric>
 using namespace mlir;
 using namespace tosa;
-static bool findIntermediateShape(ArrayRef<int64_t> lhsShape,
-                                  ArrayRef<int64_t> rhsShape,
-                                  SmallVector<int64_t> &intermediateShape,
-                                  bool isDynamic) {
-  if (isDynamic) {
-    // TODO (natashaknk): Make dynamic intermediate shape not always be rank-1
-    intermediateShape = {ShapedType::kDynamic};
-    return true;
-  }
-  if (lhsShape.empty() || rhsShape.empty()) {
-    intermediateShape = {};
-    return true;
-  }
-  unsigned currLhsDim = 0, currRhsDim = 0;
-  while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
-    int64_t rhsSize = rhsShape[currRhsDim];
-    int64_t lhsSize = lhsShape[currLhsDim];
-    while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
-           currRhsDim < rhsShape.size()) {
-      if (lhsSize < rhsSize) {
-        currLhsDim++;
-        if (currLhsDim < lhsShape.size()) {
-          lhsSize *= lhsShape[currLhsDim];
-        }
-      } else {
-        currRhsDim++;
-        if (currRhsDim < rhsShape.size()) {
-          rhsSize *= rhsShape[currRhsDim];
-        }
-      }
-    }
-    if (lhsSize == rhsSize) {
-      intermediateShape.push_back(lhsSize);
-    }
-    currRhsDim++;
-    currLhsDim++;
-  }
-  // If the iterators didn't reach the end and their leftover dimensions are not
-  // equal to 1 an intermediate shape was not found.
-  while (currLhsDim < lhsShape.size()) {
-    if (lhsShape[currLhsDim++] != 1) {
-      return false;
-    }
-  }
-  while (currRhsDim < rhsShape.size()) {
-    if (rhsShape[currRhsDim++] != 1) {
-      return false;
-    }
-  }
-  return true;
+static Value getIndexConstant(OpBuilder& builder, Location loc, int64_t index) {
+  return builder.create<arith::ConstantIndexOp>(loc, index);
-static bool createReassociationMapsForCollapse(
-    PatternRewriter &rewriter, ArrayRef<int64_t> srcShape,
-    ArrayRef<int64_t> dstShape,
-    SmallVector<ReassociationExprs, 4> &reassociationMap, bool isDynamic) {
-  // If the shape is dynamic, create a map for collapsing into one dimension.
-  if (isDynamic) {
-    SmallVector<AffineExpr, 2> exprs;
-    for (int i = 0, s = srcShape.size(); i < s; ++i)
-      exprs.push_back(rewriter.getAffineDimExpr(i));
-    reassociationMap = {exprs};
-    return true;
-  }
-  if (dstShape.empty()) {
-    reassociationMap = {};
-    return true;
-  }
-  reassociationMap.resize(dstShape.size());
-  unsigned currSrcDim = 0, currDstDim = 0;
-  while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
-    int64_t dstSize = dstShape[currDstDim];
-    int64_t srcSize = srcShape[currSrcDim];
-    while (srcSize < dstSize && currSrcDim < srcShape.size()) {
-      reassociationMap[currDstDim].push_back(
-          rewriter.getAffineDimExpr(currSrcDim++));
-      srcSize *= srcShape[currSrcDim];
-    }
-    if (srcSize == dstSize) {
-      reassociationMap[currDstDim].push_back(
-          rewriter.getAffineDimExpr(currSrcDim++));
-      // If the next dim in collapsedShape is not 1, treat subsequent dims in
-      // expandedShape which are 1 to be collapsed.
-      if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
-        while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
-          reassociationMap[currDstDim].push_back(
-              rewriter.getAffineDimExpr(currSrcDim++));
-        }
-      }
-    }
-    currDstDim++;
+// Return the total size of the given input tensor.
+static Value getTensorSize(OpBuilder& builder, Location loc, TypedValue<TensorType> input) {
+  // If the input tensor is statically shaped, return its size as a constant.
+  if (input.getType().hasStaticShape()) {
+    auto shape = input.getType().getShape();
+    auto size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies());
+    return getIndexConstant(builder, loc, size);
-  // If both iterators didn't reach the end, we have leftover dimentions which
-  // implies that we have a mismatch in shape.
-  return currSrcDim == srcShape.size() && currDstDim == dstShape.size();
+  // When the input tensor has at least one dynamic dimension, collapse it into
+  // a 1D tensor and get its size.
+  auto rank = input.getType().getRank();
+  auto elementType = input.getType().getElementType();
+  auto collapsedType = RankedTensorType::get({ShapedType::kDynamic}, elementType);
+  auto reassociationIndices = SmallVector<ReassociationIndices>{
+    llvm::to_vector(llvm::seq<int64_t>(rank))
+  };
+  auto collapsed = builder.create<tensor::CollapseShapeOp>(
+      loc, collapsedType, input, reassociationIndices);
+  return builder.create<tensor::DimOp>(loc, collapsed, 0);
-namespace {
-Value createCollapse(ConversionPatternRewriter &rewriter, Location loc,
-                     ShapedType resultTy, Value operand) {
-  ShapedType operandTy = cast<ShapedType>(operand.getType());
-  if (resultTy == operandTy)
-    return operand;
-  bool isDynamic = !operandTy.hasStaticShape();
-  if (isDynamic && resultTy.getRank() != 1) {
-    (void)rewriter.notifyMatchFailure(
-        loc, "Cannot collapse dynamic dims to more than one dimension");
-    return {};
-  }
-  SmallVector<ReassociationExprs, 4> reassociationMap;
-  if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(),
-                                          resultTy.getShape(),
-                                          reassociationMap, isDynamic)) {
-    (void)rewriter.notifyMatchFailure(
-        loc, "tosa.reshape Attempting to collapse into an incompatible shape");
-    return {};
-  }
-  SmallVector<int64_t> intermediateShape;
-  if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
-                             intermediateShape, isDynamic)) {
-    (void)rewriter.notifyMatchFailure(
-        loc, "tosa.reshape Cannot collapse into given shape");
-    return {};
-  }
-  return rewriter.create<tensor::CollapseShapeOp>(loc, resultTy, operand,
-                                                  reassociationMap);
+// Compute the dimension size of the result tensor corresponding to the
+// placeholder value set to -1 in the 'new_shape' attribute of a 'tosa.reshape'
+// op. Argument 'index' indicates the position of the -1 placeholder.
+static Value getReshapePlaceholderDimSize(OpBuilder &builder,
+                                          tosa::ReshapeOp reshape,
+                                          int64_t index) {
+  auto loc = reshape.getLoc();
+  auto input = reshape.getInput1();
+  auto newShape = reshape.getNewShape();
+  auto resultType = reshape.getResult().getType();
+  // If the corresponding dimension in the result type is static, take the
+  // dimension size from there.
+  assert(newShape[index] == -1);
+  if (!resultType.isDynamicDim(index))
+    return getIndexConstant(builder, loc, resultType.getDimSize(index));
+  // Calculate the product of all dimensions in the new shape. We expect to have
+  // exactly one size set to -1, so we can discard this component by just
+  // negating the final product.
+  auto newSizeLiteral = -std::accumulate(newShape.begin(), newShape.end(), 1,
+                                         std::multiplies<int64_t>());
+  assert(newSizeLiteral >= 0);
+  auto newSize = builder.create<arith::ConstantIndexOp>(loc, newSizeLiteral);
+  // Avoid a division by zero. If any of the given dimension sizes was set to
+  // zero, set the placeholder size to zero, too.
+  if (newSizeLiteral == 0)
+    return newSize;
+  // The size of the placeholder dimension is the size of the input tensor
+  // divided by all non-placeholder dimension sizes.
+  auto inputSize = getTensorSize(builder, loc, input);
+  return builder.createOrFold<arith::DivUIOp>(loc, inputSize, newSize);
krzysz00 wrote:

I figure unsigned division's justified here, but I'd like to confirm that "tensors can't have more than `INT64_MAX' elements"  is written down somewhere


More information about the Mlir-commits mailing list