[Mlir-commits] [mlir] ff23599 - [mlir][tosa] Update TOSA resize to match specification
Rob Suderman
llvmlistbot at llvm.org
Wed Oct 5 13:20:30 PDT 2022
Author: TatWai Chong
Date: 2022-10-05T13:18:00-07:00
New Revision: ff23599a0dfc857c4e80854a35026be4a29dd57c
URL: https://github.com/llvm/llvm-project/commit/ff23599a0dfc857c4e80854a35026be4a29dd57c
DIFF: https://github.com/llvm/llvm-project/commit/ff23599a0dfc857c4e80854a35026be4a29dd57c.diff
LOG: [mlir][tosa] Update TOSA resize to match specification
Attribute stride and shift are removed, and has new scale and border.
Signed-off-by: TatWai Chong <tatwai.chong at arm.com>
Change-Id: I6cdbeb3978f5ee540bc6cf59eb7c273eb0131430
Reviewed By: rsuderman
Differential Revision: https://reviews.llvm.org/D131629
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
mlir/test/Dialect/Tosa/ops.mlir
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 8518d6bf54842..edecc2cee96d0 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1616,19 +1616,18 @@ def Tosa_ResizeOp : Tosa_Op<"resize", [
let description = [{
Resizes a tensor. Resize is only allowed in the H and W dimensions. In
- expected use, stride_y is approximately (IH<<shift)/OH and stride_x is
- approximately (IW<<shift)/OW. OH and OW are also supplied as inputs since
- there may be off by one errors if calculating OH and OW from the strides.
+ expected use, The height dimension is scaled by factor (scale_y_n/scale_y_d).
+ And the width dimension is scaled by factor (scale_x_n/scale_x_d). Thus the
+ output dimensions can be derived from the input dimensions by inverting the
+ scale. And the [order_y, border_x] values adjust the output size to allow
+ fractional sampling beyond integer input position (IH-1,IW-1).
}];
let arguments = (ins
Tosa_Tensor4D:$input,
- Tosa_IntArrayAttr2:$output_size,
- Tosa_IntArrayAttr2:$stride,
+ Tosa_IntArrayAttr4:$scale,
Tosa_IntArrayAttr2:$offset,
- I32Attr:$shift,
- Tosa_Fp32ArrayAttr2:$stride_fp,
- Tosa_Fp32ArrayAttr2:$offset_fp,
+ Tosa_IntArrayAttr2:$border,
Tosa_ResizeTypeAttr:$mode
);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index a634e2bf1b895..cfb7124ff743c 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1373,47 +1373,65 @@ class ResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
Value inX =
rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), x);
- int32_t shift = op.getShift();
- bool floatingPointMode = shift == 0;
-
- Value yStride, xStride, yOffset, xOffset;
- if (floatingPointMode) {
- yStride = rewriter.create<arith::ConstantOp>(loc, op.getStrideFp()[0]);
- xStride = rewriter.create<arith::ConstantOp>(loc, op.getStrideFp()[1]);
- yOffset = rewriter.create<arith::ConstantOp>(loc, op.getOffsetFp()[0]);
- xOffset = rewriter.create<arith::ConstantOp>(loc, op.getOffsetFp()[1]);
- } else {
- SmallVector<int32_t> stride, offset;
- getValuesFromIntArrayAttribute(op.getStride(), stride);
- getValuesFromIntArrayAttribute(op.getOffset(), offset);
-
- yStride = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(stride[0]));
- xStride = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(stride[1]));
- yOffset = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(offset[0]));
- xOffset = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(offset[1]));
- }
+ bool floatingPointMode = resultElementTy.isF32();
+
+ Value yScaleN, yScaleD, xScaleN, xScaleD, yOffset, xOffset, yBorder,
+ xBorder;
+ SmallVector<int32_t> scale, offset, border;
+ getValuesFromIntArrayAttribute(op.getScale(), scale);
+ getValuesFromIntArrayAttribute(op.getOffset(), offset);
+ getValuesFromIntArrayAttribute(op.getBorder(), border);
+
+ yScaleN = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI32IntegerAttr(scale[0]));
+ yScaleD = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI32IntegerAttr(scale[1]));
+ xScaleN = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI32IntegerAttr(scale[2]));
+ xScaleD = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI32IntegerAttr(scale[3]));
+ yOffset = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI32IntegerAttr(offset[0]));
+ xOffset = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI32IntegerAttr(offset[1]));
+ yBorder = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI32IntegerAttr(border[0]));
+ xBorder = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI32IntegerAttr(border[1]));
// Compute the the integer index and partial offset.
- // x = x * stride + offset;
- // ix = floor(x)
- // dx = x - ix
Value ix, iy, dx, dy;
+ // x = x * scale_d + offset;
+ // ix = floor(x / scale_n)
if (floatingPointMode) {
+ // dx = x / scale_n - ix
Value y =
rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), inY);
Value x =
rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), inX);
- y = rewriter.create<arith::MulFOp>(loc, y, yStride);
- x = rewriter.create<arith::MulFOp>(loc, x, xStride);
+ yScaleN =
+ rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), yScaleN);
+ yScaleD =
+ rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), yScaleD);
+ xScaleN =
+ rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), xScaleN);
+ xScaleD =
+ rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), xScaleD);
+ yOffset =
+ rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), yOffset);
+ xOffset =
+ rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), xOffset);
+
+ y = rewriter.create<arith::MulFOp>(loc, y, yScaleD);
+ x = rewriter.create<arith::MulFOp>(loc, x, xScaleD);
y = rewriter.create<arith::AddFOp>(loc, y, yOffset);
x = rewriter.create<arith::AddFOp>(loc, x, xOffset);
+ y = rewriter.create<arith::DivFOp>(loc, y, yScaleN);
+ x = rewriter.create<arith::DivFOp>(loc, x, xScaleN);
+
iy = rewriter.create<math::FloorOp>(loc, y);
ix = rewriter.create<math::FloorOp>(loc, x);
@@ -1423,27 +1441,30 @@ class ResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
iy = rewriter.create<arith::FPToSIOp>(loc, rewriter.getI32Type(), iy);
ix = rewriter.create<arith::FPToSIOp>(loc, rewriter.getI32Type(), ix);
} else {
- Value shiftVal = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(shift));
-
- Value y = rewriter.create<arith::MulIOp>(loc, inY, yStride);
- Value x = rewriter.create<arith::MulIOp>(loc, inX, xStride);
+ // dx = x - ix * scale_n;
+ Value y = rewriter.create<arith::MulIOp>(loc, inY, yScaleD);
+ Value x = rewriter.create<arith::MulIOp>(loc, inX, xScaleD);
y = rewriter.create<arith::AddIOp>(loc, y, yOffset);
x = rewriter.create<arith::AddIOp>(loc, x, xOffset);
- iy = rewriter.create<arith::ShRSIOp>(loc, y, shiftVal);
- ix = rewriter.create<arith::ShRSIOp>(loc, x, shiftVal);
+ iy = rewriter.create<arith::DivUIOp>(loc, y, yScaleN);
+ ix = rewriter.create<arith::DivUIOp>(loc, x, xScaleN);
- Value yTrunc = rewriter.create<arith::ShLIOp>(loc, iy, shiftVal);
- Value xTrunc = rewriter.create<arith::ShLIOp>(loc, ix, shiftVal);
+ Value temp_y = rewriter.create<arith::MulIOp>(loc, iy, yScaleN);
+ Value temp_x = rewriter.create<arith::MulIOp>(loc, ix, xScaleN);
- dy = rewriter.create<arith::SubIOp>(loc, y, yTrunc);
- dx = rewriter.create<arith::SubIOp>(loc, x, xTrunc);
+ dy = rewriter.create<arith::SubIOp>(loc, y, temp_y);
+ dx = rewriter.create<arith::SubIOp>(loc, x, temp_x);
}
if (op.getMode() == "NEAREST_NEIGHBOR") {
Value yPred, xPred;
+ auto zeroVal = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI32IntegerAttr(0));
+ auto oneVal = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI32IntegerAttr(1));
+
// Round the index position towards the closest pixel location.
if (floatingPointMode) {
auto halfVal = rewriter.create<arith::ConstantOp>(
@@ -1453,19 +1474,16 @@ class ResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
xPred = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
dx, halfVal);
} else {
- auto halfVal = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(1 << (shift - 1)));
+ Value yScaleNHalfVal =
+ rewriter.create<arith::ShRSIOp>(loc, yScaleN, oneVal);
+ Value xScaleNHalfVal =
+ rewriter.create<arith::ShRSIOp>(loc, xScaleN, oneVal);
yPred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
- dy, halfVal);
+ dy, yScaleNHalfVal);
xPred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
- dx, halfVal);
+ dx, xScaleNHalfVal);
}
- auto zeroVal = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(0));
- auto oneVal = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(1));
-
auto yOffset =
rewriter.create<arith::SelectOp>(loc, yPred, oneVal, zeroVal);
auto xOffset =
@@ -1491,9 +1509,8 @@ class ResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
rewriter.create<linalg::YieldOp>(loc, result);
return success();
- }
-
- if (op.getMode() == "BILINEAR") {
+ } else {
+ // The mode here must be BILINEAR. This has been checked above.
Value y0 = iy;
Value x0 = ix;
@@ -1527,10 +1544,8 @@ class ResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
loc, input, ValueRange{batch, y1, x1, channel});
if (floatingPointMode) {
- auto oneVal = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getF32FloatAttr(1.f));
Value rightPart = dx;
- Value leftPart = rewriter.create<arith::SubFOp>(loc, oneVal, dx);
+ Value leftPart = rewriter.create<arith::SubFOp>(loc, xScaleN, dx);
y0x0 = rewriter.create<arith::MulFOp>(loc, y0x0, leftPart);
y0x1 = rewriter.create<arith::MulFOp>(loc, y0x1, rightPart);
@@ -1541,46 +1556,46 @@ class ResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
Value bottomAcc = rewriter.create<arith::AddFOp>(loc, y1x0, y1x1);
Value bottomPart = dy;
- Value topPart = rewriter.create<arith::SubFOp>(loc, oneVal, dy);
+ Value topPart = rewriter.create<arith::SubFOp>(loc, yScaleN, dy);
topAcc = rewriter.create<arith::MulFOp>(loc, topAcc, topPart);
bottomAcc = rewriter.create<arith::MulFOp>(loc, bottomAcc, bottomPart);
Value result = rewriter.create<arith::AddFOp>(loc, topAcc, bottomAcc);
rewriter.create<linalg::YieldOp>(loc, result);
return success();
- }
- y0x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x0);
- y0x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x1);
- y1x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x0);
- y1x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x1);
-
- if (resultElementTy.getIntOrFloatBitWidth() > 32) {
- dx = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dx);
- dy = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dy);
- }
+ } else {
+ y0x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x0);
+ y0x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x1);
+ y1x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x0);
+ y1x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x1);
+
+ if (resultElementTy.getIntOrFloatBitWidth() > 32) {
+ dx = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dx);
+ dy = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dy);
+ }
- auto unitVal = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(resultElementTy, 1LL << shift));
- Value rightPart = dx;
- Value leftPart = rewriter.create<arith::SubIOp>(loc, unitVal, dx);
+ Value rightPart = dx;
+ Value leftPart = rewriter.create<arith::SubIOp>(loc, xScaleN, dx);
- y0x0 = rewriter.create<arith::MulIOp>(loc, y0x0, leftPart);
- y0x1 = rewriter.create<arith::MulIOp>(loc, y0x1, rightPart);
- Value topAcc = rewriter.create<arith::AddIOp>(loc, y0x0, y0x1);
+ y0x0 = rewriter.create<arith::MulIOp>(loc, y0x0, leftPart);
+ y0x1 = rewriter.create<arith::MulIOp>(loc, y0x1, rightPart);
+ Value topAcc = rewriter.create<arith::AddIOp>(loc, y0x0, y0x1);
- y1x0 = rewriter.create<arith::MulIOp>(loc, y1x0, leftPart);
- y1x1 = rewriter.create<arith::MulIOp>(loc, y1x1, rightPart);
- Value bottomAcc = rewriter.create<arith::AddIOp>(loc, y1x0, y1x1);
+ y1x0 = rewriter.create<arith::MulIOp>(loc, y1x0, leftPart);
+ y1x1 = rewriter.create<arith::MulIOp>(loc, y1x1, rightPart);
+ Value bottomAcc = rewriter.create<arith::AddIOp>(loc, y1x0, y1x1);
- Value bottomPart = dy;
- Value topPart = rewriter.create<arith::SubIOp>(loc, unitVal, dy);
- topAcc = rewriter.create<arith::MulIOp>(loc, topAcc, topPart);
- bottomAcc = rewriter.create<arith::MulIOp>(loc, bottomAcc, bottomPart);
- Value result = rewriter.create<arith::AddIOp>(loc, topAcc, bottomAcc);
+ Value bottomPart = dy;
+ Value topPart = rewriter.create<arith::SubIOp>(loc, yScaleN, dy);
+ topAcc = rewriter.create<arith::MulIOp>(loc, topAcc, topPart);
+ bottomAcc = rewriter.create<arith::MulIOp>(loc, bottomAcc, bottomPart);
+ Value result = rewriter.create<arith::AddIOp>(loc, topAcc, bottomAcc);
- rewriter.create<linalg::YieldOp>(loc, result);
- return success();
+ rewriter.create<linalg::YieldOp>(loc, result);
+ return success();
+ }
}
+
return failure();
}
};
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index d944eece7ec90..8df30279cb7e0 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -334,12 +334,6 @@ static void getI64Values(ArrayAttr arrayAttr, SmallVector<int64_t> &values) {
}
}
-static void getF64Values(ArrayAttr arrayAttr, SmallVector<double> &values) {
- for (auto it : arrayAttr) {
- values.push_back(it.cast<FloatAttr>().getValueAsDouble());
- }
-}
-
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
SmallVector<int64_t> &outShape) {
int64_t outRank = 0;
@@ -800,64 +794,36 @@ LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
llvm::SmallVector<int64_t, 4> outputShape;
outputShape.resize(4, ShapedType::kDynamicSize);
- int32_t inHeight = ShapedType::kDynamicSize;
- int32_t inWidth = ShapedType::kDynamicSize;
-
ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
- if (inputShape.hasRank()) {
- outputShape[0] = inputShape.getDimSize(0);
- outputShape[3] = inputShape.getDimSize(3);
+ if (!inputShape.hasRank())
+ return failure();
- inHeight = inputShape.getDimSize(1);
- inWidth = inputShape.getDimSize(2);
- }
+ outputShape[0] = inputShape.getDimSize(0);
+ outputShape[3] = inputShape.getDimSize(3);
+ int32_t inputHeight = inputShape.getDimSize(1);
+ int32_t inputWidth = inputShape.getDimSize(2);
- int32_t shift = adaptor.getShift();
- llvm::SmallVector<int64_t> newShape;
- getI64Values(adaptor.getOutputSize(), newShape);
- outputShape[1] = newShape[0];
- outputShape[2] = newShape[1];
+ if ((inputHeight == ShapedType::kDynamicSize) ||
+ (inputWidth == ShapedType::kDynamicSize))
+ return failure();
- llvm::SmallVector<int64_t> strideInt;
+ llvm::SmallVector<int64_t> scaleInt;
llvm::SmallVector<int64_t> offsetInt;
- llvm::SmallVector<double> strideFp;
- llvm::SmallVector<double> offsetFp;
+ llvm::SmallVector<int64_t> borderInt;
+ getI64Values(adaptor.getScale(), scaleInt);
getI64Values(adaptor.getOffset(), offsetInt);
- getF64Values(adaptor.getOffsetFp(), offsetFp);
- getI64Values(adaptor.getStride(), strideInt);
- getF64Values(adaptor.getStrideFp(), strideFp);
-
- // If we have a 0 zero in integers we know that the resize indexing needs to
- // be performed in floating point. Use the floating point varient to compute
- // the resize shape.
- bool fpMode = strideInt[0] == 0;
-
- // We can compute the output shape if attribute specifies unknown dimensions
- // based on the offset and stride. If we perfectly line up to the last index
- // we need to round up the size to include it.
- if (outputShape[1] == ShapedType::kDynamicSize && inHeight >= 0 && fpMode) {
- float sizeFp = (inHeight - offsetFp[0] - 1) / strideFp[0];
- float round = std::floor(sizeFp) == sizeFp ? 1 : 0;
- outputShape[1] = std::ceil(sizeFp) + round;
- }
-
- if (outputShape[2] == ShapedType::kDynamicSize && inWidth >= 0 && fpMode) {
- float sizeFp = (inWidth - offsetFp[1] - 1) / strideFp[1];
- float round = std::floor(sizeFp) == sizeFp ? 1 : 0;
- outputShape[2] = std::ceil(sizeFp) + round;
- }
-
- if (outputShape[1] == ShapedType::kDynamicSize && inHeight >= 0 && !fpMode) {
- int64_t size = (inHeight - 1);
- size = ((size << shift) - offsetInt[0]) / strideInt[0];
- outputShape[1] = size + 1;
- }
-
- if (outputShape[2] == ShapedType::kDynamicSize && inWidth >= 0 && !fpMode) {
- int64_t size = (inWidth - 1);
- size = ((size << shift) - offsetInt[1]) / strideInt[1];
- outputShape[2] = size + 1;
- }
+ getI64Values(adaptor.getBorder(), borderInt);
+
+ // Compute the output shape based on attributes: scale, offset, and border.
+ outputShape[1] =
+ (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
+ scaleInt[1]) +
+ 1;
+
+ outputShape[2] =
+ (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
+ scaleInt[3]) +
+ 1;
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
return success();
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 0c8af01c112e7..a7297281d5ab3 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1623,124 +1623,141 @@ func.func @table8_dyn_table(%arg0: tensor<6xi8>, %arg1: tensor<?xi8>) -> () {
// -----
-// CHECK-LABEL: @resize_nearest
-func.func @resize_nearest(%input: tensor<1x2x2x1xf32>) -> () {
- // CHECK: %[[INIT:.+]] = tensor.empty()
+// CHECK-LABEL: @resize_nearest_int
+func.func @resize_nearest_int(%arg0: tensor<1x15x13x1xi8>) -> () {
+ // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x23x179x1xi8>
// CHECK: %[[GENERIC:.+]] = linalg.generic
- // CHECK: %[[IDX0:.+]] = linalg.index 0
- // CHECK: %[[IDX1:.+]] = linalg.index 1
- // CHECK: %[[IDX2:.+]] = linalg.index 2
- // CHECK: %[[IDX3:.+]] = linalg.index 3
- // CHECK-DAG: %[[XYMIN:.+]] = arith.constant 0
- // CHECK-DAG: %[[YMAX:.+]] = arith.constant 1
- // CHECK-DAG: %[[XMAX:.+]] = arith.constant 1
- // CHECK-DAG: %[[Y:.+]] = arith.index_cast %[[IDX1]]
- // CHECK-DAG: %[[X:.+]] = arith.index_cast %[[IDX2]]
- // CHECK-DAG: %[[STRIDEY:.+]] = arith.constant 5.000000e-01
- // CHECK-DAG: %[[STRIDEX:.+]] = arith.constant 5.000000e-01
- // CHECK-DAG: %[[OFFSETY:.+]] = arith.constant 1.000000e-01
- // CHECK-DAG: %[[OFFSETX:.+]] = arith.constant 2.000000e-01
- // CHECK-DAG: %[[VAL4:.+]] = arith.uitofp %[[Y]]
- // CHECK-DAG: %[[VAL5:.+]] = arith.uitofp %[[X]]
- // CHECK-DAG: %[[VAL6:.+]] = arith.mulf %[[VAL4]], %[[STRIDEY]]
- // CHECK-DAG: %[[VAL7:.+]] = arith.mulf %[[VAL5]], %[[STRIDEX]]
- // CHECK-DAG: %[[VAL8:.+]] = arith.addf %[[VAL6]], %[[OFFSETY]]
- // CHECK-DAG: %[[VAL9:.+]] = arith.addf %[[VAL7]], %[[OFFSETX]]
-
- // Find the remainder and integer component of the target index.
-
- // CHECK-DAG: %[[VAL10:.+]] = math.floor %[[VAL8]]
- // CHECK-DAG: %[[VAL11:.+]] = math.floor %[[VAL9]]
- // CHECK-DAG: %[[VAL12:.+]] = arith.subf %[[VAL8]], %[[VAL10]]
- // CHECK-DAG: %[[VAL13:.+]] = arith.subf %[[VAL9]], %[[VAL11]]
- // CHECK-DAG: %[[VAL14:.+]] = arith.fptosi %[[VAL10]]
- // CHECK-DAG: %[[VAL15:.+]] = arith.fptosi %[[VAL11]]
-
- // Round to the nearest index.
-
- // CHECK-DAG: %[[ROUND:.+]] = arith.constant 5.000000e-01
- // CHECK-DAG: %[[VAL16:.+]] = arith.cmpf oge, %[[VAL12]], %[[ROUND]]
- // CHECK-DAG: %[[VAL17:.+]] = arith.cmpf oge, %[[VAL13]], %[[ROUND]]
- // CHECK-DAG: %[[ZERO:.+]] = arith.constant 0
- // CHECK-DAG: %[[ONE:.+]] = arith.constant 1
- // CHECK-DAG: %[[VAL18:.+]] = arith.select %[[VAL16]], %[[ONE]], %[[ZERO]]
- // CHECK-DAG: %[[VAL19:.+]] = arith.select %[[VAL17]], %[[ONE]], %[[ZERO]]
- // CHECK-DAG: %[[VAL20:.+]] = arith.addi %[[VAL14]], %[[VAL18]]
- // CHECK-DAG: %[[VAL21:.+]] = arith.addi %[[VAL15]], %[[VAL19]]
+ // CHECK: %[[IDX_0:.+]] = linalg.index 0
+ // CHECK: %[[IDX_1:.+]] = linalg.index 1
+ // CHECK: %[[IDX_2:.+]] = linalg.index 2
+ // CHECK: %[[IDX_3:.+]] = linalg.index 3
+ // CHECK: %[[XY_MIN:.+]] = arith.constant 0
+ // CHECK: %[[Y_MAX:.+]] = arith.constant 14
+ // CHECK: %[[X_MAX:.+]] = arith.constant 12
+
+ // CHECK: %[[Y:.+]] = arith.index_cast %[[IDX_1]]
+ // CHECK: %[[X:.+]] = arith.index_cast %[[IDX_2]]
+ // CHECK: %[[SCALE_Y_N:.*]] = arith.constant 11
+ // CHECK: %[[SCALE_Y_D:.*]] = arith.constant 7
+ // CHECK: %[[SCALE_X_N:.*]] = arith.constant 89
+ // CHECK: %[[SCALE_X_D:.*]] = arith.constant 6
+ // CHECK: %[[OFFSET_Y:.*]] = arith.constant 0
+ // CHECK: %[[OFFSET_X:.*]] = arith.constant 0
+ // CHECK: %[[BORDER_Y:.*]] = arith.constant 0
+ // CHECK: %[[BORDER_X:.*]] = arith.constant 0
+
+ // find the remainder and integer component of the target index.
+
+ // CHECK: %[[TEMP_Y:.*]] = arith.muli %[[Y]], %[[SCALE_Y_D]]
+ // CHECK: %[[TEMP_X:.*]] = arith.muli %[[X]], %[[SCALE_X_D]]
+ // CHECK: %[[Y:.*]] = arith.addi %[[TEMP_Y]], %[[OFFSET_Y]]
+ // CHECK: %[[X:.*]] = arith.addi %[[TEMP_X]], %[[OFFSET_X]]
+ // CHECK: %[[I_Y:.*]] = arith.divui %[[Y]], %[[SCALE_Y_N]]
+ // CHECK: %[[I_X:.*]] = arith.divui %[[X]], %[[SCALE_X_N]]
+ // CHECK: %[[TEMP_Y:.*]] = arith.muli %[[I_Y]], %[[SCALE_Y_N]]
+ // CHECK: %[[TEMP_X:.*]] = arith.muli %[[I_X]], %[[SCALE_X_N]]
+ // CHECK: %[[D_Y:.*]] = arith.subi %[[Y]], %[[TEMP_Y]]
+ // CHECK: %[[D_X:.*]] = arith.subi %[[X]], %[[TEMP_X]]
+
+ // Round to the nearest neighor.
+
+ // CHECK: %[[ZERO:.*]] = arith.constant 0
+ // CHECK: %[[ONE:.*]] = arith.constant 1
+ // CHECK: %[[SCALE_Y_N_HALF:.*]] = arith.shrsi %[[SCALE_Y_N]], %[[ONE]]
+ // CHECK: %[[SCALE_X_N_HALF:.*]] = arith.shrsi %[[SCALE_X_N]], %[[ONE]]
+ // CHECK: %[[PRED_Y:.*]] = arith.cmpi sge, %[[D_Y]], %[[SCALE_Y_N_HALF]]
+ // CHECK: %[[PRED_X:.*]] = arith.cmpi sge, %[[D_X]], %[[SCALE_X_N_HALF]]
+ // CHECK: %[[VAL_37:.*]] = arith.select %[[PRED_Y]], %[[ONE]], %[[ZERO]]
+ // CHECK: %[[VAL_38:.*]] = arith.select %[[PRED_X]], %[[ONE]], %[[ZERO]]
+ // CHECK: %[[VAL_39:.*]] = arith.addi %[[I_Y]], %[[VAL_37]]
+ // CHECK: %[[VAL_40:.*]] = arith.addi %[[I_X]], %[[VAL_38]]
// This section applies bound checking to be within the input image.
- // CHECK-DAG: %[[VAL22:.+]] = arith.cmpi slt, %[[VAL20]], %[[XYMIN]]
- // CHECK-DAG: %[[VAL23:.+]] = arith.select %[[VAL22]], %[[XYMIN]], %[[VAL20]]
- // CHECK-DAG: %[[VAL24:.+]] = arith.cmpi slt, %[[YMAX]], %[[VAL20]]
- // CHECK-DAG: %[[VAL25:.+]] = arith.select %[[VAL24]], %[[YMAX]], %[[VAL23]]
- // CHECK-DAG: %[[VAL26:.+]] = arith.cmpi slt, %[[VAL21]], %[[XYMIN]]
- // CHECK-DAG: %[[VAL27:.+]] = arith.select %[[VAL26]], %[[XYMIN]], %[[VAL21]]
- // CHECK-DAG: %[[VAL28:.+]] = arith.cmpi slt, %[[XMAX]], %[[VAL21]]
- // CHECK-DAG: %[[VAL29:.+]] = arith.select %[[VAL28]], %[[XMAX]], %[[VAL27]]
+ // CHECK: %[[VAL_41:.*]] = arith.cmpi slt, %[[VAL_39]], %[[XY_MIN]]
+ // CHECK: %[[VAL_42:.*]] = arith.select %[[VAL_41]], %[[XY_MIN]], %[[VAL_39]]
+ // CHECK: %[[VAL_43:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[VAL_39]]
+ // CHECK: %[[VAL_44:.*]] = arith.select %[[VAL_43]], %[[Y_MAX]], %[[VAL_42]]
+ // CHECK: %[[VAL_45:.*]] = arith.cmpi slt, %[[VAL_40]], %[[XY_MIN]]
+ // CHECK: %[[VAL_46:.*]] = arith.select %[[VAL_45]], %[[XY_MIN]], %[[VAL_40]]
+ // CHECK: %[[VAL_47:.*]] = arith.cmpi slt, %[[X_MAX]], %[[VAL_40]]
+ // CHECK: %[[VAL_48:.*]] = arith.select %[[VAL_47]], %[[X_MAX]], %[[VAL_46]]
// Extract the nearest value using the computed indices.
- // CHECK-DAG: %[[IDY:.+]] = arith.index_cast %[[VAL25]]
- // CHECK-DAG: %[[IDX:.+]] = arith.index_cast %[[VAL29]]
- // CHECK-DAG: %[[EXTRACT:.+]] = tensor.extract %arg0[%[[IDX0]], %[[IDY]], %[[IDX]], %[[IDX3]]]
+ // CHECK: %[[IDY:.+]] = arith.index_cast %[[VAL_44]]
+ // CHECK: %[[IDX:.+]] = arith.index_cast %[[VAL_48]]
+ // CHECK: %[[EXTRACT:.+]] = tensor.extract %arg0[%[[IDX_0]], %[[IDY]], %[[IDX]], %[[IDX_3]]]
// CHECK: linalg.yield %[[EXTRACT]]
- %output = "tosa.resize"(%input) { output_size = [4, 4], stride = [0, 0], offset = [0, 0], stride_fp = [0.5 : f32, 0.5 : f32], offset_fp = [0.1 : f32, 0.2 : f32], shift = 0 : i32, mode = "NEAREST_NEIGHBOR" } : (tensor<1x2x2x1xf32>) -> (tensor<1x4x4x1xf32>)
- return
+ // Round to the nearest index.
+ %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = [11, 7, 89, 6], offset = [0, 0], border = [0, 0]} : (tensor<1x15x13x1xi8>) -> tensor<1x23x179x1xi8>
+ return
}
// -----
-// CHECK-LABEL: @resize_bilinear
+// CHECK-LABEL: @resize_bilinear_int
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
-func.func @resize_bilinear(%input: tensor<1x2x2x1xf32>) -> () {
- // CHECK: %[[INIT:.+]] = tensor.empty()
+func.func @resize_bilinear_int(%arg0: tensor<1x19x19x1xi8>) {
+ // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x289x289x1xi32>
// CHECK: %[[GENERIC:.+]] = linalg.generic
- // CHECK: %[[IDX0:.+]] = linalg.index 0
- // CHECK: %[[IDX1:.+]] = linalg.index 1
- // CHECK: %[[IDX2:.+]] = linalg.index 2
- // CHECK: %[[IDX3:.+]] = linalg.index 3
- // CHECK: %[[XYMIN:.+]] = arith.constant 0
- // CHECK: %[[YMAX:.+]] = arith.constant 1
- // CHECK: %[[XMAX:.+]] = arith.constant 1
-
- // CHECK: %[[VAL10:.+]] = math.floor %[[VAL8:.+]]
- // CHECK: %[[VAL11:.+]] = math.floor %[[VAL9:.+]]
-
- // CHECK: %[[DY:.+]] = arith.subf %[[VAL8:.+]], %[[VAL10]]
- // CHECK: %[[DX:.+]] = arith.subf %[[VAL9:.+]], %[[VAL11]]
-
- // CHECK: %[[Y0:.+]] = arith.fptosi %[[VAL10]]
- // CHECK: %[[X0:.+]] = arith.fptosi %[[VAL11]]
+ // CHECK: %[[IDX_0:.+]] = linalg.index 0
+ // CHECK: %[[IDX_1:.+]] = linalg.index 1
+ // CHECK: %[[IDX_2:.+]] = linalg.index 2
+ // CHECK: %[[IDX_3:.+]] = linalg.index 3
+ // CHECK: %[[XY_MIN:.+]] = arith.constant 0
+ // CHECK: %[[Y_MAX:.+]] = arith.constant 18
+ // CHECK: %[[X_MAX:.+]] = arith.constant 18
+ // CHECK: %[[Y:.+]] = arith.index_cast %[[IDX_1]]
+ // CHECK: %[[X:.+]] = arith.index_cast %[[IDX_2]]
+ // CHECK: %[[SCALE_Y_N:.*]] = arith.constant 16
+ // CHECK: %[[SCALE_Y_D:.*]] = arith.constant 1
+ // CHECK: %[[SCALE_X_N:.*]] = arith.constant 16
+ // CHECK: %[[SCALE_X_D:.*]] = arith.constant 1
+ // CHECK: %[[OFFSET_Y:.*]] = arith.constant 0
+ // CHECK: %[[OFFSET_X:.*]] = arith.constant 0
+ // CHECK: %[[BORDER_Y:.*]] = arith.constant 0
+ // CHECK: %[[BORDER_X:.*]] = arith.constant 0
+
+ // CHECK: %[[TEMP_Y:.*]] = arith.muli %[[Y]], %[[SCALE_Y_D]]
+ // CHECK: %[[TEMP_X:.*]] = arith.muli %[[X]], %[[SCALE_X_D]]
+ // CHECK: %[[Y:.*]] = arith.addi %[[TEMP_Y]], %[[OFFSET_Y]]
+ // CHECK: %[[X:.*]] = arith.addi %[[TEMP_X]], %[[OFFSET_X]]
+ // CHECK: %[[I_Y:.*]] = arith.divui %[[Y]], %[[SCALE_Y_N]]
+ // CHECK: %[[I_X:.*]] = arith.divui %[[X]], %[[SCALE_X_N]]
+ // CHECK: %[[TEMP_Y:.*]] = arith.muli %[[I_Y]], %[[SCALE_Y_N]]
+ // CHECK: %[[TEMP_X:.*]] = arith.muli %[[I_X]], %[[SCALE_X_N]]
+ // CHECK: %[[D_Y:.*]] = arith.subi %[[Y]], %[[TEMP_Y]]
+ // CHECK: %[[D_X:.*]] = arith.subi %[[X]], %[[TEMP_X]]
// Compute the left, right, and top indices for the bilinear interpolation.
- // CHECK: %[[ONE:.+]] = arith.constant 1
- // CHECK: %[[Y1:.+]] = arith.addi %[[Y0]], %[[ONE]]
- // CHECK: %[[X1:.+]] = arith.addi %[[X0]], %[[ONE]]
+ // CHECK: %[[ONE:.*]] = arith.constant 1
+ // CHECK: %[[Y1:.*]] = arith.addi %[[I_Y]], %[[ONE]]
+ // CHECK: %[[X1:.*]] = arith.addi %[[I_X]], %[[ONE]]
// Bound check each dimension.
- // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[Y0]], %[[XYMIN]]
- // CHECK: %[[BOUND:.+]] = arith.select %[[PRED]], %[[XYMIN]], %[[Y0]]
- // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[YMAX]], %[[Y0]]
- // CHECK: %[[YLO:.+]] = arith.select %[[PRED]], %[[YMAX]], %[[BOUND]]
+ // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_Y]], %[[XY_MIN]]
+ // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[I_Y]]
+ // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[I_Y]]
+ // CHECK: %[[YLO:.*]] = arith.select %[[PRED]], %[[Y_MAX]], %[[BOUND]]
- // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[Y1]], %[[XYMIN]]
- // CHECK: %[[BOUND:.+]] = arith.select %[[PRED]], %[[XYMIN]], %[[Y1]]
- // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[YMAX]], %[[Y1]]
- // CHECK: %[[YHI:.+]] = arith.select %[[PRED]], %[[YMAX]], %[[BOUND]]
+ // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y1]], %[[XY_MIN]]
+ // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[Y1]]
+ // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[Y1]]
+ // CHECK: %[[YHI:.*]] = arith.select %[[PRED]], %[[Y_MAX]], %[[BOUND]]
- // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[X0]], %[[XYMIN]]
- // CHECK: %[[BOUND:.+]] = arith.select %[[PRED]], %[[XYMIN]], %[[X0]]
- // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[XMAX]], %[[X0]]
- // CHECK: %[[XLO:.+]] = arith.select %[[PRED]], %[[XMAX]], %[[BOUND]]
+ // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_X]], %[[XY_MIN]]
+ // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[I_X]]
+ // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X_MAX]], %[[I_X]]
+ // CHECK: %[[XLO:.*]] = arith.select %[[PRED]], %[[X_MAX]], %[[BOUND]]
- // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[X1]], %[[XYMIN]]
- // CHECK: %[[BOUND:.+]] = arith.select %[[PRED]], %[[XYMIN]], %[[X1]]
- // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[XMAX]], %[[X1]]
- // CHECK: %[[XHI:.+]] = arith.select %[[PRED]], %[[XMAX]], %[[BOUND]]
+ // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X1]], %[[XY_MIN]]
+ // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[X1]]
+ // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X_MAX]], %[[X1]]
+ // CHECK: %[[XHI:.*]] = arith.select %[[PRED]], %[[X_MAX]], %[[BOUND]]
// Extract each corner of the bilinear interpolation.
@@ -1749,182 +1766,215 @@ func.func @resize_bilinear(%input: tensor<1x2x2x1xf32>) -> () {
// CHECK: %[[XLOI:.+]] = arith.index_cast %[[XLO]]
// CHECK: %[[XHII:.+]] = arith.index_cast %[[XHI]]
- // CHECK: %[[LOLO:.+]] = tensor.extract %[[ARG0]][%[[IDX0]], %[[YLOI]], %[[XLOI]], %[[IDX3]]]
- // CHECK: %[[LOHI:.+]] = tensor.extract %[[ARG0]][%[[IDX0]], %[[YLOI]], %[[XHII]], %[[IDX3]]]
- // CHECK: %[[HILO:.+]] = tensor.extract %[[ARG0]][%[[IDX0]], %[[YHII]], %[[XLOI]], %[[IDX3]]]
- // CHECK: %[[HIHI:.+]] = tensor.extract %[[ARG0]][%[[IDX0]], %[[YHII]], %[[XHII]], %[[IDX3]]]
+ // CHECK: %[[LOLO:.+]] = tensor.extract %[[ARG0]][%[[IDX_0]], %[[YLOI]], %[[XLOI]], %[[IDX_3]]]
+ // CHECK: %[[LOHI:.+]] = tensor.extract %[[ARG0]][%[[IDX_0]], %[[YLOI]], %[[XHII]], %[[IDX_3]]]
+ // CHECK: %[[HILO:.+]] = tensor.extract %[[ARG0]][%[[IDX_0]], %[[YHII]], %[[XLOI]], %[[IDX_3]]]
+ // CHECK: %[[HIHI:.+]] = tensor.extract %[[ARG0]][%[[IDX_0]], %[[YHII]], %[[XHII]], %[[IDX_3]]]
+
+ // CHECK: %[[XLOLO:.+]] = arith.extsi %[[LOLO]]
+ // CHECK: %[[XLOHI:.+]] = arith.extsi %[[LOHI]]
+ // CHECK: %[[XHILO:.+]] = arith.extsi %[[HILO]]
+ // CHECK: %[[XHIHI:.+]] = arith.extsi %[[HIHI]]
// Compute the bilinear interpolation.
- // CHECK: %[[ONE:.+]] = arith.constant 1.000000e+00
- // CHECK: %[[NDX:.+]] = arith.subf %[[ONE]], %[[DX]]
- // CHECK: %[[WLOLO:.+]] = arith.mulf %[[LOLO]], %[[NDX]]
- // CHECK: %[[WLOHI:.+]] = arith.mulf %[[LOHI]], %[[DX]]
- // CHECK: %[[LO:.+]] = arith.addf %[[WLOLO]], %[[WLOHI]]
- // CHECK: %[[WHILO:.+]] = arith.mulf %[[HILO]], %[[NDX]]
- // CHECK: %[[WHIHI:.+]] = arith.mulf %[[HIHI]], %[[DX]]
- // CHECK: %[[HI:.+]] = arith.addf %[[WHILO]], %[[WHIHI]]
- // CHECK: %[[NDY:.+]] = arith.subf %[[ONE]], %[[DY]]
- // CHECK: %[[WLO:.+]] = arith.mulf %[[LO]], %[[NDY]]
- // CHECK: %[[WHI:.+]] = arith.mulf %[[HI]], %[[DY]]
- // CHECK: %[[RESULT:.+]] = arith.addf %[[WLO]], %[[WHI]]
+ // CHECK: %[[NDX:.+]] = arith.subi %[[SCALE_X_N]], %[[D_X]]
+ // CHECK: %[[WLOLO:.+]] = arith.muli %[[XLOLO]], %[[NDX]]
+ // CHECK: %[[WLOHI:.+]] = arith.muli %[[XLOHI]], %[[D_X]]
+ // CHECK: %[[LO:.+]] = arith.addi %[[WLOLO]], %[[WLOHI]]
+ // CHECK: %[[WHILO:.+]] = arith.muli %[[XHILO]], %[[NDX]]
+ // CHECK: %[[WHIHI:.+]] = arith.muli %[[XHIHI]], %[[D_X]]
+ // CHECK: %[[HI:.+]] = arith.addi %[[WHILO]], %[[WHIHI]]
+ // CHECK: %[[NDY:.+]] = arith.subi %[[SCALE_Y_N]], %[[D_Y]]
+ // CHECK: %[[WLO:.+]] = arith.muli %[[LO]], %[[NDY]]
+ // CHECK: %[[WHI:.+]] = arith.muli %[[HI]], %[[D_Y]]
+ // CHECK: %[[RESULT:.+]] = arith.addi %[[WLO]], %[[WHI]]
// CHECK: linalg.yield %[[RESULT]]
- %output = "tosa.resize"(%input) { output_size = [4, 4], stride = [0, 0], offset = [0, 0], stride_fp = [0.5 : f32, 0.5 : f32], offset_fp = [0.1 : f32, 0.2 : f32], shift = 0 : i32, mode = "BILINEAR" } : (tensor<1x2x2x1xf32>) -> (tensor<1x4x4x1xf32>)
- return
+
+ // Round to the nearest index.
+ %0 = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = [16, 1, 16, 1], offset = [0, 0], border = [0, 0]} : (tensor<1x19x19x1xi8>) -> tensor<1x289x289x1xi32>
+ return
}
// -----
-// CHECK-LABEL: @resize_nearest_int
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
-func.func @resize_nearest_int(%input: tensor<1x2x2x1xi32>) -> () {
- // CHECK: %[[INIT:.+]] = tensor.empty()
+// CHECK-LABEL: @resize_nearest_fp
+func.func @resize_nearest_fp(%input: tensor<1x50x48x1xf32>) -> () {
+ // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x1600x1536x1xf32>
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK: %[[IDX0:.+]] = linalg.index 0
// CHECK: %[[IDX1:.+]] = linalg.index 1
// CHECK: %[[IDX2:.+]] = linalg.index 2
// CHECK: %[[IDX3:.+]] = linalg.index 3
- // CHECK-DAG: %[[XYMIN:.+]] = arith.constant 0
- // CHECK-DAG: %[[YMAX:.+]] = arith.constant 1
- // CHECK-DAG: %[[XMAX:.+]] = arith.constant 1
- // CHECK-DAG: %[[Y:.+]] = arith.index_cast %[[IDX1]]
- // CHECK-DAG: %[[X:.+]] = arith.index_cast %[[IDX2]]
- // CHECK-DAG: %[[STRIDEY:.+]] = arith.constant 128
- // CHECK-DAG: %[[STRIDEX:.+]] = arith.constant 128
- // CHECK-DAG: %[[OFFSETY:.+]] = arith.constant 1
- // CHECK-DAG: %[[OFFSETX:.+]] = arith.constant 2
- // CHECK-DAG: %[[EIGHT:.+]] = arith.constant 8
- // CHECK-DAG: %[[VAL4:.+]] = arith.muli %[[Y]], %[[STRIDEY]]
- // CHECK-DAG: %[[VAL5:.+]] = arith.muli %[[X]], %[[STRIDEX]]
- // CHECK-DAG: %[[VAL6:.+]] = arith.addi %[[VAL4]], %[[OFFSETY]]
- // CHECK-DAG: %[[VAL7:.+]] = arith.addi %[[VAL5]], %[[OFFSETX]]
+ // CHECK: %[[XYMIN:.*]] = arith.constant 0
+ // CHECK: %[[YMAX:.*]] = arith.constant 49
+ // CHECK: %[[XMAX:.*]] = arith.constant 47
+ // CHECK: %[[Y:.+]] = arith.index_cast %[[IDX1]]
+ // CHECK: %[[X:.+]] = arith.index_cast %[[IDX2]]
+ // CHECK: %[[ISCALE_Y_N:.*]] = arith.constant 64
+ // CHECK: %[[ISCALE_Y_D:.*]] = arith.constant 2
+ // CHECK: %[[ISCALE_X_N:.*]] = arith.constant 64
+ // CHECK: %[[ISCALE_X_D:.*]] = arith.constant 2
+ // CHECK: %[[IOFFSET_Y:.*]] = arith.constant -31
+ // CHECK: %[[IOFFSET_X:.*]] = arith.constant -31
+ // CHECK: %[[IBORDER_Y:.*]] = arith.constant 31
+ // CHECK: %[[IBORDER_X:.*]] = arith.constant 31
+
+ // CHECK: %[[Y0:.+]] = arith.uitofp %[[Y]]
+ // CHECK: %[[X0:.+]] = arith.uitofp %[[X]]
+ // CHECK: %[[SCALE_Y_N:.*]] = arith.uitofp %[[ISCALE_Y_N]]
+ // CHECK: %[[SCALE_Y_D:.*]] = arith.uitofp %[[ISCALE_Y_D]]
+ // CHECK: %[[SCALE_X_N:.*]] = arith.uitofp %[[ISCALE_X_N]]
+ // CHECK: %[[SCALE_X_D:.*]] = arith.uitofp %[[ISCALE_X_D]]
+ // CHECK: %[[OFFSET_Y:.*]] = arith.uitofp %[[IOFFSET_Y]]
+ // CHECK: %[[OFFSET_X:.*]] = arith.uitofp %[[IOFFSET_X]]
+
+ // CHECK: %[[VAL_29:.*]] = arith.mulf %[[Y0]], %[[SCALE_Y_D]]
+ // CHECK: %[[VAL_30:.*]] = arith.mulf %[[X0]], %[[SCALE_X_D]]
+ // CHECK: %[[VAL_31:.*]] = arith.addf %[[VAL_29]], %[[OFFSET_Y]]
+ // CHECK: %[[VAL_32:.*]] = arith.addf %[[VAL_30]], %[[OFFSET_X]]
+ // CHECK: %[[VAL_33:.*]] = arith.divf %[[VAL_31]], %[[SCALE_Y_N]]
+ // CHECK: %[[VAL_34:.*]] = arith.divf %[[VAL_32]], %[[SCALE_X_N]]
// Find the remainder and integer component of the target index.
+ // CHECK: %[[VAL_35:.*]] = math.floor %[[VAL_33]]
+ // CHECK: %[[VAL_36:.*]] = math.floor %[[VAL_34]]
+ // CHECK: %[[D_Y:.*]] = arith.subf %[[VAL_33]], %[[VAL_35]]
+ // CHECK: %[[D_X:.*]] = arith.subf %[[VAL_34]], %[[VAL_36]]
+ // CHECK: %[[VAL_39:.*]] = arith.fptosi %[[VAL_35]]
+ // CHECK: %[[VAL_40:.*]] = arith.fptosi %[[VAL_36]]
+
+ // CHECK: %[[ZERO:.*]] = arith.constant 0
+ // CHECK: %[[ONE:.*]] = arith.constant 1
+ // CHECK: %[[HALF:.*]] = arith.constant 5.000000e-01
+ // CHECK: %[[PRED_Y:.*]] = arith.cmpf oge, %[[D_Y]], %[[HALF]]
+ // CHECK: %[[PRED_X:.*]] = arith.cmpf oge, %[[D_X]], %[[HALF]]
+ // CHECK: %[[ROUND_Y:.*]] = arith.select %[[PRED_Y]], %[[ONE]], %[[ZERO]]
+ // CHECK: %[[ROUND_X:.*]] = arith.select %[[PRED_X]], %[[ONE]], %[[ZERO]]
+ // CHECK: %[[VAL_48:.*]] = arith.addi %[[VAL_39]], %[[ROUND_Y]]
+ // CHECK: %[[VAL_49:.*]] = arith.addi %[[VAL_40]], %[[ROUND_X]]
+
+ // CHECK: %[[VAL_50:.*]] = arith.cmpi slt, %[[VAL_48]], %[[XYMIN]]
+ // CHECK: %[[VAL_51:.*]] = arith.select %[[VAL_50]], %[[XYMIN]], %[[VAL_48]]
+ // CHECK: %[[VAL_52:.*]] = arith.cmpi slt, %[[YMAX]], %[[VAL_48]]
+ // CHECK: %[[VAL_53:.*]] = arith.select %[[VAL_52]], %[[YMAX]], %[[VAL_51]]
+ // CHECK: %[[VAL_54:.*]] = arith.cmpi slt, %[[VAL_49]], %[[XYMIN]]
+ // CHECK: %[[VAL_55:.*]] = arith.select %[[VAL_54]], %[[XYMIN]], %[[VAL_49]]
+ // CHECK: %[[VAL_56:.*]] = arith.cmpi slt, %[[XMAX]], %[[VAL_49]]
+ // CHECK: %[[VAL_57:.*]] = arith.select %[[VAL_56]], %[[XMAX]], %[[VAL_55]]
+
+ // CHECK: %[[IDY:.*]] = arith.index_cast %[[VAL_53]]
+ // CHECK: %[[IDX:.*]] = arith.index_cast %[[VAL_57]]
+ // CHECK: %[[EXTRACT:.+]] = tensor.extract %arg0[%[[IDX0]], %[[IDY]], %[[IDX]], %[[IDX3]]]
+ // CHECK: linalg.yield %[[EXTRACT]]
- // CHECK-DAG: %[[VAL8:.+]] = arith.shrsi %[[VAL6]], %[[EIGHT]]
- // CHECK-DAG: %[[VAL9:.+]] = arith.shrsi %[[VAL7]], %[[EIGHT]]
- // CHECK-DAG: %[[VAL10:.+]] = arith.shli %[[VAL8]], %[[EIGHT]]
- // CHECK-DAG: %[[VAL11:.+]] = arith.shli %[[VAL9]], %[[EIGHT]]
- // CHECK-DAG: %[[VAL12:.+]] = arith.subi %[[VAL6]], %[[VAL10]]
- // CHECK-DAG: %[[VAL13:.+]] = arith.subi %[[VAL7]], %[[VAL11]]
-
- // Round to the nearest index.
-
- // CHECK-DAG: %[[ROUND:.+]] = arith.constant 128
- // CHECK-DAG: %[[VAL16:.+]] = arith.cmpi sge, %[[VAL12]], %[[ROUND]]
- // CHECK-DAG: %[[VAL17:.+]] = arith.cmpi sge, %[[VAL13]], %[[ROUND]]
- // CHECK-DAG: %[[ZERO:.+]] = arith.constant 0
- // CHECK-DAG: %[[ONE:.+]] = arith.constant 1
- // CHECK-DAG: %[[VAL18:.+]] = arith.select %[[VAL16]], %[[ONE]], %[[ZERO]]
- // CHECK-DAG: %[[VAL19:.+]] = arith.select %[[VAL17]], %[[ONE]], %[[ZERO]]
- // CHECK-DAG: %[[VAL20:.+]] = arith.addi %[[VAL8]], %[[VAL18]]
- // CHECK-DAG: %[[VAL21:.+]] = arith.addi %[[VAL9]], %[[VAL19]]
-
- // This section applies bound checking to be within the input image.
-
- // CHECK-DAG: %[[VAL22:.+]] = arith.cmpi slt, %[[VAL20]], %[[XYMIN]]
- // CHECK-DAG: %[[VAL23:.+]] = arith.select %[[VAL22]], %[[XYMIN]], %[[VAL20]]
- // CHECK-DAG: %[[VAL24:.+]] = arith.cmpi slt, %[[YMAX]], %[[VAL20]]
- // CHECK-DAG: %[[VAL25:.+]] = arith.select %[[VAL24]], %[[YMAX]], %[[VAL23]]
- // CHECK-DAG: %[[VAL26:.+]] = arith.cmpi slt, %[[VAL21]], %[[XYMIN]]
- // CHECK-DAG: %[[VAL27:.+]] = arith.select %[[VAL26]], %[[XYMIN]], %[[VAL21]]
- // CHECK-DAG: %[[VAL28:.+]] = arith.cmpi slt, %[[XMAX]], %[[VAL21]]
- // CHECK-DAG: %[[VAL29:.+]] = arith.select %[[VAL28]], %[[XMAX]], %[[VAL27]]
+ %output = "tosa.resize"(%input) {mode = "NEAREST_NEIGHBOR", scale = [64, 2, 64, 2], offset = [-31, -31], border = [31, 31]} : (tensor<1x50x48x1xf32>) -> tensor<1x1600x1536x1xf32>
- // Extract the nearest value using the computed indices.
-
- // CHECK-DAG: %[[IDY:.+]] = arith.index_cast %[[VAL25]]
- // CHECK-DAG: %[[IDX:.+]] = arith.index_cast %[[VAL29]]
- // CHECK: %[[EXTRACT:.+]] = tensor.extract %[[ARG0]][%[[IDX0]], %[[IDY]], %[[IDX]], %[[IDX3]]]
- // CHECK: linalg.yield %[[EXTRACT]]
- %output = "tosa.resize"(%input) { output_size = [4, 4], stride = [128, 128], offset = [1, 2], stride_fp = [0. : f32, 0. : f32], offset_fp = [0. : f32, 0. : f32], shift = 8 : i32, mode = "NEAREST_NEIGHBOR" } : (tensor<1x2x2x1xi32>) -> (tensor<1x4x4x1xi32>)
return
}
// -----
-// CHECK-LABEL: @resize_bilinear_int
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
-func.func @resize_bilinear_int(%input: tensor<1x2x2x1xi8>) -> () {
- // CHECK: %[[INIT:.+]] = tensor.empty()
+// CHECK-LABEL: @resize_bilinear_fp
+func.func @resize_bilinear_fp(%input: tensor<1x23x23x1xf32>) -> () {
+ // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x89x89x1xf32>
// CHECK: %[[GENERIC:.+]] = linalg.generic
-
- // CHECK: %[[IDX0:.+]] = linalg.index 0
- // CHECK: %[[IDX3:.+]] = linalg.index 3
-
- // CHECK: %[[XYMIN:.+]] = arith.constant 0
- // CHECK: %[[YMAX:.+]] = arith.constant 1
- // CHECK: %[[XMAX:.+]] = arith.constant 1
-
- // CHECK: %[[Y0:.+]] = arith.shrsi
- // CHECK: %[[X0:.+]] = arith.shrsi
- // CHECK: %[[ROUNDY:.+]] = arith.shli %[[Y0]]
- // CHECK: %[[ROUNDX:.+]] = arith.shli %[[X0]]
- // CHECK: %[[DY:.+]] = arith.subi %10, %[[ROUNDY]]
- // CHECK: %[[DX:.+]] = arith.subi %11, %[[ROUNDX]]
+ // CHECK: %[[IDX_0:.+]] = linalg.index 0
+ // CHECK: %[[IDX_1:.+]] = linalg.index 1
+ // CHECK: %[[IDX_2:.+]] = linalg.index 2
+ // CHECK: %[[IDX_3:.+]] = linalg.index 3
+ // CHECK: %[[XY_MIN:.*]] = arith.constant 0
+ // CHECK: %[[Y_MAX:.*]] = arith.constant 22
+ // CHECK: %[[X_MAX:.*]] = arith.constant 22
+ // CHECK: %[[Y:.+]] = arith.index_cast %[[IDX_1]]
+ // CHECK: %[[X:.+]] = arith.index_cast %[[IDX_2]]
+ // CHECK: %[[ISCALE_Y_N:.*]] = arith.constant 4
+ // CHECK: %[[ISCALE_Y_D:.*]] = arith.constant 1
+ // CHECK: %[[ISCALE_X_N:.*]] = arith.constant 4
+ // CHECK: %[[ISCALE_X_D:.*]] = arith.constant 1
+ // CHECK: %[[IOFFSET_Y:.*]] = arith.constant 0
+ // CHECK: %[[IOFFSET_X:.*]] = arith.constant 0
+ // CHECK: %[[IBORDER_Y:.*]] = arith.constant 0
+ // CHECK: %[[IBORDER_X:.*]] = arith.constant 0
+
+ // CHECK: %[[Y0:.+]] = arith.uitofp %[[Y]]
+ // CHECK: %[[X0:.+]] = arith.uitofp %[[X]]
+ // CHECK: %[[SCALE_Y_N:.*]] = arith.uitofp %[[ISCALE_Y_N]]
+ // CHECK: %[[SCALE_Y_D:.*]] = arith.uitofp %[[ISCALE_Y_D]]
+ // CHECK: %[[SCALE_X_N:.*]] = arith.uitofp %[[ISCALE_X_N]]
+ // CHECK: %[[SCALE_X_D:.*]] = arith.uitofp %[[ISCALE_X_D]]
+ // CHECK: %[[OFFSET_Y:.*]] = arith.uitofp %[[IOFFSET_Y]]
+ // CHECK: %[[OFFSET_X:.*]] = arith.uitofp %[[IOFFSET_X]]
+
+ // CHECK: %[[VAL_29:.*]] = arith.mulf %[[Y0]], %[[SCALE_Y_D]]
+ // CHECK: %[[VAL_30:.*]] = arith.mulf %[[X0]], %[[SCALE_X_D]]
+ // CHECK: %[[VAL_31:.*]] = arith.addf %[[VAL_29]], %[[OFFSET_Y]]
+ // CHECK: %[[VAL_32:.*]] = arith.addf %[[VAL_30]], %[[OFFSET_X]]
+ // CHECK: %[[VAL_33:.*]] = arith.divf %[[VAL_31]], %[[SCALE_Y_N]]
+ // CHECK: %[[VAL_34:.*]] = arith.divf %[[VAL_32]], %[[SCALE_X_N]]
+
+ // CHECK: %[[VAL_35:.*]] = math.floor %[[VAL_33]]
+ // CHECK: %[[VAL_36:.*]] = math.floor %[[VAL_34]]
+ // CHECK: %[[D_Y:.*]] = arith.subf %[[VAL_33]], %[[VAL_35]]
+ // CHECK: %[[D_X:.*]] = arith.subf %[[VAL_34]], %[[VAL_36]]
+ // CHECK: %[[I_Y:.*]] = arith.fptosi %[[VAL_35]]
+ // CHECK: %[[I_X:.*]] = arith.fptosi %[[VAL_36]]
// Compute the left, right, and top indices for the bilinear interpolation.
- // CHECK: %[[ONE:.+]] = arith.constant 1
- // CHECK: %[[Y1:.+]] = arith.addi %[[Y0]], %[[ONE]]
- // CHECK: %[[X1:.+]] = arith.addi %[[X0]], %[[ONE]]
+ // CHECK: %[[ONE:.*]] = arith.constant 1
+ // CHECK: %[[Y1:.*]] = arith.addi %[[I_Y]], %[[ONE]]
+ // CHECK: %[[X1:.*]] = arith.addi %[[I_X]], %[[ONE]]
// Bound check each dimension.
- // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[Y0]], %[[XYMIN]]
- // CHECK: %[[BOUND:.+]] = arith.select %[[PRED]], %[[XYMIN]], %[[Y0]]
- // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[YMAX]], %[[Y0]]
- // CHECK: %[[YLO:.+]] = arith.select %[[PRED]], %[[YMAX]], %[[BOUND]]
+ // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_Y]], %[[XY_MIN]]
+ // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[I_Y]]
+ // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[I_Y]]
+ // CHECK: %[[YLO:.*]] = arith.select %[[PRED]], %[[Y_MAX]], %[[BOUND]]
- // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[Y1]], %[[XYMIN]]
- // CHECK: %[[BOUND:.+]] = arith.select %[[PRED]], %[[XYMIN]], %[[Y1]]
- // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[YMAX]], %[[Y1]]
- // CHECK: %[[YHI:.+]] = arith.select %[[PRED]], %[[YMAX]], %[[BOUND]]
+ // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y1]], %[[XY_MIN]]
+ // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[Y1]]
+ // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[Y1]]
+ // CHECK: %[[YHI:.*]] = arith.select %[[PRED]], %[[Y_MAX]], %[[BOUND]]
- // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[X0]], %[[XYMIN]]
- // CHECK: %[[BOUND:.+]] = arith.select %[[PRED]], %[[XYMIN]], %[[X0]]
- // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[XMAX]], %[[X0]]
- // CHECK: %[[XLO:.+]] = arith.select %[[PRED]], %[[XMAX]], %[[BOUND]]
+ // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_X]], %[[XY_MIN]]
+ // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[I_X]]
+ // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X_MAX]], %[[I_X]]
+ // CHECK: %[[XLO:.*]] = arith.select %[[PRED]], %[[X_MAX]], %[[BOUND]]
- // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[X1]], %[[XYMIN]]
- // CHECK: %[[BOUND:.+]] = arith.select %[[PRED]], %[[XYMIN]], %[[X1]]
- // CHECK: %[[PRED:.+]] = arith.cmpi slt, %[[XMAX]], %[[X1]]
- // CHECK: %[[XHI:.+]] = arith.select %[[PRED]], %[[XMAX]], %[[BOUND]]
-
- // Extract each corner of the bilinear interpolation.
+ // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X1]], %[[XY_MIN]]
+ // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[X1]]
+ // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X_MAX]], %[[X1]]
+ // CHECK: %[[XHI:.*]] = arith.select %[[PRED]], %[[X_MAX]], %[[BOUND]]
// CHECK: %[[YLOI:.+]] = arith.index_cast %[[YLO]]
// CHECK: %[[YHII:.+]] = arith.index_cast %[[YHI]]
// CHECK: %[[XLOI:.+]] = arith.index_cast %[[XLO]]
// CHECK: %[[XHII:.+]] = arith.index_cast %[[XHI]]
- // CHECK: %[[LOLO:.+]] = tensor.extract %[[ARG0]][%[[IDX0]], %[[YLOI]], %[[XLOI]], %[[IDX3]]]
- // CHECK: %[[LOHI:.+]] = tensor.extract %[[ARG0]][%[[IDX0]], %[[YLOI]], %[[XHII]], %[[IDX3]]]
- // CHECK: %[[HILO:.+]] = tensor.extract %[[ARG0]][%[[IDX0]], %[[YHII]], %[[XLOI]], %[[IDX3]]]
- // CHECK: %[[HIHI:.+]] = tensor.extract %[[ARG0]][%[[IDX0]], %[[YHII]], %[[XHII]], %[[IDX3]]]
+ // CHECK: %[[LOLO:.+]] = tensor.extract %arg0[%[[IDX_0]], %[[YLOI]], %[[XLOI]], %[[IDX_3]]]
+ // CHECK: %[[LOHI:.+]] = tensor.extract %arg0[%[[IDX_0]], %[[YLOI]], %[[XHII]], %[[IDX_3]]]
+ // CHECK: %[[HILO:.+]] = tensor.extract %arg0[%[[IDX_0]], %[[YHII]], %[[XLOI]], %[[IDX_3]]]
+ // CHECK: %[[HIHI:.+]] = tensor.extract %arg0[%[[IDX_0]], %[[YHII]], %[[XHII]], %[[IDX_3]]]
- // CHECK: %[[XLOLO:.+]] = arith.extsi %[[LOLO]]
- // CHECK: %[[XLOHI:.+]] = arith.extsi %[[LOHI]]
- // CHECK: %[[XHILO:.+]] = arith.extsi %[[HILO]]
- // CHECK: %[[XHIHI:.+]] = arith.extsi %[[HIHI]]
+ // CHECK: %[[NDX:.+]] = arith.subf %[[SCALE_X_N]], %[[D_X]]
+ // CHECK: %[[WLOLO:.+]] = arith.mulf %[[LOLO]], %[[NDX]]
+ // CHECK: %[[WLOHI:.+]] = arith.mulf %[[LOHI]], %[[D_X]]
+ // CHECK: %[[LO:.+]] = arith.addf %[[WLOLO]], %[[WLOHI]]
+ // CHECK: %[[WHILO:.+]] = arith.mulf %[[HILO]], %[[NDX]]
+ // CHECK: %[[WHIHI:.+]] = arith.mulf %[[HIHI]], %[[D_X]]
+ // CHECK: %[[HI:.+]] = arith.addf %[[WHILO]], %[[WHIHI]]
+ // CHECK: %[[NDY:.+]] = arith.subf %[[SCALE_Y_N]], %[[D_Y]]
+ // CHECK: %[[WLO:.+]] = arith.mulf %[[LO]], %[[NDY]]
+ // CHECK: %[[WHI:.+]] = arith.mulf %[[HI]], %[[D_Y]]
+ // CHECK: %[[RESULT:.+]] = arith.addf %[[WLO]], %[[WHI]]
+ // CHECK: linalg.yield %[[RESULT]]
- // Compute the bilinear interpolation.
+ // Round by bilinear interpolation
+ %output = "tosa.resize"(%input) {mode = "BILINEAR", scale = [4, 1, 4, 1], offset = [0, 0], border = [0, 0]} : (tensor<1x23x23x1xf32>) -> tensor<1x89x89x1xf32>
- // CHECK: %[[SCALE:.+]] = arith.constant 256
- // CHECK: %[[NDX:.+]] = arith.subi %[[SCALE]], %[[DX]]
- // CHECK: %[[WLOLO:.+]] = arith.muli %[[XLOLO]], %[[NDX]]
- // CHECK: %[[WLOHI:.+]] = arith.muli %[[XLOHI]], %[[DX]]
- // CHECK: %[[LO:.+]] = arith.addi %[[WLOLO]], %[[WLOHI]]
- // CHECK: %[[WHILO:.+]] = arith.muli %[[XHILO]], %[[NDX]]
- // CHECK: %[[WHIHI:.+]] = arith.muli %[[XHIHI]], %[[DX]]
- // CHECK: %[[HI:.+]] = arith.addi %[[WHILO]], %[[WHIHI]]
- // CHECK: %[[NDY:.+]] = arith.subi %[[SCALE]], %[[DY]]
- // CHECK: %[[WLO:.+]] = arith.muli %[[LO]], %[[NDY]]
- // CHECK: %[[WHI:.+]] = arith.muli %[[HI]], %[[DY]]
- // CHECK: %[[RESULT:.+]] = arith.addi %[[WLO]], %[[WHI]]
- // CHECK: linalg.yield %[[RESULT]]
- %output = "tosa.resize"(%input) { output_size = [4, 4], stride = [128, 128], offset = [1, 2], stride_fp = [0. : f32, 0. : f32], offset_fp = [0. : f32, 0. : f32], shift = 8 : i32, mode = "BILINEAR" } : (tensor<1x2x2x1xi8>) -> (tensor<1x4x4x1xi32>)
return
}
@@ -1933,10 +1983,10 @@ func.func @resize_bilinear_int(%input: tensor<1x2x2x1xi8>) -> () {
// CHECK-LABEL: @resize_dyn
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
func.func @resize_dyn(%input: tensor<?x2x2x1xi8>) -> () {
- // CHECK: %[[C0:.+]] = arith.constant 0
- // CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
- // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]])
+ // CHECK: %[[C0:.+]] = arith.constant 0
+ // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
+ // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x4x4x1xi32>
// CHECK: %[[GENERIC:.+]] = linalg.generic
- %output = "tosa.resize"(%input) { output_size = [4, 4], stride = [128, 128], offset = [1, 2], stride_fp = [0. : f32, 0. : f32], offset_fp = [0. : f32, 0. : f32], shift = 8 : i32, mode = "BILINEAR" } : (tensor<?x2x2x1xi8>) -> (tensor<?x4x4x1xi32>)
+ %output = "tosa.resize"(%input) { scale = [4, 2, 4, 2], offset = [-1, -1], border = [1, 1], mode = "BILINEAR" } : (tensor<?x2x2x1xi8>) -> (tensor<?x4x4x1xi32>)
return
}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 37b25a1519e1b..7894b07a8ef42 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -448,7 +448,7 @@ func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %a
// -----
// CHECK-LABEL: resize
func.func @test_resize(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> {
- %1 = "tosa.resize"(%arg0) {output_size = [64, 64], stride = [1024, 1024], offset = [0, 0], shift = 10 : i32, stride_fp = [0.0 : f32, 0.0 : f32], offset_fp = [0.0 : f32, 0.0 : f32], mode = "BILINEAR"} : (tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32>
+ %1 = "tosa.resize"(%arg0) { scale = [4, 2, 4, 2], offset = [-1, -1], border = [1, 1], mode = "BILINEAR"} : (tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32>
return %1 : tensor<1x64x64x8xf32>
}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 3fd70b2510167..311851a72c889 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -964,61 +964,71 @@ func.func @transpose_conv2d_strided(%arg0: tensor<1x5x7x1xf32>, %arg1: tensor<1x
// -----
-// CHECK-LABEL: @resize_output_size
-func.func @resize_output_size(%arg0: tensor<2x?x?x3xi32>) {
- // CHECK: -> tensor<2x4x5x3xi32>
- %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", offset = [0, 1], offset_fp = [0.000000e+00 : f32, 0.000000e+00 : f32], output_size = [4, 5], shift = 8 : i32, stride = [1, 1], stride_fp = [0.000000e+00 : f32, 0.000000e+00 : f32]} : (tensor<2x?x?x3xi32>) -> tensor<?x?x?x?xi32>
+// CHECK-LABEL: @resize_int_horizontal
+func.func @resize_int_horizontal(%arg0: tensor<1x15x13x1xi8>) {
+ // CHECK: -> tensor<1x23x179x1xi8>
+ %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = [11, 7, 89, 6], offset = [0, 0], border = [0, 0]} : (tensor<1x15x13x1xi8>) -> tensor<?x?x?x?xi8>
return
}
// -----
-// CHECK-LABEL: @resize_int_horizontal
-func.func @resize_int_horizontal(%arg0: tensor<1x2x4x1xi32>) {
- // CHECK: -> tensor<1x2x7x1xi32>
- %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", offset = [0, 0], offset_fp = [0.000000e+00 : f32, 0.000000e+00 : f32], output_size = [-1, -1], shift = 8 : i32, stride = [256, 128], stride_fp = [0.000000e+00 : f32, 0.000000e+00 : f32]} : (tensor<1x2x4x1xi32>) -> tensor<?x?x?x?xi32>
+// CHECK-LABEL: @resize_int_vertical
+func.func @resize_int_vertical(%arg0: tensor<1x49x42x1xi16>) {
+ // CHECK: -> tensor<1x112x220x1xi16>
+ %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = [37, 16, 219, 41], offset = [0, 0], border = [0, 0]} : (tensor<1x49x42x1xi16>) -> tensor<?x?x?x?xi16>
return
}
// -----
-// CHECK-LABEL: @resize_int_vertical
-func.func @resize_int_vertical(%arg0: tensor<1x2x4x1xi32>) {
- // CHECK: -> tensor<1x3x4x1xi32>
- %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", offset = [0, 0], offset_fp = [0.000000e+00 : f32, 0.000000e+00 : f32], output_size = [-1, -1], shift = 8 : i32, stride = [128, 256], stride_fp = [0.000000e+00 : f32, 0.000000e+00 : f32]} : (tensor<1x2x4x1xi32>) -> tensor<?x?x?x?xi32>
+// CHECK-LABEL: @resize_int_power_of_two_upscale
+func.func @resize_int_power_of_two_upscale(%arg0: tensor<1x23x19x1xi8>) {
+ // CHECK: -> tensor<1x353x289x1xi32>
+ %0 = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = [16, 1, 16, 1], offset = [0, 0], border = [0, 0]} : (tensor<1x23x19x1xi8>) -> tensor<?x?x?x?xi32>
return
}
// -----
-// CHECK-LABEL: @resize_int_offsetted
-func.func @resize_int_offsetted(%arg0: tensor<1x2x4x1xi32>) {
- // CHECK: -> tensor<1x4x6x1xi32>
- %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", offset = [64, 64], offset_fp = [0.000000e+00 : f32, 0.000000e+00 : f32], output_size = [-1, -1], shift = 8 : i32, stride = [64, 128], stride_fp = [0.000000e+00 : f32, 0.000000e+00 : f32]} : (tensor<1x2x4x1xi32>) -> tensor<?x?x?x?xi32>
+// CHECK-LABEL: @resize_int_power_of_two_upscale_offsetted
+func.func @resize_int_power_of_two_upscale_offsetted(%arg0: tensor<1x41x26x1xi16>) {
+ // CHECK: -> tensor<1x328x208x1xi48>
+ %0 = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = [16, 2, 16, 2], offset = [-7, -7], border = [7, 7]} : (tensor<1x41x26x1xi16>) -> tensor<?x?x?x?xi48>
return
}
// -----
-
// CHECK-LABEL: @resize_fp_horizontal
-func.func @resize_fp_horizontal(%arg0: tensor<1x2x4x1xi32>) {
- // CHECK: -> tensor<1x2x7x1xi32>
- %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", offset = [0, 0], offset_fp = [0.000000e+00 : f32, 0.000000e+00 : f32], output_size = [-1, -1], shift = 0 : i32, stride = [0, 0], stride_fp = [1.000000e+00 : f32, 5.000000e-01 : f32]} : (tensor<1x2x4x1xi32>) -> tensor<?x?x?x?xi32>
+func.func @resize_fp_horizontal(%arg0: tensor<1x50x48x1xf32>) {
+ // CHECK: -> tensor<1x106x85x1xf32>
+ %0 = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = [15, 7, 84, 47], offset = [0, 0], border = [0, 0]} : (tensor<1x50x48x1xf32>) -> tensor<?x?x?x?xf32>
return
}
// -----
-
// CHECK-LABEL: @resize_fp_vertical
-func.func @resize_fp_vertical(%arg0: tensor<1x2x4x1xi32>) {
- // CHECK: -> tensor<1x3x4x1xi32>
- %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", offset = [0, 0], offset_fp = [0.000000e+00 : f32, 0.000000e+00 : f32], output_size = [-1, -1], shift = 0 : i32, stride = [0, 0], stride_fp = [5.000000e-01 : f32, 1.000000e+00 : f32]} : (tensor<1x2x4x1xi32>) -> tensor<?x?x?x?xi32>
+func.func @resize_fp_vertical(%arg0: tensor<1x50x48x1xf32>) {
+ // CHECK: -> tensor<1x128x13x1xf32>
+ %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = [127, 49, 12, 47], offset = [0, 0], border = [0, 0]} : (tensor<1x50x48x1xf32>) -> tensor<?x?x?x?xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @resize_fp_power_of_two_upscale
+func.func @resize_fp_power_of_two_upscale(%arg0: tensor<1x23x23x1xf32>) {
+ // CHECK: -> tensor<1x89x89x1xf32>
+ %0 = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = [4, 1, 4, 1], offset = [0, 0], border = [0, 0]} : (tensor<1x23x23x1xf32>) -> tensor<?x?x?x?xf32>
return
}
-// CHECK-LABEL: @resize_fp_offsetted
-func.func @resize_fp_offsetted(%arg0: tensor<1x2x4x1xi32>) {
- // CHECK: -> tensor<1x4x6x1xi32>
- %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", offset = [0, 0], offset_fp = [2.500000e-01 : f32, 2.500000e-01 : f32], output_size = [-1, -1], shift = 0 : i32, stride = [0, 0], stride_fp = [2.500000e-01 : f32, 5.000000e-01 : f32]} : (tensor<1x2x4x1xi32>) -> tensor<?x?x?x?xi32>
+
+// -----
+
+// CHECK-LABEL: @resize_fp_power_of_two_upscale_offsetted
+func.func @resize_fp_power_of_two_upscale_offsetted(%arg0: tensor<1x50x48x1xf32>) {
+ // CHECK: -> tensor<1x1600x1536x1xf32>
+ %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = [64, 2, 64, 2], offset = [-31, -31], border = [31, 31]} : (tensor<1x50x48x1xf32>) -> tensor<?x?x?x?xf32>
return
}
More information about the Mlir-commits
mailing list