[Mlir-commits] [mlir] 4309bb2 - [mlir][tosa] Add broadcasting case for tosa.resize degenerate case

Rob Suderman llvmlistbot at llvm.org
Thu Oct 20 15:45:30 PDT 2022


Author: Rob Suderman
Date: 2022-10-20T15:37:19-07:00
New Revision: 4309bb28ae77061d528b09dfe1c41335e92bc7a3

URL: https://github.com/llvm/llvm-project/commit/4309bb28ae77061d528b09dfe1c41335e92bc7a3
DIFF: https://github.com/llvm/llvm-project/commit/4309bb28ae77061d528b09dfe1c41335e92bc7a3.diff

LOG: [mlir][tosa] Add broadcasting case for tosa.resize degenerate case

When the resize is ?x1x1x?, the tosa.resize operation broadcasts the
input and (when quantized) applies a scaling factor. Updated the resize
operation to not use a tensor.extract operation, instead broadcasting
the only positional value as necessary.

Moved the tosa.resize tests to their own mlir test due to increased
complexity. Also corrected a bug where tosa.resize for bilinear-floating
point was not applying the correct scaling.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D136299

Added: 
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 381c0a15da283..178b4b1f959f1 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Tosa/Utils/CoversionUtils.h"
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -1321,7 +1322,104 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
   }
 };
 
-class ResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
+// Handle the case where the resize operation is a regular broadcast. We
+// perform this part separately to avoid generating Extract operations which
+// are 
diff icult to vectorize / optimize.
+class BroadcastResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
+public:
+  using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::ResizeOp op,
+                                PatternRewriter &rewriter) const final {
+    Location loc = op.getLoc();
+    ImplicitLocOpBuilder builder(loc, rewriter);
+    auto input = op.getInput();
+    auto inputTy = input.getType().cast<RankedTensorType>();
+    auto resultTy = op.getType().cast<RankedTensorType>();
+
+    auto imageH = inputTy.getDimSize(1);
+    auto imageW = inputTy.getDimSize(2);
+
+    if (imageH != 1 || imageW != 1) {
+      return rewriter.notifyMatchFailure(
+          op, "tosa.resize is not a pure broadcast operation");
+    }
+
+    // TODO(suderman): These string values should be declared the TOSA dialect.
+    if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
+      return failure();
+
+    const bool isBilinear = op.getMode() == "BILINEAR";
+
+    SmallVector<int32_t> scale;
+    getValuesFromIntArrayAttribute(op.getScale(), scale);
+
+    // Collapse the 1 dimensions away.
+    SmallVector<ReassociationExprs, 4> collapseMap(2);
+    collapseMap[0].push_back(builder.getAffineDimExpr(0));
+    collapseMap[1].push_back(builder.getAffineDimExpr(1));
+    collapseMap[1].push_back(builder.getAffineDimExpr(2));
+    collapseMap[1].push_back(builder.getAffineDimExpr(3));
+
+    auto collapseTy =
+        RankedTensorType::get({inputTy.getDimSize(0), inputTy.getDimSize(3)},
+                              inputTy.getElementType());
+    Value collapse =
+        builder.create<tensor::CollapseShapeOp>(collapseTy, input, collapseMap);
+
+    // Broadcast input to the output shape.
+    llvm::SmallVector<Value> outputDynSize;
+    if (inputTy.isDynamicDim(0))
+      outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
+
+    if (inputTy.isDynamicDim(3))
+      outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
+
+    llvm::SmallVector<AffineExpr> inputExprs{
+        rewriter.getAffineDimExpr(0),
+        rewriter.getAffineDimExpr(3),
+    };
+
+    auto inputMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0,
+                                   inputExprs, builder.getContext());
+    auto resultMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
+    SmallVector<StringRef> iterators(4, getParallelIteratorTypeName());
+
+    Value empty = builder.create<tensor::EmptyOp>(
+        resultTy.getShape(), resultTy.getElementType(), outputDynSize);
+
+    auto generic = builder.create<linalg::GenericOp>(
+        resultTy, ValueRange{collapse}, ValueRange{empty},
+        ArrayRef<AffineMap>{inputMap, resultMap}, iterators,
+        [=](OpBuilder &b, Location loc, ValueRange args) {
+          Value value = args[0];
+          // This is the quantized case.
+          if (inputTy.getElementType() != resultTy.getElementType()) {
+            value =
+                b.create<arith::ExtSIOp>(loc, resultTy.getElementType(), value);
+
+            if (isBilinear && scale[0] != 0) {
+              Value scaleY = b.create<arith::ConstantOp>(
+                  loc, b.getI32IntegerAttr(scale[0]));
+              value = b.create<arith::MulIOp>(loc, value, scaleY);
+            }
+
+            if (isBilinear && scale[2] != 0) {
+              Value scaleX = b.create<arith::ConstantOp>(
+                  loc, b.getI32IntegerAttr(scale[2]));
+              value = b.create<arith::MulIOp>(loc, value, scaleX);
+            }
+          }
+
+          b.create<linalg::YieldOp>(loc, value);
+        });
+
+    rewriter.replaceOp(op, generic.getResult(0));
+    return success();
+  }
+};
+
+class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
 public:
   using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
 
@@ -1351,10 +1449,11 @@ class ResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
     SmallVector<AffineMap, 2> affineMaps = {
         rewriter.getMultiDimIdentityMap(resultTy.getRank())};
 
+    Value resize = input;
     auto genericOp = rewriter.create<linalg::GenericOp>(
         loc, resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps,
         getNParallelLoopsAttrs(resultTy.getRank()));
-    rewriter.replaceOp(op, genericOp.getResult(0));
+    resize = genericOp.getResult(0);
 
     OpBuilder::InsertionGuard regionGuard(rewriter);
     rewriter.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
@@ -1496,7 +1595,6 @@ class ResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
       ix = rewriter.create<arith::AddIOp>(loc, ix, xOffset);
 
       // Clamp the to be within the bounds of the input image.
-
       iy = clampIntHelper(loc, iy, hwMin, hMax, rewriter);
       ix = clampIntHelper(loc, ix, hwMin, wMax, rewriter);
 
@@ -1510,10 +1608,9 @@ class ResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
           loc, input, ValueRange{batch, iy, ix, channel});
 
       rewriter.create<linalg::YieldOp>(loc, result);
-
-      return success();
     } else {
-      // The mode here must be BILINEAR. This has been checked above.
+      // The mode here must be BILINEAR.
+      assert(op.getMode() == "BILINEAR");
       Value y0 = iy;
       Value x0 = ix;
 
@@ -1548,7 +1645,9 @@ class ResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
 
       if (floatingPointMode) {
         Value rightPart = dx;
-        Value leftPart = rewriter.create<arith::SubFOp>(loc, xScaleN, dx);
+        auto oneVal = rewriter.create<arith::ConstantOp>(
+            loc, rewriter.getF32FloatAttr(1.0f));
+        Value leftPart = rewriter.create<arith::SubFOp>(loc, oneVal, dx);
 
         y0x0 = rewriter.create<arith::MulFOp>(loc, y0x0, leftPart);
         y0x1 = rewriter.create<arith::MulFOp>(loc, y0x1, rightPart);
@@ -1559,46 +1658,59 @@ class ResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
         Value bottomAcc = rewriter.create<arith::AddFOp>(loc, y1x0, y1x1);
 
         Value bottomPart = dy;
-        Value topPart = rewriter.create<arith::SubFOp>(loc, yScaleN, dy);
+        Value topPart = rewriter.create<arith::SubFOp>(loc, oneVal, dy);
         topAcc = rewriter.create<arith::MulFOp>(loc, topAcc, topPart);
         bottomAcc = rewriter.create<arith::MulFOp>(loc, bottomAcc, bottomPart);
         Value result = rewriter.create<arith::AddFOp>(loc, topAcc, bottomAcc);
 
         rewriter.create<linalg::YieldOp>(loc, result);
-        return success();
-      }
-      y0x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x0);
-      y0x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x1);
-      y1x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x0);
-      y1x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x1);
-
-      if (resultElementTy.getIntOrFloatBitWidth() > 32) {
-        dx = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dx);
-        dy = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dy);
-      }
-
-        Value rightPart = dx;
-        Value leftPart = rewriter.create<arith::SubIOp>(loc, xScaleN, dx);
-
-        y0x0 = rewriter.create<arith::MulIOp>(loc, y0x0, leftPart);
-        y0x1 = rewriter.create<arith::MulIOp>(loc, y0x1, rightPart);
-        Value topAcc = rewriter.create<arith::AddIOp>(loc, y0x0, y0x1);
+      } else {
+        // Perform in quantized space.
+        y0x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x0);
+        y0x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x1);
+        y1x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x0);
+        y1x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x1);
+
+        if (resultElementTy.getIntOrFloatBitWidth() > 32) {
+          dx = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dx);
+          dy = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dy);
+        }
 
-        y1x0 = rewriter.create<arith::MulIOp>(loc, y1x0, leftPart);
-        y1x1 = rewriter.create<arith::MulIOp>(loc, y1x1, rightPart);
-        Value bottomAcc = rewriter.create<arith::AddIOp>(loc, y1x0, y1x1);
+        Value topAcc, bottomAcc;
+        if (imageW == 1) {
+          topAcc = rewriter.create<arith::MulIOp>(loc, y0x0, xScaleN);
+          bottomAcc = rewriter.create<arith::MulIOp>(loc, y1x0, xScaleN);
+        } else {
+          Value rightPart = dx;
+          Value leftPart = rewriter.create<arith::SubIOp>(loc, xScaleN, dx);
+
+          y0x0 = rewriter.create<arith::MulIOp>(loc, y0x0, leftPart);
+          y0x1 = rewriter.create<arith::MulIOp>(loc, y0x1, rightPart);
+          topAcc = rewriter.create<arith::AddIOp>(loc, y0x0, y0x1);
+
+          y1x0 = rewriter.create<arith::MulIOp>(loc, y1x0, leftPart);
+          y1x1 = rewriter.create<arith::MulIOp>(loc, y1x1, rightPart);
+          bottomAcc = rewriter.create<arith::AddIOp>(loc, y1x0, y1x1);
+        }
 
-        Value bottomPart = dy;
-        Value topPart = rewriter.create<arith::SubIOp>(loc, yScaleN, dy);
-        topAcc = rewriter.create<arith::MulIOp>(loc, topAcc, topPart);
-        bottomAcc = rewriter.create<arith::MulIOp>(loc, bottomAcc, bottomPart);
-        Value result = rewriter.create<arith::AddIOp>(loc, topAcc, bottomAcc);
+        Value result;
+        if (imageH == 1) {
+          result = rewriter.create<arith::MulIOp>(loc, topAcc, yScaleN);
+        } else {
+          Value bottomPart = dy;
+          Value topPart = rewriter.create<arith::SubIOp>(loc, yScaleN, dy);
+          topAcc = rewriter.create<arith::MulIOp>(loc, topAcc, topPart);
+          bottomAcc =
+              rewriter.create<arith::MulIOp>(loc, bottomAcc, bottomPart);
+          result = rewriter.create<arith::AddIOp>(loc, topAcc, bottomAcc);
+        }
 
         rewriter.create<linalg::YieldOp>(loc, result);
-        return success();
+      }
     }
 
-    return failure();
+    rewriter.replaceOp(op, resize);
+    return success();
   }
 };
 
@@ -2210,6 +2322,13 @@ class TableConverter : public OpRewritePattern<tosa::TableOp> {
 
 void mlir::tosa::populateTosaToLinalgConversionPatterns(
     RewritePatternSet *patterns) {
+
+  // We have multiple resize coverters to handle degenerate cases.
+  patterns->add<GenericResizeConverter>(patterns->getContext(),
+                                        /*benefit=*/100);
+  patterns->add<BroadcastResizeConverter>(patterns->getContext(),
+                                          /*benefit=*/200);
+
   patterns->add<
       // clang-format off
       PointwiseConverter<tosa::AddOp>,
@@ -2262,7 +2381,6 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
       ReshapeConverterExpand,
       ReshapeConverterCollapseExpand,
       RescaleConverter,
-      ResizeConverter,
       ReverseConverter,
       TableConverter,
       TileConverter,

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir
new file mode 100644
index 0000000000000..e48be848f6086
--- /dev/null
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir
@@ -0,0 +1,486 @@
+// RUN: mlir-opt --split-input-file -pass-pipeline="func.func(tosa-to-linalg)" %s -o -| FileCheck %s
+
+// CHECK: #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d3)>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: @broadcast_resize_nearest_fp
+func.func @broadcast_resize_nearest_fp(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x15x13x7xf32> {
+  // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %arg0 
+  // CHECK-SAME{literal}: [[0], [1, 2, 3]]
+  // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<3x15x13x7xf32>
+  // CHECK: %[[GENERIC:.+]] = linalg.generic 
+  // CHECK-SAME: indexing_maps = [#map0, #map1]
+  // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+  // CHECK-SAME: ins(%[[COLLAPSE]] : tensor<3x7xf32>)
+  // CHECK-SAME: outs(%[[EMPTY]] : tensor<3x15x13x7xf32>)
+  // CHECK-NEXT: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+  // CHECK:   linalg.yield %[[IN]]
+  %resize = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = [2, 2, 1, 1], offset = [0, 0], border = [0, 0]} : (tensor<3x1x1x7xf32>) -> tensor<3x15x13x7xf32>
+
+  // CHECK: return %[[GENERIC]]
+  return %resize : tensor<3x15x13x7xf32>
+}
+
+// -----
+
+// CHECK: #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d3)>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: @broadcast_resize_bilinear_fp
+func.func @broadcast_resize_bilinear_fp(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x15x13x7xf32> {
+  // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %arg0
+  // CHECK-SAME{literal}: [[0], [1, 2, 3]]
+  // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<3x15x13x7xf32>
+  // CHECK: %[[GENERIC:.+]] = linalg.generic 
+  // CHECK-SAME: indexing_maps = [#map0, #map1]
+  // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+  // CHECK-SAME: ins(%[[COLLAPSE]] : tensor<3x7xf32>)
+  // CHECK-SAME: outs(%[[EMPTY]] : tensor<3x15x13x7xf32>)
+  // CHECK-NEXT: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+  // CHECK:   linalg.yield %[[IN]]
+  %resize = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = [2, 2, 1, 1], offset = [0, 0], border = [0, 0]} : (tensor<3x1x1x7xf32>) -> tensor<3x15x13x7xf32>
+
+  // CHECK: return %[[GENERIC]]
+  return %resize : tensor<3x15x13x7xf32>
+}
+
+// -----
+
+// CHECK: #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d3)>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: @broadcast_resize_nearest_i8
+func.func @broadcast_resize_nearest_i8(%arg0 : tensor<3x1x1x7xi8>) -> tensor<3x15x13x7xi8> {
+  // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %arg0
+  // CHECK-SAME{literal}: [[0], [1, 2, 3]]
+  // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<3x15x13x7xi8>
+  // CHECK: %[[GENERIC:.+]] = linalg.generic 
+  // CHECK-SAME: indexing_maps = [#map0, #map1]
+  // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+  // CHECK-SAME: ins(%[[COLLAPSE]] : tensor<3x7xi8>)
+  // CHECK-SAME: outs(%[[EMPTY]] : tensor<3x15x13x7xi8>)
+  // CHECK-NEXT: ^bb0(%[[IN:.+]]: i8, %[[OUT:.+]]: i8):
+  // CHECK:   linalg.yield %[[IN]]
+  %resize = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = [2, 2, 1, 1], offset = [0, 0], border = [0, 0]} : (tensor<3x1x1x7xi8>) -> tensor<3x15x13x7xi8>
+
+  // CHECK: return %[[GENERIC]]
+  return %resize : tensor<3x15x13x7xi8>
+}
+
+// -----
+
+// CHECK: #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d3)>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: @broadcast_resize_nearest_i32
+func.func @broadcast_resize_nearest_i32(%arg0 : tensor<3x1x1x7xi8>) -> tensor<3x15x13x7xi32> {
+  // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %arg0 
+  // CHECK-SAME{literal}: [[0], [1, 2, 3]]
+  // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<3x15x13x7xi32>
+  // CHECK: %[[GENERIC:.+]] = linalg.generic 
+  // CHECK-SAME: indexing_maps = [#map0, #map1]
+  // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+  // CHECK-SAME: ins(%[[COLLAPSE]] : tensor<3x7xi8>)
+  // CHECK-SAME: outs(%[[EMPTY]] : tensor<3x15x13x7xi32>)
+  // CHECK-NEXT: ^bb0(%[[IN:.+]]: i8, %[[OUT:.+]]: i32):
+  // CHECK:   %[[EXT:.+]] = arith.extsi %[[IN]] : i8 to i32 
+  // CHECK:   linalg.yield %[[EXT]]
+  %resize = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = [2, 2, 1, 1], offset = [0, 0], border = [0, 0]} : (tensor<3x1x1x7xi8>) -> tensor<3x15x13x7xi32>
+
+  // CHECK: return %[[GENERIC]]
+  return %resize : tensor<3x15x13x7xi32>
+}
+
+// -----
+
+// CHECK: #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d3)>
+// CHECK: #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: @broadcast_resize_bilinear_i32
+func.func @broadcast_resize_bilinear_i32(%arg0 : tensor<3x1x1x7xi8>) -> tensor<3x15x13x7xi32> {
+  // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %arg0
+  // CHECK-SAME{literal}: [[0], [1, 2, 3]]
+  // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<3x15x13x7xi32>
+  // CHECK: %[[GENERIC:.+]] = linalg.generic 
+  // CHECK-SAME: indexing_maps = [#map0, #map1]
+  // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+  // CHECK-SAME: ins(%[[COLLAPSE]] : tensor<3x7xi8>)
+  // CHECK-SAME: outs(%[[EMPTY]] : tensor<3x15x13x7xi32>)
+  // CHECK-NEXT: ^bb0(%[[IN:.+]]: i8, %[[OUT:.+]]: i32):
+  // CHECK: %[[EXT:.+]] = arith.extsi %[[IN]] : i8 to i32 
+  // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : i32
+  // CHECK: %[[MUL1:.+]] = arith.muli %[[EXT]], %[[C2]] : i32
+  // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : i32
+  // CHECK: %[[MUL2:.+]] = arith.muli %[[MUL1]], %[[C1]] : i32
+  // CHECK: linalg.yield %[[MUL2]]
+  %resize = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = [2, 2, 1, 1], offset = [0, 0], border = [0, 0]} : (tensor<3x1x1x7xi8>) -> tensor<3x15x13x7xi32>
+
+  // CHECK: return %[[GENERIC]]
+  return %resize : tensor<3x15x13x7xi32>
+}
+
+// -----
+
+// CHECK-LABEL:  @resize_nearest_int
+func.func @resize_nearest_int(%arg0: tensor<1x15x13x1xi8>) -> () {
+  // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x23x179x1xi8>
+  // CHECK: %[[GENERIC:.+]] = linalg.generic
+  // CHECK: %[[IDX_0:.+]] = linalg.index 0
+  // CHECK: %[[IDX_1:.+]] = linalg.index 1
+  // CHECK: %[[IDX_2:.+]] = linalg.index 2
+  // CHECK: %[[IDX_3:.+]] = linalg.index 3
+  // CHECK-DAG: %[[XY_MIN:.+]] = arith.constant 0
+  // CHECK-DAG: %[[Y_MAX:.+]] = arith.constant 14
+  // CHECK-DAG: %[[X_MAX:.+]] = arith.constant 12
+
+  // CHECK: %[[Y:.+]] = arith.index_cast %[[IDX_1]]
+  // CHECK: %[[X:.+]] = arith.index_cast %[[IDX_2]]
+  // CHECK-DAG: %[[SCALE_Y_N:.*]] = arith.constant 11
+  // CHECK-DAG: %[[SCALE_Y_D:.*]] = arith.constant 7
+  // CHECK-DAG: %[[SCALE_X_N:.*]] = arith.constant 89
+  // CHECK-DAG: %[[SCALE_X_D:.*]] = arith.constant 6
+  // CHECK-DAG: %[[OFFSET_Y:.*]] = arith.constant 0
+  // CHECK-DAG: %[[OFFSET_X:.*]] = arith.constant 0
+  // CHECK-DAG: %[[BORDER_Y:.*]] = arith.constant 0
+  // CHECK-DAG: %[[BORDER_X:.*]] = arith.constant 0
+
+  // find the remainder and integer component of the target index.
+
+  // CHECK: %[[TEMP_Y:.*]] = arith.muli %[[Y]], %[[SCALE_Y_D]]
+  // CHECK: %[[TEMP_X:.*]] = arith.muli %[[X]], %[[SCALE_X_D]]
+  // CHECK: %[[Y:.*]] = arith.addi %[[TEMP_Y]], %[[OFFSET_Y]]
+  // CHECK: %[[X:.*]] = arith.addi %[[TEMP_X]], %[[OFFSET_X]]
+  // CHECK: %[[I_Y:.*]] = arith.divui %[[Y]], %[[SCALE_Y_N]]
+  // CHECK: %[[I_X:.*]] = arith.divui %[[X]], %[[SCALE_X_N]]
+  // CHECK: %[[TEMP_Y:.*]] = arith.muli %[[I_Y]], %[[SCALE_Y_N]]
+  // CHECK: %[[TEMP_X:.*]] = arith.muli %[[I_X]], %[[SCALE_X_N]]
+  // CHECK: %[[D_Y:.*]] = arith.subi %[[Y]], %[[TEMP_Y]]
+  // CHECK: %[[D_X:.*]] = arith.subi %[[X]], %[[TEMP_X]]
+
+  // Round to the nearest neighor.
+
+  // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0
+  // CHECK-DAG: %[[ONE:.*]] = arith.constant 1
+  // CHECK: %[[SCALE_Y_N_HALF:.*]] = arith.shrsi %[[SCALE_Y_N]], %[[ONE]]
+  // CHECK: %[[SCALE_X_N_HALF:.*]] = arith.shrsi %[[SCALE_X_N]], %[[ONE]]
+  // CHECK: %[[PRED_Y:.*]] = arith.cmpi sge, %[[D_Y]], %[[SCALE_Y_N_HALF]]
+  // CHECK: %[[PRED_X:.*]] = arith.cmpi sge, %[[D_X]], %[[SCALE_X_N_HALF]]
+  // CHECK: %[[VAL_37:.*]] = arith.select %[[PRED_Y]], %[[ONE]], %[[ZERO]]
+  // CHECK: %[[VAL_38:.*]] = arith.select %[[PRED_X]], %[[ONE]], %[[ZERO]]
+  // CHECK: %[[VAL_39:.*]] = arith.addi %[[I_Y]], %[[VAL_37]]
+  // CHECK: %[[VAL_40:.*]] = arith.addi %[[I_X]], %[[VAL_38]]
+
+  // This section applies bound checking to be within the input image.
+
+  // CHECK: %[[VAL_41:.*]] = arith.cmpi slt, %[[VAL_39]], %[[XY_MIN]]
+  // CHECK: %[[VAL_42:.*]] = arith.select %[[VAL_41]], %[[XY_MIN]], %[[VAL_39]]
+  // CHECK: %[[VAL_43:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[VAL_39]]
+  // CHECK: %[[VAL_44:.*]] = arith.select %[[VAL_43]], %[[Y_MAX]], %[[VAL_42]]
+  // CHECK: %[[VAL_45:.*]] = arith.cmpi slt, %[[VAL_40]], %[[XY_MIN]]
+  // CHECK: %[[VAL_46:.*]] = arith.select %[[VAL_45]], %[[XY_MIN]], %[[VAL_40]]
+  // CHECK: %[[VAL_47:.*]] = arith.cmpi slt, %[[X_MAX]], %[[VAL_40]]
+  // CHECK: %[[VAL_48:.*]] = arith.select %[[VAL_47]], %[[X_MAX]], %[[VAL_46]]
+
+  // Extract the nearest value using the computed indices.
+
+  // CHECK: %[[IDY:.+]] = arith.index_cast %[[VAL_44]]
+  // CHECK: %[[IDX:.+]] = arith.index_cast %[[VAL_48]]
+  // CHECK: %[[EXTRACT:.+]] = tensor.extract %arg0[%[[IDX_0]], %[[IDY]], %[[IDX]], %[[IDX_3]]]
+  // CHECK: linalg.yield %[[EXTRACT]]
+
+  // Round to the nearest index.
+  %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = [11, 7, 89, 6], offset = [0, 0], border = [0, 0]} : (tensor<1x15x13x1xi8>) -> tensor<1x23x179x1xi8>
+           return
+}
+
+// -----
+
+// CHECK-LABEL:  @resize_bilinear_int
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
+func.func @resize_bilinear_int(%arg0: tensor<1x19x19x1xi8>) {
+  // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x289x289x1xi32>
+  // CHECK: %[[GENERIC:.+]] = linalg.generic
+  // CHECK: %[[IDX_0:.+]] = linalg.index 0
+  // CHECK: %[[IDX_1:.+]] = linalg.index 1
+  // CHECK: %[[IDX_2:.+]] = linalg.index 2
+  // CHECK: %[[IDX_3:.+]] = linalg.index 3
+  // CHECK-DAG: %[[XY_MIN:.+]] = arith.constant 0
+  // CHECK-DAG: %[[Y_MAX:.+]] = arith.constant 18
+  // CHECK-DAG: %[[X_MAX:.+]] = arith.constant 18
+  // CHECK: %[[Y:.+]] = arith.index_cast %[[IDX_1]]
+  // CHECK: %[[X:.+]] = arith.index_cast %[[IDX_2]]
+  // CHECK-DAG: %[[SCALE_Y_N:.*]] = arith.constant 16
+  // CHECK-DAG: %[[SCALE_Y_D:.*]] = arith.constant 1
+  // CHECK-DAG: %[[SCALE_X_N:.*]] = arith.constant 16
+  // CHECK-DAG: %[[SCALE_X_D:.*]] = arith.constant 1
+  // CHECK-DAG: %[[OFFSET_Y:.*]] = arith.constant 0
+  // CHECK-DAG: %[[OFFSET_X:.*]] = arith.constant 0
+  // CHECK-DAG: %[[BORDER_Y:.*]] = arith.constant 0
+  // CHECK-DAG: %[[BORDER_X:.*]] = arith.constant 0
+
+  // CHECK: %[[TEMP_Y:.*]] = arith.muli %[[Y]], %[[SCALE_Y_D]]
+  // CHECK: %[[TEMP_X:.*]] = arith.muli %[[X]], %[[SCALE_X_D]]
+  // CHECK: %[[Y:.*]] = arith.addi %[[TEMP_Y]], %[[OFFSET_Y]]
+  // CHECK: %[[X:.*]] = arith.addi %[[TEMP_X]], %[[OFFSET_X]]
+  // CHECK: %[[I_Y:.*]] = arith.divui %[[Y]], %[[SCALE_Y_N]]
+  // CHECK: %[[I_X:.*]] = arith.divui %[[X]], %[[SCALE_X_N]]
+  // CHECK: %[[TEMP_Y:.*]] = arith.muli %[[I_Y]], %[[SCALE_Y_N]]
+  // CHECK: %[[TEMP_X:.*]] = arith.muli %[[I_X]], %[[SCALE_X_N]]
+  // CHECK: %[[D_Y:.*]] = arith.subi %[[Y]], %[[TEMP_Y]]
+  // CHECK: %[[D_X:.*]] = arith.subi %[[X]], %[[TEMP_X]]
+
+  // Compute the left, right, and top indices for the bilinear interpolation.
+
+  // CHECK-DAG: %[[ONE:.*]] = arith.constant 1
+  // CHECK: %[[Y1:.*]] = arith.addi %[[I_Y]], %[[ONE]]
+  // CHECK: %[[X1:.*]] = arith.addi %[[I_X]], %[[ONE]]
+
+  // Bound check each dimension.
+
+  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_Y]], %[[XY_MIN]]
+  // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[I_Y]]
+  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[I_Y]]
+  // CHECK: %[[YLO:.*]] = arith.select %[[PRED]], %[[Y_MAX]], %[[BOUND]]
+
+  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y1]], %[[XY_MIN]]
+  // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[Y1]]
+  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[Y1]]
+  // CHECK: %[[YHI:.*]] = arith.select %[[PRED]], %[[Y_MAX]], %[[BOUND]]
+
+  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_X]], %[[XY_MIN]]
+  // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[I_X]]
+  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X_MAX]], %[[I_X]]
+  // CHECK: %[[XLO:.*]] = arith.select %[[PRED]], %[[X_MAX]], %[[BOUND]]
+
+  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X1]], %[[XY_MIN]]
+  // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[X1]]
+  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X_MAX]], %[[X1]]
+  // CHECK: %[[XHI:.*]] = arith.select %[[PRED]], %[[X_MAX]], %[[BOUND]]
+
+  // Extract each corner of the bilinear interpolation.
+
+  // CHECK: %[[YLOI:.+]] = arith.index_cast %[[YLO]]
+  // CHECK: %[[YHII:.+]] = arith.index_cast %[[YHI]]
+  // CHECK: %[[XLOI:.+]] = arith.index_cast %[[XLO]]
+  // CHECK: %[[XHII:.+]] = arith.index_cast %[[XHI]]
+
+  // CHECK: %[[LOLO:.+]] = tensor.extract %[[ARG0]][%[[IDX_0]], %[[YLOI]], %[[XLOI]], %[[IDX_3]]]
+  // CHECK: %[[LOHI:.+]] = tensor.extract %[[ARG0]][%[[IDX_0]], %[[YLOI]], %[[XHII]], %[[IDX_3]]]
+  // CHECK: %[[HILO:.+]] = tensor.extract %[[ARG0]][%[[IDX_0]], %[[YHII]], %[[XLOI]], %[[IDX_3]]]
+  // CHECK: %[[HIHI:.+]] = tensor.extract %[[ARG0]][%[[IDX_0]], %[[YHII]], %[[XHII]], %[[IDX_3]]]
+
+  // CHECK: %[[XLOLO:.+]] = arith.extsi %[[LOLO]]
+  // CHECK: %[[XLOHI:.+]] = arith.extsi %[[LOHI]]
+  // CHECK: %[[XHILO:.+]] = arith.extsi %[[HILO]]
+  // CHECK: %[[XHIHI:.+]] = arith.extsi %[[HIHI]]
+
+  // Compute the bilinear interpolation.
+
+  // CHECK: %[[NDX:.+]] = arith.subi %[[SCALE_X_N]], %[[D_X]]
+  // CHECK: %[[WLOLO:.+]] = arith.muli %[[XLOLO]], %[[NDX]]
+  // CHECK: %[[WLOHI:.+]] = arith.muli %[[XLOHI]], %[[D_X]]
+  // CHECK: %[[LO:.+]] = arith.addi %[[WLOLO]], %[[WLOHI]]
+  // CHECK: %[[WHILO:.+]] = arith.muli %[[XHILO]], %[[NDX]]
+  // CHECK: %[[WHIHI:.+]] = arith.muli %[[XHIHI]], %[[D_X]]
+  // CHECK: %[[HI:.+]] = arith.addi %[[WHILO]], %[[WHIHI]]
+  // CHECK: %[[NDY:.+]] = arith.subi %[[SCALE_Y_N]], %[[D_Y]]
+  // CHECK: %[[WLO:.+]] = arith.muli %[[LO]], %[[NDY]]
+  // CHECK: %[[WHI:.+]] = arith.muli %[[HI]], %[[D_Y]]
+  // CHECK: %[[RESULT:.+]] = arith.addi %[[WLO]], %[[WHI]]
+  // CHECK: linalg.yield %[[RESULT]]
+
+  // Round to the nearest index.
+  %0 = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = [16, 1, 16, 1], offset = [0, 0], border = [0, 0]} : (tensor<1x19x19x1xi8>) -> tensor<1x289x289x1xi32>
+           return
+}
+
+// -----
+
+// CHECK-LABEL: @resize_nearest_fp
+func.func @resize_nearest_fp(%input: tensor<1x50x48x1xf32>) -> () {
+  // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x1600x1536x1xf32>
+  // CHECK: %[[GENERIC:.+]] = linalg.generic
+  // CHECK: %[[IDX0:.+]] = linalg.index 0
+  // CHECK: %[[IDX1:.+]] = linalg.index 1
+  // CHECK: %[[IDX2:.+]] = linalg.index 2
+  // CHECK: %[[IDX3:.+]] = linalg.index 3
+  // CHECK-DAG: %[[XYMIN:.*]] = arith.constant 0
+  // CHECK-DAG: %[[YMAX:.*]] = arith.constant 49
+  // CHECK-DAG: %[[XMAX:.*]] = arith.constant 47
+  // CHECK: %[[Y:.+]] = arith.index_cast %[[IDX1]]
+  // CHECK: %[[X:.+]] = arith.index_cast %[[IDX2]]
+  // CHECK-DAG: %[[ISCALE_Y_N:.*]] = arith.constant 64
+  // CHECK-DAG: %[[ISCALE_Y_D:.*]] = arith.constant 2
+  // CHECK-DAG: %[[ISCALE_X_N:.*]] = arith.constant 64
+  // CHECK-DAG: %[[ISCALE_X_D:.*]] = arith.constant 2
+  // CHECK-DAG: %[[IOFFSET_Y:.*]] = arith.constant -31
+  // CHECK-DAG: %[[IOFFSET_X:.*]] = arith.constant -31
+  // CHECK-DAG: %[[IBORDER_Y:.*]] = arith.constant 31
+  // CHECK-DAG: %[[IBORDER_X:.*]] = arith.constant 31
+
+  // CHECK: %[[Y0:.+]] = arith.uitofp %[[Y]]
+  // CHECK: %[[X0:.+]] = arith.uitofp %[[X]]
+  // CHECK: %[[SCALE_Y_N:.*]] = arith.uitofp %[[ISCALE_Y_N]]
+  // CHECK: %[[SCALE_Y_D:.*]] = arith.uitofp %[[ISCALE_Y_D]]
+  // CHECK: %[[SCALE_X_N:.*]] = arith.uitofp %[[ISCALE_X_N]]
+  // CHECK: %[[SCALE_X_D:.*]] = arith.uitofp %[[ISCALE_X_D]]
+  // CHECK: %[[OFFSET_Y:.*]] = arith.uitofp %[[IOFFSET_Y]]
+  // CHECK: %[[OFFSET_X:.*]] = arith.uitofp %[[IOFFSET_X]]
+
+  // CHECK: %[[VAL_29:.*]] = arith.mulf %[[Y0]], %[[SCALE_Y_D]]
+  // CHECK: %[[VAL_30:.*]] = arith.mulf %[[X0]], %[[SCALE_X_D]]
+  // CHECK: %[[VAL_31:.*]] = arith.addf %[[VAL_29]], %[[OFFSET_Y]]
+  // CHECK: %[[VAL_32:.*]] = arith.addf %[[VAL_30]], %[[OFFSET_X]]
+  // CHECK: %[[VAL_33:.*]] = arith.divf %[[VAL_31]], %[[SCALE_Y_N]]
+  // CHECK: %[[VAL_34:.*]] = arith.divf %[[VAL_32]], %[[SCALE_X_N]]
+
+  // Find the remainder and integer component of the target index.
+
+  // CHECK: %[[VAL_35:.*]] = math.floor %[[VAL_33]]
+  // CHECK: %[[VAL_36:.*]] = math.floor %[[VAL_34]]
+  // CHECK: %[[D_Y:.*]] = arith.subf %[[VAL_33]], %[[VAL_35]]
+  // CHECK: %[[D_X:.*]] = arith.subf %[[VAL_34]], %[[VAL_36]]
+  // CHECK: %[[VAL_39:.*]] = arith.fptosi %[[VAL_35]]
+  // CHECK: %[[VAL_40:.*]] = arith.fptosi %[[VAL_36]]
+
+  // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0
+  // CHECK-DAG: %[[ONE:.*]] = arith.constant 1
+  // CHECK-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01
+  // CHECK: %[[PRED_Y:.*]] = arith.cmpf oge, %[[D_Y]], %[[HALF]]
+  // CHECK: %[[PRED_X:.*]] = arith.cmpf oge, %[[D_X]], %[[HALF]]
+  // CHECK: %[[ROUND_Y:.*]] = arith.select %[[PRED_Y]], %[[ONE]], %[[ZERO]]
+  // CHECK: %[[ROUND_X:.*]] = arith.select %[[PRED_X]], %[[ONE]], %[[ZERO]]
+  // CHECK: %[[VAL_48:.*]] = arith.addi %[[VAL_39]], %[[ROUND_Y]]
+  // CHECK: %[[VAL_49:.*]] = arith.addi %[[VAL_40]], %[[ROUND_X]]
+
+  // CHECK: %[[VAL_50:.*]] = arith.cmpi slt, %[[VAL_48]], %[[XYMIN]]
+  // CHECK: %[[VAL_51:.*]] = arith.select %[[VAL_50]], %[[XYMIN]], %[[VAL_48]]
+  // CHECK: %[[VAL_52:.*]] = arith.cmpi slt, %[[YMAX]], %[[VAL_48]]
+  // CHECK: %[[VAL_53:.*]] = arith.select %[[VAL_52]], %[[YMAX]], %[[VAL_51]]
+  // CHECK: %[[VAL_54:.*]] = arith.cmpi slt, %[[VAL_49]], %[[XYMIN]]
+  // CHECK: %[[VAL_55:.*]] = arith.select %[[VAL_54]], %[[XYMIN]], %[[VAL_49]]
+  // CHECK: %[[VAL_56:.*]] = arith.cmpi slt, %[[XMAX]], %[[VAL_49]]
+  // CHECK: %[[VAL_57:.*]] = arith.select %[[VAL_56]], %[[XMAX]], %[[VAL_55]]
+
+  // CHECK: %[[IDY:.*]] = arith.index_cast %[[VAL_53]]
+  // CHECK: %[[IDX:.*]] = arith.index_cast %[[VAL_57]]
+  // CHECK: %[[EXTRACT:.+]] = tensor.extract %arg0[%[[IDX0]], %[[IDY]], %[[IDX]], %[[IDX3]]]
+  // CHECK: linalg.yield %[[EXTRACT]]
+
+  %output = "tosa.resize"(%input) {mode = "NEAREST_NEIGHBOR", scale = [64, 2, 64, 2], offset = [-31, -31], border = [31, 31]} : (tensor<1x50x48x1xf32>) -> tensor<1x1600x1536x1xf32>
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @resize_bilinear_fp
+func.func @resize_bilinear_fp(%input: tensor<1x23x23x1xf32>) -> () {
+  // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x89x89x1xf32>
+  // CHECK: %[[GENERIC:.+]] = linalg.generic
+  // CHECK: %[[IDX_0:.+]] = linalg.index 0
+  // CHECK: %[[IDX_1:.+]] = linalg.index 1
+  // CHECK: %[[IDX_2:.+]] = linalg.index 2
+  // CHECK: %[[IDX_3:.+]] = linalg.index 3
+  // CHECK-DAG: %[[XY_MIN:.*]] = arith.constant 0
+  // CHECK-DAG: %[[Y_MAX:.*]] = arith.constant 22
+  // CHECK-DAG: %[[X_MAX:.*]] = arith.constant 22
+  // CHECK: %[[Y:.+]] = arith.index_cast %[[IDX_1]]
+  // CHECK: %[[X:.+]] = arith.index_cast %[[IDX_2]]
+  // CHECK-DAG: %[[ISCALE_Y_N:.*]] = arith.constant 4
+  // CHECK-DAG: %[[ISCALE_Y_D:.*]] = arith.constant 1
+  // CHECK-DAG: %[[ISCALE_X_N:.*]] = arith.constant 4
+  // CHECK-DAG: %[[ISCALE_X_D:.*]] = arith.constant 1
+  // CHECK-DAG: %[[IOFFSET_Y:.*]] = arith.constant 0
+  // CHECK-DAG: %[[IOFFSET_X:.*]] = arith.constant 0
+  // CHECK-DAG: %[[IBORDER_Y:.*]] = arith.constant 0
+  // CHECK-DAG: %[[IBORDER_X:.*]] = arith.constant 0
+
+  // CHECK: %[[Y0:.+]] = arith.uitofp %[[Y]]
+  // CHECK: %[[X0:.+]] = arith.uitofp %[[X]]
+  // CHECK: %[[SCALE_Y_N:.*]] = arith.uitofp %[[ISCALE_Y_N]]
+  // CHECK: %[[SCALE_Y_D:.*]] = arith.uitofp %[[ISCALE_Y_D]]
+  // CHECK: %[[SCALE_X_N:.*]] = arith.uitofp %[[ISCALE_X_N]]
+  // CHECK: %[[SCALE_X_D:.*]] = arith.uitofp %[[ISCALE_X_D]]
+  // CHECK: %[[OFFSET_Y:.*]] = arith.uitofp %[[IOFFSET_Y]]
+  // CHECK: %[[OFFSET_X:.*]] = arith.uitofp %[[IOFFSET_X]]
+
+  // CHECK: %[[VAL_29:.*]] = arith.mulf %[[Y0]], %[[SCALE_Y_D]]
+  // CHECK: %[[VAL_30:.*]] = arith.mulf %[[X0]], %[[SCALE_X_D]]
+  // CHECK: %[[VAL_31:.*]] = arith.addf %[[VAL_29]], %[[OFFSET_Y]]
+  // CHECK: %[[VAL_32:.*]] = arith.addf %[[VAL_30]], %[[OFFSET_X]]
+  // CHECK: %[[VAL_33:.*]] = arith.divf %[[VAL_31]], %[[SCALE_Y_N]]
+  // CHECK: %[[VAL_34:.*]] = arith.divf %[[VAL_32]], %[[SCALE_X_N]]
+
+  // CHECK: %[[VAL_35:.*]] = math.floor %[[VAL_33]]
+  // CHECK: %[[VAL_36:.*]] = math.floor %[[VAL_34]]
+  // CHECK: %[[D_Y:.*]] = arith.subf %[[VAL_33]], %[[VAL_35]]
+  // CHECK: %[[D_X:.*]] = arith.subf %[[VAL_34]], %[[VAL_36]]
+  // CHECK: %[[I_Y:.*]] = arith.fptosi %[[VAL_35]]
+  // CHECK: %[[I_X:.*]] = arith.fptosi %[[VAL_36]]
+
+  // Compute the left, right, and top indices for the bilinear interpolation.
+
+  // CHECK-DAG: %[[ONE:.*]] = arith.constant 1
+  // CHECK: %[[Y1:.*]] = arith.addi %[[I_Y]], %[[ONE]]
+  // CHECK: %[[X1:.*]] = arith.addi %[[I_X]], %[[ONE]]
+
+  // Bound check each dimension.
+
+  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_Y]], %[[XY_MIN]]
+  // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[I_Y]]
+  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[I_Y]]
+  // CHECK: %[[YLO:.*]] = arith.select %[[PRED]], %[[Y_MAX]], %[[BOUND]]
+
+  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y1]], %[[XY_MIN]]
+  // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[Y1]]
+  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[Y1]]
+  // CHECK: %[[YHI:.*]] = arith.select %[[PRED]], %[[Y_MAX]], %[[BOUND]]
+
+  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_X]], %[[XY_MIN]]
+  // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[I_X]]
+  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X_MAX]], %[[I_X]]
+  // CHECK: %[[XLO:.*]] = arith.select %[[PRED]], %[[X_MAX]], %[[BOUND]]
+
+  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X1]], %[[XY_MIN]]
+  // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[X1]]
+  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X_MAX]], %[[X1]]
+  // CHECK: %[[XHI:.*]] = arith.select %[[PRED]], %[[X_MAX]], %[[BOUND]]
+
+  // CHECK: %[[YLOI:.+]] = arith.index_cast %[[YLO]]
+  // CHECK: %[[YHII:.+]] = arith.index_cast %[[YHI]]
+  // CHECK: %[[XLOI:.+]] = arith.index_cast %[[XLO]]
+  // CHECK: %[[XHII:.+]] = arith.index_cast %[[XHI]]
+
+  // CHECK: %[[LOLO:.+]] = tensor.extract %arg0[%[[IDX_0]], %[[YLOI]], %[[XLOI]], %[[IDX_3]]]
+  // CHECK: %[[LOHI:.+]] = tensor.extract %arg0[%[[IDX_0]], %[[YLOI]], %[[XHII]], %[[IDX_3]]]
+  // CHECK: %[[HILO:.+]] = tensor.extract %arg0[%[[IDX_0]], %[[YHII]], %[[XLOI]], %[[IDX_3]]]
+  // CHECK: %[[HIHI:.+]] = tensor.extract %arg0[%[[IDX_0]], %[[YHII]], %[[XHII]], %[[IDX_3]]]
+
+  // CHECK-DAG: %[[ONE:.+]] = arith.constant 1.000000e+00 : f32
+  // CHECK: %[[NDX:.+]] = arith.subf %[[ONE]], %[[D_X]]
+  // CHECK: %[[WLOLO:.+]] = arith.mulf %[[LOLO]], %[[NDX]]
+  // CHECK: %[[WLOHI:.+]] = arith.mulf %[[LOHI]], %[[D_X]]
+  // CHECK: %[[LO:.+]] = arith.addf %[[WLOLO]], %[[WLOHI]]
+  // CHECK: %[[WHILO:.+]] = arith.mulf %[[HILO]], %[[NDX]]
+  // CHECK: %[[WHIHI:.+]] = arith.mulf %[[HIHI]], %[[D_X]]
+  // CHECK: %[[HI:.+]] = arith.addf %[[WHILO]], %[[WHIHI]]
+  // CHECK: %[[NDY:.+]] = arith.subf %[[ONE]], %[[D_Y]]
+  // CHECK: %[[WLO:.+]] = arith.mulf %[[LO]], %[[NDY]]
+  // CHECK: %[[WHI:.+]] = arith.mulf %[[HI]], %[[D_Y]]
+  // CHECK: %[[RESULT:.+]] = arith.addf %[[WLO]], %[[WHI]]
+  // CHECK: linalg.yield %[[RESULT]]
+
+  // Round by bilinear interpolation
+  %output = "tosa.resize"(%input) {mode = "BILINEAR", scale = [4, 1, 4, 1], offset = [0, 0], border = [0, 0]} : (tensor<1x23x23x1xf32>) -> tensor<1x89x89x1xf32>
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @resize_dyn
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
+func.func @resize_dyn(%input: tensor<?x2x2x1xi8>) -> () {
+  // CHECK-DAG: %[[C0:.+]] = arith.constant 0
+  // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
+  // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x4x4x1xi32>
+  // CHECK: %[[GENERIC:.+]] = linalg.generic
+  %output = "tosa.resize"(%input) { scale = [4, 2, 4, 2], offset = [-1, -1], border = [1, 1], mode = "BILINEAR" } : (tensor<?x2x2x1xi8>)  -> (tensor<?x4x4x1xi32>)
+  return
+}

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index f94104ac2a0d8..7c7a8aab3bb47 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1643,376 +1643,6 @@ func.func @table8_dyn_table(%arg0: tensor<6xi8>, %arg1: tensor<?xi8>) -> () {
 
 // -----
 
-// CHECK-LABEL:  @resize_nearest_int
-func.func @resize_nearest_int(%arg0: tensor<1x15x13x1xi8>) -> () {
-  // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x23x179x1xi8>
-  // CHECK: %[[GENERIC:.+]] = linalg.generic
-  // CHECK: %[[IDX_0:.+]] = linalg.index 0
-  // CHECK: %[[IDX_1:.+]] = linalg.index 1
-  // CHECK: %[[IDX_2:.+]] = linalg.index 2
-  // CHECK: %[[IDX_3:.+]] = linalg.index 3
-  // CHECK: %[[XY_MIN:.+]] = arith.constant 0
-  // CHECK: %[[Y_MAX:.+]] = arith.constant 14
-  // CHECK: %[[X_MAX:.+]] = arith.constant 12
-
-  // CHECK: %[[Y:.+]] = arith.index_cast %[[IDX_1]]
-  // CHECK: %[[X:.+]] = arith.index_cast %[[IDX_2]]
-  // CHECK: %[[SCALE_Y_N:.*]] = arith.constant 11
-  // CHECK: %[[SCALE_Y_D:.*]] = arith.constant 7
-  // CHECK: %[[SCALE_X_N:.*]] = arith.constant 89
-  // CHECK: %[[SCALE_X_D:.*]] = arith.constant 6
-  // CHECK: %[[OFFSET_Y:.*]] = arith.constant 0
-  // CHECK: %[[OFFSET_X:.*]] = arith.constant 0
-  // CHECK: %[[BORDER_Y:.*]] = arith.constant 0
-  // CHECK: %[[BORDER_X:.*]] = arith.constant 0
-
-  // find the remainder and integer component of the target index.
-
-  // CHECK: %[[TEMP_Y:.*]] = arith.muli %[[Y]], %[[SCALE_Y_D]]
-  // CHECK: %[[TEMP_X:.*]] = arith.muli %[[X]], %[[SCALE_X_D]]
-  // CHECK: %[[Y:.*]] = arith.addi %[[TEMP_Y]], %[[OFFSET_Y]]
-  // CHECK: %[[X:.*]] = arith.addi %[[TEMP_X]], %[[OFFSET_X]]
-  // CHECK: %[[I_Y:.*]] = arith.divui %[[Y]], %[[SCALE_Y_N]]
-  // CHECK: %[[I_X:.*]] = arith.divui %[[X]], %[[SCALE_X_N]]
-  // CHECK: %[[TEMP_Y:.*]] = arith.muli %[[I_Y]], %[[SCALE_Y_N]]
-  // CHECK: %[[TEMP_X:.*]] = arith.muli %[[I_X]], %[[SCALE_X_N]]
-  // CHECK: %[[D_Y:.*]] = arith.subi %[[Y]], %[[TEMP_Y]]
-  // CHECK: %[[D_X:.*]] = arith.subi %[[X]], %[[TEMP_X]]
-
-  // Round to the nearest neighor.
-
-  // CHECK: %[[ZERO:.*]] = arith.constant 0
-  // CHECK: %[[ONE:.*]] = arith.constant 1
-  // CHECK: %[[SCALE_Y_N_HALF:.*]] = arith.shrsi %[[SCALE_Y_N]], %[[ONE]]
-  // CHECK: %[[SCALE_X_N_HALF:.*]] = arith.shrsi %[[SCALE_X_N]], %[[ONE]]
-  // CHECK: %[[PRED_Y:.*]] = arith.cmpi sge, %[[D_Y]], %[[SCALE_Y_N_HALF]]
-  // CHECK: %[[PRED_X:.*]] = arith.cmpi sge, %[[D_X]], %[[SCALE_X_N_HALF]]
-  // CHECK: %[[VAL_37:.*]] = arith.select %[[PRED_Y]], %[[ONE]], %[[ZERO]]
-  // CHECK: %[[VAL_38:.*]] = arith.select %[[PRED_X]], %[[ONE]], %[[ZERO]]
-  // CHECK: %[[VAL_39:.*]] = arith.addi %[[I_Y]], %[[VAL_37]]
-  // CHECK: %[[VAL_40:.*]] = arith.addi %[[I_X]], %[[VAL_38]]
-
-  // This section applies bound checking to be within the input image.
-
-  // CHECK: %[[VAL_41:.*]] = arith.cmpi slt, %[[VAL_39]], %[[XY_MIN]]
-  // CHECK: %[[VAL_42:.*]] = arith.select %[[VAL_41]], %[[XY_MIN]], %[[VAL_39]]
-  // CHECK: %[[VAL_43:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[VAL_39]]
-  // CHECK: %[[VAL_44:.*]] = arith.select %[[VAL_43]], %[[Y_MAX]], %[[VAL_42]]
-  // CHECK: %[[VAL_45:.*]] = arith.cmpi slt, %[[VAL_40]], %[[XY_MIN]]
-  // CHECK: %[[VAL_46:.*]] = arith.select %[[VAL_45]], %[[XY_MIN]], %[[VAL_40]]
-  // CHECK: %[[VAL_47:.*]] = arith.cmpi slt, %[[X_MAX]], %[[VAL_40]]
-  // CHECK: %[[VAL_48:.*]] = arith.select %[[VAL_47]], %[[X_MAX]], %[[VAL_46]]
-
-  // Extract the nearest value using the computed indices.
-
-  // CHECK: %[[IDY:.+]] = arith.index_cast %[[VAL_44]]
-  // CHECK: %[[IDX:.+]] = arith.index_cast %[[VAL_48]]
-  // CHECK: %[[EXTRACT:.+]] = tensor.extract %arg0[%[[IDX_0]], %[[IDY]], %[[IDX]], %[[IDX_3]]]
-  // CHECK: linalg.yield %[[EXTRACT]]
-
-  // Round to the nearest index.
-  %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = [11, 7, 89, 6], offset = [0, 0], border = [0, 0]} : (tensor<1x15x13x1xi8>) -> tensor<1x23x179x1xi8>
-           return
-}
-
-// -----
-
-// CHECK-LABEL:  @resize_bilinear_int
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
-func.func @resize_bilinear_int(%arg0: tensor<1x19x19x1xi8>) {
-  // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x289x289x1xi32>
-  // CHECK: %[[GENERIC:.+]] = linalg.generic
-  // CHECK: %[[IDX_0:.+]] = linalg.index 0
-  // CHECK: %[[IDX_1:.+]] = linalg.index 1
-  // CHECK: %[[IDX_2:.+]] = linalg.index 2
-  // CHECK: %[[IDX_3:.+]] = linalg.index 3
-  // CHECK: %[[XY_MIN:.+]] = arith.constant 0
-  // CHECK: %[[Y_MAX:.+]] = arith.constant 18
-  // CHECK: %[[X_MAX:.+]] = arith.constant 18
-  // CHECK: %[[Y:.+]] = arith.index_cast %[[IDX_1]]
-  // CHECK: %[[X:.+]] = arith.index_cast %[[IDX_2]]
-  // CHECK: %[[SCALE_Y_N:.*]] = arith.constant 16
-  // CHECK: %[[SCALE_Y_D:.*]] = arith.constant 1
-  // CHECK: %[[SCALE_X_N:.*]] = arith.constant 16
-  // CHECK: %[[SCALE_X_D:.*]] = arith.constant 1
-  // CHECK: %[[OFFSET_Y:.*]] = arith.constant 0
-  // CHECK: %[[OFFSET_X:.*]] = arith.constant 0
-  // CHECK: %[[BORDER_Y:.*]] = arith.constant 0
-  // CHECK: %[[BORDER_X:.*]] = arith.constant 0
-
-  // CHECK: %[[TEMP_Y:.*]] = arith.muli %[[Y]], %[[SCALE_Y_D]]
-  // CHECK: %[[TEMP_X:.*]] = arith.muli %[[X]], %[[SCALE_X_D]]
-  // CHECK: %[[Y:.*]] = arith.addi %[[TEMP_Y]], %[[OFFSET_Y]]
-  // CHECK: %[[X:.*]] = arith.addi %[[TEMP_X]], %[[OFFSET_X]]
-  // CHECK: %[[I_Y:.*]] = arith.divui %[[Y]], %[[SCALE_Y_N]]
-  // CHECK: %[[I_X:.*]] = arith.divui %[[X]], %[[SCALE_X_N]]
-  // CHECK: %[[TEMP_Y:.*]] = arith.muli %[[I_Y]], %[[SCALE_Y_N]]
-  // CHECK: %[[TEMP_X:.*]] = arith.muli %[[I_X]], %[[SCALE_X_N]]
-  // CHECK: %[[D_Y:.*]] = arith.subi %[[Y]], %[[TEMP_Y]]
-  // CHECK: %[[D_X:.*]] = arith.subi %[[X]], %[[TEMP_X]]
-
-  // Compute the left, right, and top indices for the bilinear interpolation.
-
-  // CHECK: %[[ONE:.*]] = arith.constant 1
-  // CHECK: %[[Y1:.*]] = arith.addi %[[I_Y]], %[[ONE]]
-  // CHECK: %[[X1:.*]] = arith.addi %[[I_X]], %[[ONE]]
-
-  // Bound check each dimension.
-
-  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_Y]], %[[XY_MIN]]
-  // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[I_Y]]
-  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[I_Y]]
-  // CHECK: %[[YLO:.*]] = arith.select %[[PRED]], %[[Y_MAX]], %[[BOUND]]
-
-  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y1]], %[[XY_MIN]]
-  // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[Y1]]
-  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[Y1]]
-  // CHECK: %[[YHI:.*]] = arith.select %[[PRED]], %[[Y_MAX]], %[[BOUND]]
-
-  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_X]], %[[XY_MIN]]
-  // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[I_X]]
-  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X_MAX]], %[[I_X]]
-  // CHECK: %[[XLO:.*]] = arith.select %[[PRED]], %[[X_MAX]], %[[BOUND]]
-
-  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X1]], %[[XY_MIN]]
-  // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[X1]]
-  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X_MAX]], %[[X1]]
-  // CHECK: %[[XHI:.*]] = arith.select %[[PRED]], %[[X_MAX]], %[[BOUND]]
-
-  // Extract each corner of the bilinear interpolation.
-
-  // CHECK: %[[YLOI:.+]] = arith.index_cast %[[YLO]]
-  // CHECK: %[[YHII:.+]] = arith.index_cast %[[YHI]]
-  // CHECK: %[[XLOI:.+]] = arith.index_cast %[[XLO]]
-  // CHECK: %[[XHII:.+]] = arith.index_cast %[[XHI]]
-
-  // CHECK: %[[LOLO:.+]] = tensor.extract %[[ARG0]][%[[IDX_0]], %[[YLOI]], %[[XLOI]], %[[IDX_3]]]
-  // CHECK: %[[LOHI:.+]] = tensor.extract %[[ARG0]][%[[IDX_0]], %[[YLOI]], %[[XHII]], %[[IDX_3]]]
-  // CHECK: %[[HILO:.+]] = tensor.extract %[[ARG0]][%[[IDX_0]], %[[YHII]], %[[XLOI]], %[[IDX_3]]]
-  // CHECK: %[[HIHI:.+]] = tensor.extract %[[ARG0]][%[[IDX_0]], %[[YHII]], %[[XHII]], %[[IDX_3]]]
-
-  // CHECK: %[[XLOLO:.+]] = arith.extsi %[[LOLO]]
-  // CHECK: %[[XLOHI:.+]] = arith.extsi %[[LOHI]]
-  // CHECK: %[[XHILO:.+]] = arith.extsi %[[HILO]]
-  // CHECK: %[[XHIHI:.+]] = arith.extsi %[[HIHI]]
-
-  // Compute the bilinear interpolation.
-
-  // CHECK: %[[NDX:.+]] = arith.subi %[[SCALE_X_N]], %[[D_X]]
-  // CHECK: %[[WLOLO:.+]] = arith.muli %[[XLOLO]], %[[NDX]]
-  // CHECK: %[[WLOHI:.+]] = arith.muli %[[XLOHI]], %[[D_X]]
-  // CHECK: %[[LO:.+]] = arith.addi %[[WLOLO]], %[[WLOHI]]
-  // CHECK: %[[WHILO:.+]] = arith.muli %[[XHILO]], %[[NDX]]
-  // CHECK: %[[WHIHI:.+]] = arith.muli %[[XHIHI]], %[[D_X]]
-  // CHECK: %[[HI:.+]] = arith.addi %[[WHILO]], %[[WHIHI]]
-  // CHECK: %[[NDY:.+]] = arith.subi %[[SCALE_Y_N]], %[[D_Y]]
-  // CHECK: %[[WLO:.+]] = arith.muli %[[LO]], %[[NDY]]
-  // CHECK: %[[WHI:.+]] = arith.muli %[[HI]], %[[D_Y]]
-  // CHECK: %[[RESULT:.+]] = arith.addi %[[WLO]], %[[WHI]]
-  // CHECK: linalg.yield %[[RESULT]]
-
-  // Round to the nearest index.
-  %0 = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = [16, 1, 16, 1], offset = [0, 0], border = [0, 0]} : (tensor<1x19x19x1xi8>) -> tensor<1x289x289x1xi32>
-           return
-}
-
-// -----
-
-// CHECK-LABEL: @resize_nearest_fp
-func.func @resize_nearest_fp(%input: tensor<1x50x48x1xf32>) -> () {
-  // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x1600x1536x1xf32>
-  // CHECK: %[[GENERIC:.+]] = linalg.generic
-  // CHECK: %[[IDX0:.+]] = linalg.index 0
-  // CHECK: %[[IDX1:.+]] = linalg.index 1
-  // CHECK: %[[IDX2:.+]] = linalg.index 2
-  // CHECK: %[[IDX3:.+]] = linalg.index 3
-  // CHECK: %[[XYMIN:.*]] = arith.constant 0
-  // CHECK: %[[YMAX:.*]] = arith.constant 49
-  // CHECK: %[[XMAX:.*]] = arith.constant 47
-  // CHECK: %[[Y:.+]] = arith.index_cast %[[IDX1]]
-  // CHECK: %[[X:.+]] = arith.index_cast %[[IDX2]]
-  // CHECK: %[[ISCALE_Y_N:.*]] = arith.constant 64
-  // CHECK: %[[ISCALE_Y_D:.*]] = arith.constant 2
-  // CHECK: %[[ISCALE_X_N:.*]] = arith.constant 64
-  // CHECK: %[[ISCALE_X_D:.*]] = arith.constant 2
-  // CHECK: %[[IOFFSET_Y:.*]] = arith.constant -31
-  // CHECK: %[[IOFFSET_X:.*]] = arith.constant -31
-  // CHECK: %[[IBORDER_Y:.*]] = arith.constant 31
-  // CHECK: %[[IBORDER_X:.*]] = arith.constant 31
-
-  // CHECK: %[[Y0:.+]] = arith.uitofp %[[Y]]
-  // CHECK: %[[X0:.+]] = arith.uitofp %[[X]]
-  // CHECK: %[[SCALE_Y_N:.*]] = arith.uitofp %[[ISCALE_Y_N]]
-  // CHECK: %[[SCALE_Y_D:.*]] = arith.uitofp %[[ISCALE_Y_D]]
-  // CHECK: %[[SCALE_X_N:.*]] = arith.uitofp %[[ISCALE_X_N]]
-  // CHECK: %[[SCALE_X_D:.*]] = arith.uitofp %[[ISCALE_X_D]]
-  // CHECK: %[[OFFSET_Y:.*]] = arith.uitofp %[[IOFFSET_Y]]
-  // CHECK: %[[OFFSET_X:.*]] = arith.uitofp %[[IOFFSET_X]]
-
-  // CHECK: %[[VAL_29:.*]] = arith.mulf %[[Y0]], %[[SCALE_Y_D]]
-  // CHECK: %[[VAL_30:.*]] = arith.mulf %[[X0]], %[[SCALE_X_D]]
-  // CHECK: %[[VAL_31:.*]] = arith.addf %[[VAL_29]], %[[OFFSET_Y]]
-  // CHECK: %[[VAL_32:.*]] = arith.addf %[[VAL_30]], %[[OFFSET_X]]
-  // CHECK: %[[VAL_33:.*]] = arith.divf %[[VAL_31]], %[[SCALE_Y_N]]
-  // CHECK: %[[VAL_34:.*]] = arith.divf %[[VAL_32]], %[[SCALE_X_N]]
-
-  // Find the remainder and integer component of the target index.
-
-  // CHECK: %[[VAL_35:.*]] = math.floor %[[VAL_33]]
-  // CHECK: %[[VAL_36:.*]] = math.floor %[[VAL_34]]
-  // CHECK: %[[D_Y:.*]] = arith.subf %[[VAL_33]], %[[VAL_35]]
-  // CHECK: %[[D_X:.*]] = arith.subf %[[VAL_34]], %[[VAL_36]]
-  // CHECK: %[[VAL_39:.*]] = arith.fptosi %[[VAL_35]]
-  // CHECK: %[[VAL_40:.*]] = arith.fptosi %[[VAL_36]]
-
-  // CHECK: %[[ZERO:.*]] = arith.constant 0
-  // CHECK: %[[ONE:.*]] = arith.constant 1
-  // CHECK: %[[HALF:.*]] = arith.constant 5.000000e-01
-  // CHECK: %[[PRED_Y:.*]] = arith.cmpf oge, %[[D_Y]], %[[HALF]]
-  // CHECK: %[[PRED_X:.*]] = arith.cmpf oge, %[[D_X]], %[[HALF]]
-  // CHECK: %[[ROUND_Y:.*]] = arith.select %[[PRED_Y]], %[[ONE]], %[[ZERO]]
-  // CHECK: %[[ROUND_X:.*]] = arith.select %[[PRED_X]], %[[ONE]], %[[ZERO]]
-  // CHECK: %[[VAL_48:.*]] = arith.addi %[[VAL_39]], %[[ROUND_Y]]
-  // CHECK: %[[VAL_49:.*]] = arith.addi %[[VAL_40]], %[[ROUND_X]]
-
-  // CHECK: %[[VAL_50:.*]] = arith.cmpi slt, %[[VAL_48]], %[[XYMIN]]
-  // CHECK: %[[VAL_51:.*]] = arith.select %[[VAL_50]], %[[XYMIN]], %[[VAL_48]]
-  // CHECK: %[[VAL_52:.*]] = arith.cmpi slt, %[[YMAX]], %[[VAL_48]]
-  // CHECK: %[[VAL_53:.*]] = arith.select %[[VAL_52]], %[[YMAX]], %[[VAL_51]]
-  // CHECK: %[[VAL_54:.*]] = arith.cmpi slt, %[[VAL_49]], %[[XYMIN]]
-  // CHECK: %[[VAL_55:.*]] = arith.select %[[VAL_54]], %[[XYMIN]], %[[VAL_49]]
-  // CHECK: %[[VAL_56:.*]] = arith.cmpi slt, %[[XMAX]], %[[VAL_49]]
-  // CHECK: %[[VAL_57:.*]] = arith.select %[[VAL_56]], %[[XMAX]], %[[VAL_55]]
-
-  // CHECK: %[[IDY:.*]] = arith.index_cast %[[VAL_53]]
-  // CHECK: %[[IDX:.*]] = arith.index_cast %[[VAL_57]]
-  // CHECK: %[[EXTRACT:.+]] = tensor.extract %arg0[%[[IDX0]], %[[IDY]], %[[IDX]], %[[IDX3]]]
-  // CHECK: linalg.yield %[[EXTRACT]]
-
-  %output = "tosa.resize"(%input) {mode = "NEAREST_NEIGHBOR", scale = [64, 2, 64, 2], offset = [-31, -31], border = [31, 31]} : (tensor<1x50x48x1xf32>) -> tensor<1x1600x1536x1xf32>
-
-  return
-}
-
-// -----
-
-// CHECK-LABEL: @resize_bilinear_fp
-func.func @resize_bilinear_fp(%input: tensor<1x23x23x1xf32>) -> () {
-  // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x89x89x1xf32>
-  // CHECK: %[[GENERIC:.+]] = linalg.generic
-  // CHECK: %[[IDX_0:.+]] = linalg.index 0
-  // CHECK: %[[IDX_1:.+]] = linalg.index 1
-  // CHECK: %[[IDX_2:.+]] = linalg.index 2
-  // CHECK: %[[IDX_3:.+]] = linalg.index 3
-  // CHECK: %[[XY_MIN:.*]] = arith.constant 0
-  // CHECK: %[[Y_MAX:.*]] = arith.constant 22
-  // CHECK: %[[X_MAX:.*]] = arith.constant 22
-  // CHECK: %[[Y:.+]] = arith.index_cast %[[IDX_1]]
-  // CHECK: %[[X:.+]] = arith.index_cast %[[IDX_2]]
-  // CHECK: %[[ISCALE_Y_N:.*]] = arith.constant 4
-  // CHECK: %[[ISCALE_Y_D:.*]] = arith.constant 1
-  // CHECK: %[[ISCALE_X_N:.*]] = arith.constant 4
-  // CHECK: %[[ISCALE_X_D:.*]] = arith.constant 1
-  // CHECK: %[[IOFFSET_Y:.*]] = arith.constant 0
-  // CHECK: %[[IOFFSET_X:.*]] = arith.constant 0
-  // CHECK: %[[IBORDER_Y:.*]] = arith.constant 0
-  // CHECK: %[[IBORDER_X:.*]] = arith.constant 0
-
-  // CHECK: %[[Y0:.+]] = arith.uitofp %[[Y]]
-  // CHECK: %[[X0:.+]] = arith.uitofp %[[X]]
-  // CHECK: %[[SCALE_Y_N:.*]] = arith.uitofp %[[ISCALE_Y_N]]
-  // CHECK: %[[SCALE_Y_D:.*]] = arith.uitofp %[[ISCALE_Y_D]]
-  // CHECK: %[[SCALE_X_N:.*]] = arith.uitofp %[[ISCALE_X_N]]
-  // CHECK: %[[SCALE_X_D:.*]] = arith.uitofp %[[ISCALE_X_D]]
-  // CHECK: %[[OFFSET_Y:.*]] = arith.uitofp %[[IOFFSET_Y]]
-  // CHECK: %[[OFFSET_X:.*]] = arith.uitofp %[[IOFFSET_X]]
-
-  // CHECK: %[[VAL_29:.*]] = arith.mulf %[[Y0]], %[[SCALE_Y_D]]
-  // CHECK: %[[VAL_30:.*]] = arith.mulf %[[X0]], %[[SCALE_X_D]]
-  // CHECK: %[[VAL_31:.*]] = arith.addf %[[VAL_29]], %[[OFFSET_Y]]
-  // CHECK: %[[VAL_32:.*]] = arith.addf %[[VAL_30]], %[[OFFSET_X]]
-  // CHECK: %[[VAL_33:.*]] = arith.divf %[[VAL_31]], %[[SCALE_Y_N]]
-  // CHECK: %[[VAL_34:.*]] = arith.divf %[[VAL_32]], %[[SCALE_X_N]]
-
-  // CHECK: %[[VAL_35:.*]] = math.floor %[[VAL_33]]
-  // CHECK: %[[VAL_36:.*]] = math.floor %[[VAL_34]]
-  // CHECK: %[[D_Y:.*]] = arith.subf %[[VAL_33]], %[[VAL_35]]
-  // CHECK: %[[D_X:.*]] = arith.subf %[[VAL_34]], %[[VAL_36]]
-  // CHECK: %[[I_Y:.*]] = arith.fptosi %[[VAL_35]]
-  // CHECK: %[[I_X:.*]] = arith.fptosi %[[VAL_36]]
-
-  // Compute the left, right, and top indices for the bilinear interpolation.
-
-  // CHECK: %[[ONE:.*]] = arith.constant 1
-  // CHECK: %[[Y1:.*]] = arith.addi %[[I_Y]], %[[ONE]]
-  // CHECK: %[[X1:.*]] = arith.addi %[[I_X]], %[[ONE]]
-
-  // Bound check each dimension.
-
-  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_Y]], %[[XY_MIN]]
-  // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[I_Y]]
-  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[I_Y]]
-  // CHECK: %[[YLO:.*]] = arith.select %[[PRED]], %[[Y_MAX]], %[[BOUND]]
-
-  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y1]], %[[XY_MIN]]
-  // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[Y1]]
-  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[Y1]]
-  // CHECK: %[[YHI:.*]] = arith.select %[[PRED]], %[[Y_MAX]], %[[BOUND]]
-
-  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_X]], %[[XY_MIN]]
-  // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[I_X]]
-  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X_MAX]], %[[I_X]]
-  // CHECK: %[[XLO:.*]] = arith.select %[[PRED]], %[[X_MAX]], %[[BOUND]]
-
-  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X1]], %[[XY_MIN]]
-  // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[X1]]
-  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X_MAX]], %[[X1]]
-  // CHECK: %[[XHI:.*]] = arith.select %[[PRED]], %[[X_MAX]], %[[BOUND]]
-
-  // CHECK: %[[YLOI:.+]] = arith.index_cast %[[YLO]]
-  // CHECK: %[[YHII:.+]] = arith.index_cast %[[YHI]]
-  // CHECK: %[[XLOI:.+]] = arith.index_cast %[[XLO]]
-  // CHECK: %[[XHII:.+]] = arith.index_cast %[[XHI]]
-
-  // CHECK: %[[LOLO:.+]] = tensor.extract %arg0[%[[IDX_0]], %[[YLOI]], %[[XLOI]], %[[IDX_3]]]
-  // CHECK: %[[LOHI:.+]] = tensor.extract %arg0[%[[IDX_0]], %[[YLOI]], %[[XHII]], %[[IDX_3]]]
-  // CHECK: %[[HILO:.+]] = tensor.extract %arg0[%[[IDX_0]], %[[YHII]], %[[XLOI]], %[[IDX_3]]]
-  // CHECK: %[[HIHI:.+]] = tensor.extract %arg0[%[[IDX_0]], %[[YHII]], %[[XHII]], %[[IDX_3]]]
-
-  // CHECK: %[[NDX:.+]] = arith.subf %[[SCALE_X_N]], %[[D_X]]
-  // CHECK: %[[WLOLO:.+]] = arith.mulf %[[LOLO]], %[[NDX]]
-  // CHECK: %[[WLOHI:.+]] = arith.mulf %[[LOHI]], %[[D_X]]
-  // CHECK: %[[LO:.+]] = arith.addf %[[WLOLO]], %[[WLOHI]]
-  // CHECK: %[[WHILO:.+]] = arith.mulf %[[HILO]], %[[NDX]]
-  // CHECK: %[[WHIHI:.+]] = arith.mulf %[[HIHI]], %[[D_X]]
-  // CHECK: %[[HI:.+]] = arith.addf %[[WHILO]], %[[WHIHI]]
-  // CHECK: %[[NDY:.+]] = arith.subf %[[SCALE_Y_N]], %[[D_Y]]
-  // CHECK: %[[WLO:.+]] = arith.mulf %[[LO]], %[[NDY]]
-  // CHECK: %[[WHI:.+]] = arith.mulf %[[HI]], %[[D_Y]]
-  // CHECK: %[[RESULT:.+]] = arith.addf %[[WLO]], %[[WHI]]
-  // CHECK: linalg.yield %[[RESULT]]
-
-  // Round by bilinear interpolation
-  %output = "tosa.resize"(%input) {mode = "BILINEAR", scale = [4, 1, 4, 1], offset = [0, 0], border = [0, 0]} : (tensor<1x23x23x1xf32>) -> tensor<1x89x89x1xf32>
-
-  return
-}
-
-// -----
-
-// CHECK-LABEL: @resize_dyn
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
-func.func @resize_dyn(%input: tensor<?x2x2x1xi8>) -> () {
-  // CHECK: %[[C0:.+]] = arith.constant 0
-  // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
-  // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x4x4x1xi32>
-  // CHECK: %[[GENERIC:.+]] = linalg.generic
-  %output = "tosa.resize"(%input) { scale = [4, 2, 4, 2], offset = [-1, -1], border = [1, 1], mode = "BILINEAR" } : (tensor<?x2x2x1xi8>)  -> (tensor<?x4x4x1xi32>)
-  return
-}
-
-// -----
-
 // Regression test for using the wrong rank.
 
 // CHECK-DAG: affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>


        


More information about the Mlir-commits mailing list