[Mlir-commits] [mlir] 78503e1 - [mlir][tosa] Refactor tosa.resize

Rob Suderman llvmlistbot at llvm.org
Mon Dec 12 14:39:47 PST 2022


Author: Rob Suderman
Date: 2022-12-12T14:38:38-08:00
New Revision: 78503e1a2f505ab4e7008d7df81eca76a546d203

URL: https://github.com/llvm/llvm-project/commit/78503e1a2f505ab4e7008d7df81eca76a546d203
DIFF: https://github.com/llvm/llvm-project/commit/78503e1a2f505ab4e7008d7df81eca76a546d203.diff

LOG: [mlir][tosa] Refactor tosa.resize

Moved to using helper lambdas to avoid code repetition. IR needed to be reordered to
accommodate which should be the only changes to the existing tests.

This changes the quantized test to target `i48` types to guarantee types are extended
correctly when necessary.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D136500

Added: 
    mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
    mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
    mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir

Removed: 
    mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
similarity index 90%
rename from mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h
rename to mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
index 13ead1cde4d43..5c30ddf921dc7 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
@@ -30,13 +30,13 @@ SmallVector<Value> condenseValues(const SmallVector<Value> &values);
 
 // Takes the parameters for a clamp and turns it into a series of ops for float
 // inputs.
-Value clampFloatHelper(Location loc, Value arg, arith::ConstantOp min,
-                       arith::ConstantOp max, OpBuilder &rewriter);
+Value clampFloatHelper(Location loc, Value arg, Value min, Value max,
+                       OpBuilder &rewriter);
 
 // Takes the parameters for a clamp and turns it into a series of ops for
 // integer inputs.
-Value clampIntHelper(Location loc, Value arg, arith::ConstantOp min,
-                     arith::ConstantOp max, OpBuilder &rewriter);
+Value clampIntHelper(Location loc, Value arg, Value min, Value max,
+                     OpBuilder &rewriter);
 
 // Returns the values in an attribute as an array of values.
 template <typename T>

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index abe5f53374813..d704b5e040916 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -18,7 +18,7 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/Utils/Utils.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
-#include "mlir/Dialect/Tosa/Utils/CoversionUtils.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/Matchers.h"
@@ -177,10 +177,10 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
     auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
 
     // Clamp to the negation range.
-    auto min = rewriter.create<arith::ConstantIntOp>(
+    Value min = rewriter.create<arith::ConstantIntOp>(
         loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
         intermediateType);
-    auto max = rewriter.create<arith::ConstantIntOp>(
+    Value max = rewriter.create<arith::ConstantIntOp>(
         loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
         intermediateType);
     auto clamp = clampIntHelper(loc, sub, min, max, rewriter);
@@ -1431,10 +1431,11 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
   LogicalResult matchAndRewrite(tosa::ResizeOp op,
                                 PatternRewriter &rewriter) const final {
     Location loc = op.getLoc();
+    ImplicitLocOpBuilder b(loc, rewriter);
     auto input = op.getInput();
     auto inputTy = input.getType().cast<ShapedType>();
     auto resultTy = op.getType().cast<ShapedType>();
-    auto resultElementTy = resultTy.getElementType();
+    auto resultETy = resultTy.getElementType();
 
     auto imageH = inputTy.getShape()[1];
     auto imageW = inputTy.getShape()[2];
@@ -1444,284 +1445,229 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
     if (!dynamicDimsOr.has_value())
       return rewriter.notifyMatchFailure(
           op, "unable to get dynamic dimensions of tosa.resize");
-    SmallVector<Value> dynamicDims = dynamicDimsOr.value();
 
     if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
       return rewriter.notifyMatchFailure(
           op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
 
-    auto emptyTensor = rewriter.create<tensor::EmptyOp>(
-        loc, resultTy.getShape(), resultElementTy, dynamicDims);
-
     SmallVector<AffineMap, 2> affineMaps = {
         rewriter.getMultiDimIdentityMap(resultTy.getRank())};
-
-    Value resize = input;
-    auto genericOp = rewriter.create<linalg::GenericOp>(
-        loc, resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps,
+    auto emptyTensor = b.create<tensor::EmptyOp>(resultTy.getShape(), resultETy,
+                                                 dynamicDimsOr.value());
+    auto genericOp = b.create<linalg::GenericOp>(
+        resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps,
         getNParallelLoopsAttrs(resultTy.getRank()));
-    resize = genericOp.getResult(0);
-
-    OpBuilder::InsertionGuard regionGuard(rewriter);
-    rewriter.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
-                         TypeRange({resultElementTy}), loc);
-    Value batch = rewriter.create<linalg::IndexOp>(loc, 0);
-    Value y = rewriter.create<linalg::IndexOp>(loc, 1);
-    Value x = rewriter.create<linalg::IndexOp>(loc, 2);
-    Value channel = rewriter.create<linalg::IndexOp>(loc, 3);
-
-    auto hwMin =
-        rewriter.create<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
-    auto hMax = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI32IntegerAttr(imageH - 1));
-    auto wMax = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI32IntegerAttr(imageW - 1));
-
-    Value inY =
-        rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), y);
-    Value inX =
-        rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), x);
-
-    bool floatingPointMode = resultElementTy.isF32();
-
-    Value yScaleN, yScaleD, xScaleN, xScaleD, yOffset, xOffset, yBorder,
-        xBorder;
-    SmallVector<int32_t> scale, offset, border;
-    getValuesFromIntArrayAttribute(op.getScale(), scale);
-    getValuesFromIntArrayAttribute(op.getOffset(), offset);
-    getValuesFromIntArrayAttribute(op.getBorder(), border);
-
-    yScaleN = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI32IntegerAttr(scale[0]));
-    yScaleD = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI32IntegerAttr(scale[1]));
-    xScaleN = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI32IntegerAttr(scale[2]));
-    xScaleD = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI32IntegerAttr(scale[3]));
-    yOffset = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI32IntegerAttr(offset[0]));
-    xOffset = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI32IntegerAttr(offset[1]));
-    yBorder = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI32IntegerAttr(border[0]));
-    xBorder = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI32IntegerAttr(border[1]));
-
-    // Compute the the integer index and partial offset.
-    Value ix, iy, dx, dy;
-    // x = x * scale_d + offset;
-    // ix = floor(x / scale_n)
-    if (floatingPointMode) {
-      // dx = x / scale_n - ix
-      Value y =
-          rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), inY);
-      Value x =
-          rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), inX);
-
-      yScaleN =
-          rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), yScaleN);
-      yScaleD =
-          rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), yScaleD);
-      xScaleN =
-          rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), xScaleN);
-      xScaleD =
-          rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), xScaleD);
-      yOffset =
-          rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), yOffset);
-      xOffset =
-          rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), xOffset);
-
-      y = rewriter.create<arith::MulFOp>(loc, y, yScaleD);
-      x = rewriter.create<arith::MulFOp>(loc, x, xScaleD);
-
-      y = rewriter.create<arith::AddFOp>(loc, y, yOffset);
-      x = rewriter.create<arith::AddFOp>(loc, x, xOffset);
-
-      y = rewriter.create<arith::DivFOp>(loc, y, yScaleN);
-      x = rewriter.create<arith::DivFOp>(loc, x, xScaleN);
-
-      iy = rewriter.create<math::FloorOp>(loc, y);
-      ix = rewriter.create<math::FloorOp>(loc, x);
-
-      dy = rewriter.create<arith::SubFOp>(loc, y, iy);
-      dx = rewriter.create<arith::SubFOp>(loc, x, ix);
-
-      iy = rewriter.create<arith::FPToSIOp>(loc, rewriter.getI32Type(), iy);
-      ix = rewriter.create<arith::FPToSIOp>(loc, rewriter.getI32Type(), ix);
-    } else {
-      //  dx = x - ix * scale_n;
-      Value y = rewriter.create<arith::MulIOp>(loc, inY, yScaleD);
-      Value x = rewriter.create<arith::MulIOp>(loc, inX, xScaleD);
-
-      y = rewriter.create<arith::AddIOp>(loc, y, yOffset);
-      x = rewriter.create<arith::AddIOp>(loc, x, xOffset);
+    Value resize = genericOp.getResult(0);
 
-      iy = rewriter.create<arith::DivSIOp>(loc, y, yScaleN);
-      ix = rewriter.create<arith::DivSIOp>(loc, x, xScaleN);
-
-      Value tempY = rewriter.create<arith::MulIOp>(loc, iy, yScaleN);
-      Value tempX = rewriter.create<arith::MulIOp>(loc, ix, xScaleN);
-
-      dy = rewriter.create<arith::SubIOp>(loc, y, tempY);
-      dx = rewriter.create<arith::SubIOp>(loc, x, tempX);
-    }
-
-    if (op.getMode() == "NEAREST_NEIGHBOR") {
-      Value yPred, xPred;
-      auto zeroVal = rewriter.create<arith::ConstantOp>(
-          loc, rewriter.getI32IntegerAttr(0));
-      auto oneVal = rewriter.create<arith::ConstantOp>(
-          loc, rewriter.getI32IntegerAttr(1));
-
-      // Round the index position towards the closest pixel location.
+    {
+      OpBuilder::InsertionGuard regionGuard(b);
+      b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
+                    TypeRange({resultETy}), loc);
+      Value batch = b.create<linalg::IndexOp>(0);
+      Value y = b.create<linalg::IndexOp>(1);
+      Value x = b.create<linalg::IndexOp>(2);
+      Value channel = b.create<linalg::IndexOp>(3);
+
+      Value zeroI32 = b.create<arith::ConstantOp>(b.getI32IntegerAttr(0));
+      Value hMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageH - 1));
+      Value wMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageW - 1));
+
+      Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y);
+      Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
+
+      bool floatingPointMode = resultETy.isF32();
+
+      SmallVector<int32_t> scale, offset, border;
+      getValuesFromIntArrayAttribute(op.getScale(), scale);
+      getValuesFromIntArrayAttribute(op.getOffset(), offset);
+      getValuesFromIntArrayAttribute(op.getBorder(), border);
+
+      Value yScaleN, yScaleD, xScaleN, xScaleD;
+      yScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[0]));
+      yScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[1]));
+      xScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[2]));
+      xScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[3]));
+
+      Value yOffset, xOffset, yBorder, xBorder;
+      yOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[0]));
+      xOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[1]));
+      yBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[0]));
+      xBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[1]));
+
+      // Compute the ix and dx values for both the X and Y dimensions.
+      auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in,
+                                    Value scaleN, Value scaleD, Value offset,
+                                    int size, ImplicitLocOpBuilder &b) {
+        // x = x * scale_d + offset;
+        // ix = floor(x / scale_n)
+        // dx = x / scale_n - ix
+        Value val = b.create<arith::UIToFPOp>(b.getF32Type(), in);
+        scaleN = b.create<arith::UIToFPOp>(b.getF32Type(), scaleN);
+        scaleD = b.create<arith::UIToFPOp>(b.getF32Type(), scaleD);
+        offset = b.create<arith::UIToFPOp>(b.getF32Type(), offset);
+        val = b.create<arith::MulFOp>(val, scaleD);
+        val = b.create<arith::AddFOp>(val, offset);
+        val = b.create<arith::DivFOp>(val, scaleN);
+        index = b.create<math::FloorOp>(val);
+        delta = b.create<arith::SubFOp>(val, index);
+        index = b.create<arith::FPToSIOp>(b.getI32Type(), index);
+      };
+
+      // Compute the ix and dx values for the X and Y dimensions - int case.
+      auto getIndexAndDeltaInt = [&](Value &index, Value &delta, Value in,
+                                     Value scaleN, Value scaleD, Value offset,
+                                     int size, ImplicitLocOpBuilder &b) {
+        // x = x * scale_d + offset;
+        // ix = floor(x / scale_n)
+        //  dx = x - ix * scale_n;
+        Value val = b.create<arith::MulIOp>(in, scaleD);
+        val = b.create<arith::AddIOp>(val, offset);
+        index = b.create<arith::DivSIOp>(val, scaleN);
+        delta = b.create<arith::MulIOp>(index, scaleN);
+        delta = b.create<arith::SubIOp>(val, delta);
+      };
+
+      Value ix, iy, dx, dy;
       if (floatingPointMode) {
-        auto halfVal = rewriter.create<arith::ConstantOp>(
-            loc, rewriter.getF32FloatAttr(0.5f));
-        yPred = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
-                                               dy, halfVal);
-        xPred = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
-                                               dx, halfVal);
+        getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
+        getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
       } else {
-        Value dyDoubled = rewriter.create<arith::ShLIOp>(loc, dy, oneVal);
-        Value dxDoubled = rewriter.create<arith::ShLIOp>(loc, dx, oneVal);
-        yPred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
-                                               dyDoubled, yScaleN);
-        xPred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
-                                               dxDoubled, xScaleN);
+        getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
+        getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
       }
 
-      auto yOffset =
-          rewriter.create<arith::SelectOp>(loc, yPred, oneVal, zeroVal);
-      auto xOffset =
-          rewriter.create<arith::SelectOp>(loc, xPred, oneVal, zeroVal);
-
-      iy = rewriter.create<arith::AddIOp>(loc, iy, yOffset);
-      ix = rewriter.create<arith::AddIOp>(loc, ix, xOffset);
-
-      // Clamp the to be within the bounds of the input image.
-      iy = clampIntHelper(loc, iy, hwMin, hMax, rewriter);
-      ix = clampIntHelper(loc, ix, hwMin, wMax, rewriter);
-
-      // Read the value from the input array.
-      iy =
-          rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), iy);
-      ix =
-          rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), ix);
+      if (op.getMode() == "NEAREST_NEIGHBOR") {
+        auto one = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
 
-      Value result = rewriter.create<tensor::ExtractOp>(
-          loc, input, ValueRange{batch, iy, ix, channel});
-
-      rewriter.create<linalg::YieldOp>(loc, result);
-    } else {
-      // The mode here must be BILINEAR.
-      assert(op.getMode() == "BILINEAR");
-      Value y0 = iy;
-      Value x0 = ix;
-
-      auto oneVal = rewriter.create<arith::ConstantOp>(
-          loc, rewriter.getI32IntegerAttr(1));
-      Value y1 = rewriter.create<arith::AddIOp>(loc, y0, oneVal);
-      Value x1 = rewriter.create<arith::AddIOp>(loc, x0, oneVal);
-
-      y0 = clampIntHelper(loc, y0, hwMin, hMax, rewriter);
-      y1 = clampIntHelper(loc, y1, hwMin, hMax, rewriter);
-
-      x0 = clampIntHelper(loc, x0, hwMin, wMax, rewriter);
-      x1 = clampIntHelper(loc, x1, hwMin, wMax, rewriter);
-
-      y0 =
-          rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), y0);
-      y1 =
-          rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), y1);
-      x0 =
-          rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), x0);
-      x1 =
-          rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), x1);
-
-      Value y0x0 = rewriter.create<tensor::ExtractOp>(
-          loc, input, ValueRange{batch, y0, x0, channel});
-      Value y0x1 = rewriter.create<tensor::ExtractOp>(
-          loc, input, ValueRange{batch, y0, x1, channel});
-      Value y1x0 = rewriter.create<tensor::ExtractOp>(
-          loc, input, ValueRange{batch, y1, x0, channel});
-      Value y1x1 = rewriter.create<tensor::ExtractOp>(
-          loc, input, ValueRange{batch, y1, x1, channel});
+        auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
+                                           Value max, int size,
+                                           ImplicitLocOpBuilder &b) -> Value {
+          if (size == 1) {
+            return b.create<arith::ConstantIndexOp>(0);
+          }
 
-      if (floatingPointMode) {
-        Value rightPart = dx;
-        auto oneVal = rewriter.create<arith::ConstantOp>(
-            loc, rewriter.getF32FloatAttr(1.0f));
-        Value leftPart = rewriter.create<arith::SubFOp>(loc, oneVal, dx);
+          Value pred;
+          if (floatingPointMode) {
+            auto h = b.create<arith::ConstantOp>(b.getF32FloatAttr(0.5f));
+            pred = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, dval, h);
+          } else {
+            Value dvalDouble = b.create<arith::ShLIOp>(dval, one);
+            pred = b.create<arith::CmpIOp>(arith::CmpIPredicate::sge,
+                                           dvalDouble, scale);
+          }
 
-        y0x0 = rewriter.create<arith::MulFOp>(loc, y0x0, leftPart);
-        y0x1 = rewriter.create<arith::MulFOp>(loc, y0x1, rightPart);
-        Value topAcc = rewriter.create<arith::AddFOp>(loc, y0x0, y0x1);
+          auto offset = b.create<arith::SelectOp>(pred, one, zeroI32);
+          val = b.create<arith::AddIOp>(val, offset);
+          val = clampIntHelper(loc, val, zeroI32, max, b);
+          return b.create<arith::IndexCastOp>(b.getIndexType(), val);
+        };
 
-        y1x0 = rewriter.create<arith::MulFOp>(loc, y1x0, leftPart);
-        y1x1 = rewriter.create<arith::MulFOp>(loc, y1x1, rightPart);
-        Value bottomAcc = rewriter.create<arith::AddFOp>(loc, y1x0, y1x1);
+        iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b);
+        ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b);
 
-        Value bottomPart = dy;
-        Value topPart = rewriter.create<arith::SubFOp>(loc, oneVal, dy);
-        topAcc = rewriter.create<arith::MulFOp>(loc, topAcc, topPart);
-        bottomAcc = rewriter.create<arith::MulFOp>(loc, bottomAcc, bottomPart);
-        Value result = rewriter.create<arith::AddFOp>(loc, topAcc, bottomAcc);
+        Value result = b.create<tensor::ExtractOp>(
+            input, ValueRange{batch, iy, ix, channel});
 
-        rewriter.create<linalg::YieldOp>(loc, result);
+        b.create<linalg::YieldOp>(result);
       } else {
-        // Perform in quantized space.
-        y0x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x0);
-        y0x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x1);
-        y1x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x0);
-        y1x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x1);
-
-        if (resultElementTy.getIntOrFloatBitWidth() > 32) {
-          dx = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dx);
-          dy = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dy);
-        }
-
-        Value xScaleNExt = xScaleN;
-        Value yScaleNExt = yScaleN;
-
-        if (xScaleN.getType() != resultElementTy)
-          xScaleNExt =
-              rewriter.create<arith::ExtSIOp>(loc, resultElementTy, xScaleN);
-
-        if (yScaleN.getType() != resultElementTy)
-          yScaleNExt =
-              rewriter.create<arith::ExtSIOp>(loc, resultElementTy, yScaleN);
-
-        Value topAcc, bottomAcc;
-        if (imageW == 1) {
-          topAcc = rewriter.create<arith::MulIOp>(loc, y0x0, xScaleNExt);
-          bottomAcc = rewriter.create<arith::MulIOp>(loc, y1x0, xScaleNExt);
+        // The mode here must be BILINEAR.
+        assert(op.getMode() == "BILINEAR");
+
+        auto oneVal = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
+
+        auto getClampedIdxs = [&](Value &val0, Value &val1, int size, Value in,
+                                  Value max, ImplicitLocOpBuilder &b) {
+          val0 = in;
+          val1 = b.create<arith::AddIOp>(val0, oneVal);
+          val0 = clampIntHelper(loc, val0, zeroI32, max, b);
+          val1 = clampIntHelper(loc, val1, zeroI32, max, b);
+          val0 = b.create<arith::IndexCastOp>(b.getIndexType(), val0);
+          val1 = b.create<arith::IndexCastOp>(b.getIndexType(), val1);
+        };
+
+        // Linalg equivalent to the section below:
+        //    int16_t iy0 = apply_max(iy, 0);
+        //    int16_t iy1 = apply_min(iy + 1, IH - 1);
+        //    int16_t ix0 = apply_max(ix, 0);
+        //    int16_t ix1 = apply_min(ix + 1, IW - 1);
+        Value x0, x1, y0, y1;
+        getClampedIdxs(y0, y1, imageH, iy, hMax, b);
+        getClampedIdxs(x0, x1, imageW, ix, wMax, b);
+
+        Value y0x0 = b.create<tensor::ExtractOp>(
+            input, ValueRange{batch, y0, x0, channel});
+        Value y0x1 = b.create<tensor::ExtractOp>(
+            input, ValueRange{batch, y0, x1, channel});
+        Value y1x0 = b.create<tensor::ExtractOp>(
+            input, ValueRange{batch, y1, x0, channel});
+        Value y1x1 = b.create<tensor::ExtractOp>(
+            input, ValueRange{batch, y1, x1, channel});
+
+        if (floatingPointMode) {
+          auto oneVal = b.create<arith::ConstantOp>(b.getF32FloatAttr(1.0f));
+          auto interpolate = [&](Value val0, Value val1, Value delta,
+                                 ImplicitLocOpBuilder &b) -> Value {
+            Value oneMinusDelta = b.create<arith::SubFOp>(oneVal, delta);
+            Value mul0 = b.create<arith::MulFOp>(val0, oneMinusDelta);
+            Value mul1 = b.create<arith::MulFOp>(val1, delta);
+            return b.create<arith::AddFOp>(mul0, mul1);
+          };
+
+          // Linalg equivalent to the section below:
+          //   topAcc = v00 * (unit_x - dx);
+          //   topAcc += v01 * dx;
+          Value topAcc = interpolate(y0x0, y0x1, dx, b);
+
+          // Linalg equivalent to the section below:
+          //   bottomAcc = v10 * (unit_x - dx);
+          //   bottomAcc += v11 * dx;
+          Value bottomAcc = interpolate(y1x0, y1x1, dx, b);
+
+          // Linalg equivalent to the section below:
+          //   result = topAcc * (unit_y - dy) + bottomAcc * dy
+          Value result = interpolate(topAcc, bottomAcc, dy, b);
+          b.create<linalg::YieldOp>(result);
         } else {
-          Value rightPart = dx;
-          Value leftPart = rewriter.create<arith::SubIOp>(loc, xScaleNExt, dx);
+          // Perform in quantized space.
+          y0x0 = b.create<arith::ExtSIOp>(resultETy, y0x0);
+          y0x1 = b.create<arith::ExtSIOp>(resultETy, y0x1);
+          y1x0 = b.create<arith::ExtSIOp>(resultETy, y1x0);
+          y1x1 = b.create<arith::ExtSIOp>(resultETy, y1x1);
+
+          const int64_t deltaBitwidth = dx.getType().getIntOrFloatBitWidth();
+          if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
+            dx = b.create<arith::ExtSIOp>(resultETy, dx);
+            dy = b.create<arith::ExtSIOp>(resultETy, dy);
+          }
 
-          y0x0 = rewriter.create<arith::MulIOp>(loc, y0x0, leftPart);
-          y0x1 = rewriter.create<arith::MulIOp>(loc, y0x1, rightPart);
-          topAcc = rewriter.create<arith::AddIOp>(loc, y0x0, y0x1);
+          Value yScaleNExt = yScaleN;
+          Value xScaleNExt = xScaleN;
 
-          y1x0 = rewriter.create<arith::MulIOp>(loc, y1x0, leftPart);
-          y1x1 = rewriter.create<arith::MulIOp>(loc, y1x1, rightPart);
-          bottomAcc = rewriter.create<arith::AddIOp>(loc, y1x0, y1x1);
-        }
+          const int64_t scaleBitwidth =
+              xScaleN.getType().getIntOrFloatBitWidth();
+          if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
+            yScaleNExt = b.create<arith::ExtSIOp>(resultETy, yScaleN);
+            xScaleNExt = b.create<arith::ExtSIOp>(resultETy, xScaleN);
+          }
 
-        Value result;
-        if (imageH == 1) {
-          result = rewriter.create<arith::MulIOp>(loc, topAcc, yScaleNExt);
-        } else {
-          Value bottomPart = dy;
-          Value topPart = rewriter.create<arith::SubIOp>(loc, yScaleNExt, dy);
-          topAcc = rewriter.create<arith::MulIOp>(loc, topAcc, topPart);
-          bottomAcc =
-              rewriter.create<arith::MulIOp>(loc, bottomAcc, bottomPart);
-          result = rewriter.create<arith::AddIOp>(loc, topAcc, bottomAcc);
+          auto interpolate = [](Value val0, Value val1, Value weight0,
+                                Value weight1,
+                                ImplicitLocOpBuilder &b) -> Value {
+            Value mul0 = b.create<arith::MulIOp>(val0, weight0);
+            Value mul1 = b.create<arith::MulIOp>(val1, weight1);
+            return b.create<arith::AddIOp>(mul0, mul1);
+          };
+
+          Value weight0 = b.create<arith::SubIOp>(xScaleNExt, dx);
+          Value weight1 = dx;
+          Value topAcc = interpolate(y0x0, y0x1, weight0, weight1, b);
+          Value bottomAcc = interpolate(y1x0, y1x1, weight0, weight1, b);
+
+          weight0 = b.create<arith::SubIOp>(yScaleNExt, dy);
+          weight1 = dy;
+          Value result = interpolate(topAcc, bottomAcc, weight0, weight1, b);
+          b.create<linalg::YieldOp>(result);
         }
-
-        rewriter.create<linalg::YieldOp>(loc, result);
       }
     }
 

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index b7c817bb60d7a..89cfd7c5a04cc 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -18,7 +18,7 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/Utils/Utils.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
-#include "mlir/Dialect/Tosa/Utils/CoversionUtils.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 9016e41ee1b7e..215a4cc1df50b 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -14,7 +14,7 @@
 #include "mlir/Dialect/Quant/QuantOps.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
-#include "mlir/Dialect/Tosa/Utils/CoversionUtils.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
 #include "mlir/IR/BuiltinTypes.h"

diff  --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index a9c77c696f030..c511ca1a6c76e 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -10,7 +10,7 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Tosa/Utils/CoversionUtils.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
 
 using namespace mlir;
 using namespace mlir::tosa;
@@ -30,15 +30,14 @@ mlir::tosa::condenseValues(const SmallVector<Value> &values) {
   return condensedValues;
 }
 
-Value mlir::tosa::clampFloatHelper(Location loc, Value arg,
-                                   arith::ConstantOp min, arith::ConstantOp max,
-                                   OpBuilder &rewriter) {
+Value mlir::tosa::clampFloatHelper(Location loc, Value arg, Value min,
+                                   Value max, OpBuilder &rewriter) {
   Value minValue = rewriter.create<arith::MinFOp>(loc, arg, max);
   return rewriter.create<arith::MaxFOp>(loc, minValue, min);
 }
 
-Value mlir::tosa::clampIntHelper(Location loc, Value arg, arith::ConstantOp min,
-                                 arith::ConstantOp max, OpBuilder &rewriter) {
+Value mlir::tosa::clampIntHelper(Location loc, Value arg, Value min, Value max,
+                                 OpBuilder &rewriter) {
   auto smallerThanMin =
       rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, arg, min);
   auto minOrArg =

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir
index 3582b6f254217..3aa6d2aac7623 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir
@@ -124,7 +124,7 @@ func.func @resize_nearest_int(%arg0: tensor<1x15x13x1xi8>) -> () {
   // CHECK: %[[IDX_1:.+]] = linalg.index 1
   // CHECK: %[[IDX_2:.+]] = linalg.index 2
   // CHECK: %[[IDX_3:.+]] = linalg.index 3
-  // CHECK-DAG: %[[XY_MIN:.+]] = arith.constant 0
+  // CHECK-DAG: %[[ZERO:.+]] = arith.constant 0
   // CHECK-DAG: %[[Y_MAX:.+]] = arith.constant 14
   // CHECK-DAG: %[[X_MAX:.+]] = arith.constant 12
 
@@ -142,66 +142,62 @@ func.func @resize_nearest_int(%arg0: tensor<1x15x13x1xi8>) -> () {
   // find the remainder and integer component of the target index.
 
   // CHECK: %[[TEMP_Y:.*]] = arith.muli %[[Y]], %[[SCALE_Y_D]]
-  // CHECK: %[[TEMP_X:.*]] = arith.muli %[[X]], %[[SCALE_X_D]]
   // CHECK: %[[Y:.*]] = arith.addi %[[TEMP_Y]], %[[OFFSET_Y]]
-  // CHECK: %[[X:.*]] = arith.addi %[[TEMP_X]], %[[OFFSET_X]]
   // CHECK: %[[I_Y:.*]] = arith.divsi %[[Y]], %[[SCALE_Y_N]]
-  // CHECK: %[[I_X:.*]] = arith.divsi %[[X]], %[[SCALE_X_N]]
   // CHECK: %[[TEMP_Y:.*]] = arith.muli %[[I_Y]], %[[SCALE_Y_N]]
-  // CHECK: %[[TEMP_X:.*]] = arith.muli %[[I_X]], %[[SCALE_X_N]]
   // CHECK: %[[D_Y:.*]] = arith.subi %[[Y]], %[[TEMP_Y]]
-  // CHECK: %[[D_X:.*]] = arith.subi %[[X]], %[[TEMP_X]]
 
-  // Round to the nearest neighor.
+  // CHECK: %[[TEMP_X:.*]] = arith.muli %[[X]], %[[SCALE_X_D]]
+  // CHECK: %[[X:.*]] = arith.addi %[[TEMP_X]], %[[OFFSET_X]]
+  // CHECK: %[[I_X:.*]] = arith.divsi %[[X]], %[[SCALE_X_N]]
+  // CHECK: %[[TEMP_X:.*]] = arith.muli %[[I_X]], %[[SCALE_X_N]]
+  // CHECK: %[[D_X:.*]] = arith.subi %[[X]], %[[TEMP_X]]
 
-  // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0
+  // Compute the offset and bound for the Y position.
   // CHECK-DAG: %[[ONE:.*]] = arith.constant 1
   // CHECK: %[[D_Y_DOUBLE:.*]] = arith.shli %[[D_Y]], %[[ONE]]
-  // CHECK: %[[D_X_DOUBLE:.*]] = arith.shli %[[D_X]], %[[ONE]]
   // CHECK: %[[PRED_Y:.*]] = arith.cmpi sge, %[[D_Y_DOUBLE]], %[[SCALE_Y_N]]
-  // CHECK: %[[PRED_X:.*]] = arith.cmpi sge, %[[D_X_DOUBLE]], %[[SCALE_X_N]]
   // CHECK: %[[VAL_37:.*]] = arith.select %[[PRED_Y]], %[[ONE]], %[[ZERO]]
-  // CHECK: %[[VAL_38:.*]] = arith.select %[[PRED_X]], %[[ONE]], %[[ZERO]]
   // CHECK: %[[VAL_39:.*]] = arith.addi %[[I_Y]], %[[VAL_37]]
-  // CHECK: %[[VAL_40:.*]] = arith.addi %[[I_X]], %[[VAL_38]]
-
-  // This section applies bound checking to be within the input image.
-
-  // CHECK: %[[VAL_41:.*]] = arith.cmpi slt, %[[VAL_39]], %[[XY_MIN]]
-  // CHECK: %[[VAL_42:.*]] = arith.select %[[VAL_41]], %[[XY_MIN]], %[[VAL_39]]
+  // CHECK: %[[VAL_41:.*]] = arith.cmpi slt, %[[VAL_39]], %[[ZERO]]
+  // CHECK: %[[VAL_42:.*]] = arith.select %[[VAL_41]], %[[ZERO]], %[[VAL_39]]
   // CHECK: %[[VAL_43:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[VAL_39]]
   // CHECK: %[[VAL_44:.*]] = arith.select %[[VAL_43]], %[[Y_MAX]], %[[VAL_42]]
-  // CHECK: %[[VAL_45:.*]] = arith.cmpi slt, %[[VAL_40]], %[[XY_MIN]]
-  // CHECK: %[[VAL_46:.*]] = arith.select %[[VAL_45]], %[[XY_MIN]], %[[VAL_40]]
+  // CHECK: %[[IDY:.+]] = arith.index_cast %[[VAL_44]]
+
+  // Compute the offset and bound for the X position.
+  // CHECK: %[[D_X_DOUBLE:.*]] = arith.shli %[[D_X]], %[[ONE]]
+  // CHECK: %[[PRED_X:.*]] = arith.cmpi sge, %[[D_X_DOUBLE]], %[[SCALE_X_N]]
+  // CHECK: %[[VAL_38:.*]] = arith.select %[[PRED_X]], %[[ONE]], %[[ZERO]]
+  // CHECK: %[[VAL_40:.*]] = arith.addi %[[I_X]], %[[VAL_38]]
+  // CHECK: %[[VAL_45:.*]] = arith.cmpi slt, %[[VAL_40]], %[[ZERO]]
+  // CHECK: %[[VAL_46:.*]] = arith.select %[[VAL_45]], %[[ZERO]], %[[VAL_40]]
   // CHECK: %[[VAL_47:.*]] = arith.cmpi slt, %[[X_MAX]], %[[VAL_40]]
   // CHECK: %[[VAL_48:.*]] = arith.select %[[VAL_47]], %[[X_MAX]], %[[VAL_46]]
-
-  // Extract the nearest value using the computed indices.
-
-  // CHECK: %[[IDY:.+]] = arith.index_cast %[[VAL_44]]
   // CHECK: %[[IDX:.+]] = arith.index_cast %[[VAL_48]]
+
   // CHECK: %[[EXTRACT:.+]] = tensor.extract %arg0[%[[IDX_0]], %[[IDY]], %[[IDX]], %[[IDX_3]]]
   // CHECK: linalg.yield %[[EXTRACT]]
 
   // Round to the nearest index.
   %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = [11, 7, 89, 6], offset = [0, 0], border = [0, 0]} : (tensor<1x15x13x1xi8>) -> tensor<1x23x179x1xi8>
-           return
+  return
 }
 
 // -----
 
 // CHECK-LABEL:  @resize_bilinear_int
 // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
-func.func @resize_bilinear_int(%arg0: tensor<1x19x19x1xi8>) {
-  // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x289x289x1xi32>
+func.func @resize_bilinear_int(%arg0: tensor<1x19x20x1xi8>) {
+  // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x304x320x1xi48>
   // CHECK: %[[GENERIC:.+]] = linalg.generic
   // CHECK: %[[IDX_0:.+]] = linalg.index 0
   // CHECK: %[[IDX_1:.+]] = linalg.index 1
   // CHECK: %[[IDX_2:.+]] = linalg.index 2
   // CHECK: %[[IDX_3:.+]] = linalg.index 3
-  // CHECK-DAG: %[[XY_MIN:.+]] = arith.constant 0
+  // CHECK-DAG: %[[ZERO:.+]] = arith.constant 0
   // CHECK-DAG: %[[Y_MAX:.+]] = arith.constant 18
-  // CHECK-DAG: %[[X_MAX:.+]] = arith.constant 18
+  // CHECK-DAG: %[[X_MAX:.+]] = arith.constant 19
   // CHECK: %[[Y:.+]] = arith.index_cast %[[IDX_1]]
   // CHECK: %[[X:.+]] = arith.index_cast %[[IDX_2]]
   // CHECK-DAG: %[[SCALE_Y_N:.*]] = arith.constant 16
@@ -214,51 +210,53 @@ func.func @resize_bilinear_int(%arg0: tensor<1x19x19x1xi8>) {
   // CHECK-DAG: %[[BORDER_X:.*]] = arith.constant 0
 
   // CHECK: %[[TEMP_Y:.*]] = arith.muli %[[Y]], %[[SCALE_Y_D]]
-  // CHECK: %[[TEMP_X:.*]] = arith.muli %[[X]], %[[SCALE_X_D]]
   // CHECK: %[[Y:.*]] = arith.addi %[[TEMP_Y]], %[[OFFSET_Y]]
-  // CHECK: %[[X:.*]] = arith.addi %[[TEMP_X]], %[[OFFSET_X]]
   // CHECK: %[[I_Y:.*]] = arith.divsi %[[Y]], %[[SCALE_Y_N]]
-  // CHECK: %[[I_X:.*]] = arith.divsi %[[X]], %[[SCALE_X_N]]
   // CHECK: %[[TEMP_Y:.*]] = arith.muli %[[I_Y]], %[[SCALE_Y_N]]
-  // CHECK: %[[TEMP_X:.*]] = arith.muli %[[I_X]], %[[SCALE_X_N]]
   // CHECK: %[[D_Y:.*]] = arith.subi %[[Y]], %[[TEMP_Y]]
+
+  // CHECK: %[[TEMP_X:.*]] = arith.muli %[[X]], %[[SCALE_X_D]]
+  // CHECK: %[[X:.*]] = arith.addi %[[TEMP_X]], %[[OFFSET_X]]
+  // CHECK: %[[I_X:.*]] = arith.divsi %[[X]], %[[SCALE_X_N]]
+  // CHECK: %[[TEMP_X:.*]] = arith.muli %[[I_X]], %[[SCALE_X_N]]
   // CHECK: %[[D_X:.*]] = arith.subi %[[X]], %[[TEMP_X]]
 
   // Compute the left, right, and top indices for the bilinear interpolation.
 
   // CHECK-DAG: %[[ONE:.*]] = arith.constant 1
   // CHECK: %[[Y1:.*]] = arith.addi %[[I_Y]], %[[ONE]]
-  // CHECK: %[[X1:.*]] = arith.addi %[[I_X]], %[[ONE]]
 
   // Bound check each dimension.
 
-  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_Y]], %[[XY_MIN]]
-  // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[I_Y]]
+  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_Y]], %[[ZERO]]
+  // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[ZERO]], %[[I_Y]]
   // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[I_Y]]
   // CHECK: %[[YLO:.*]] = arith.select %[[PRED]], %[[Y_MAX]], %[[BOUND]]
 
-  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y1]], %[[XY_MIN]]
-  // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[Y1]]
+  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y1]], %[[ZERO]]
+  // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[ZERO]], %[[Y1]]
   // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[Y1]]
   // CHECK: %[[YHI:.*]] = arith.select %[[PRED]], %[[Y_MAX]], %[[BOUND]]
 
-  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_X]], %[[XY_MIN]]
-  // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[I_X]]
+  // CHECK: %[[YLOI:.+]] = arith.index_cast %[[YLO]]
+  // CHECK: %[[YHII:.+]] = arith.index_cast %[[YHI]]
+
+  // CHECK: %[[X1:.*]] = arith.addi %[[I_X]], %[[ONE]]
+  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_X]], %[[ZERO]]
+  // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[ZERO]], %[[I_X]]
   // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X_MAX]], %[[I_X]]
   // CHECK: %[[XLO:.*]] = arith.select %[[PRED]], %[[X_MAX]], %[[BOUND]]
 
-  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X1]], %[[XY_MIN]]
-  // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[X1]]
+  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X1]], %[[ZERO]]
+  // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[ZERO]], %[[X1]]
   // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X_MAX]], %[[X1]]
   // CHECK: %[[XHI:.*]] = arith.select %[[PRED]], %[[X_MAX]], %[[BOUND]]
 
-  // Extract each corner of the bilinear interpolation.
-
-  // CHECK: %[[YLOI:.+]] = arith.index_cast %[[YLO]]
-  // CHECK: %[[YHII:.+]] = arith.index_cast %[[YHI]]
   // CHECK: %[[XLOI:.+]] = arith.index_cast %[[XLO]]
   // CHECK: %[[XHII:.+]] = arith.index_cast %[[XHI]]
 
+  // Extract each corner of the bilinear interpolation.
+
   // CHECK: %[[LOLO:.+]] = tensor.extract %[[ARG0]][%[[IDX_0]], %[[YLOI]], %[[XLOI]], %[[IDX_3]]]
   // CHECK: %[[LOHI:.+]] = tensor.extract %[[ARG0]][%[[IDX_0]], %[[YLOI]], %[[XHII]], %[[IDX_3]]]
   // CHECK: %[[HILO:.+]] = tensor.extract %[[ARG0]][%[[IDX_0]], %[[YHII]], %[[XLOI]], %[[IDX_3]]]
@@ -269,24 +267,29 @@ func.func @resize_bilinear_int(%arg0: tensor<1x19x19x1xi8>) {
   // CHECK: %[[XHILO:.+]] = arith.extsi %[[HILO]]
   // CHECK: %[[XHIHI:.+]] = arith.extsi %[[HIHI]]
 
+  // CHECK-NEXT: %[[D_X_EXT:.+]] = arith.extsi %[[D_X]]
+  // CHECK-NEXT: %[[D_Y_EXT:.+]] = arith.extsi %[[D_Y]]
+  // CHECK-NEXT: %[[Y_N_EXT:.+]] = arith.extsi %[[SCALE_Y_N]]
+  // CHECK-NEXT: %[[X_N_EXT:.+]] = arith.extsi %[[SCALE_X_N]]
+
   // Compute the bilinear interpolation.
 
-  // CHECK: %[[NDX:.+]] = arith.subi %[[SCALE_X_N]], %[[D_X]]
+  // CHECK: %[[NDX:.+]] = arith.subi %[[X_N_EXT]], %[[D_X_EXT]]
   // CHECK: %[[WLOLO:.+]] = arith.muli %[[XLOLO]], %[[NDX]]
-  // CHECK: %[[WLOHI:.+]] = arith.muli %[[XLOHI]], %[[D_X]]
+  // CHECK: %[[WLOHI:.+]] = arith.muli %[[XLOHI]], %[[D_X_EXT]]
   // CHECK: %[[LO:.+]] = arith.addi %[[WLOLO]], %[[WLOHI]]
   // CHECK: %[[WHILO:.+]] = arith.muli %[[XHILO]], %[[NDX]]
-  // CHECK: %[[WHIHI:.+]] = arith.muli %[[XHIHI]], %[[D_X]]
+  // CHECK: %[[WHIHI:.+]] = arith.muli %[[XHIHI]], %[[D_X_EXT]]
   // CHECK: %[[HI:.+]] = arith.addi %[[WHILO]], %[[WHIHI]]
-  // CHECK: %[[NDY:.+]] = arith.subi %[[SCALE_Y_N]], %[[D_Y]]
+  // CHECK: %[[NDY:.+]] = arith.subi %[[Y_N_EXT]], %[[D_Y_EXT]]
   // CHECK: %[[WLO:.+]] = arith.muli %[[LO]], %[[NDY]]
-  // CHECK: %[[WHI:.+]] = arith.muli %[[HI]], %[[D_Y]]
+  // CHECK: %[[WHI:.+]] = arith.muli %[[HI]], %[[D_Y_EXT]]
   // CHECK: %[[RESULT:.+]] = arith.addi %[[WLO]], %[[WHI]]
   // CHECK: linalg.yield %[[RESULT]]
 
   // Round to the nearest index.
-  %0 = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = [16, 1, 16, 1], offset = [0, 0], border = [0, 0]} : (tensor<1x19x19x1xi8>) -> tensor<1x289x289x1xi32>
-           return
+  %0 = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = [16, 1, 16, 1], offset = [0, 0], border = [0, 0]} : (tensor<1x19x20x1xi8>) -> tensor<1x304x320x1xi48>
+  return
 }
 
 // -----
@@ -299,7 +302,7 @@ func.func @resize_nearest_fp(%input: tensor<1x50x48x1xf32>) -> () {
   // CHECK: %[[IDX1:.+]] = linalg.index 1
   // CHECK: %[[IDX2:.+]] = linalg.index 2
   // CHECK: %[[IDX3:.+]] = linalg.index 3
-  // CHECK-DAG: %[[XYMIN:.*]] = arith.constant 0
+  // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0
   // CHECK-DAG: %[[YMAX:.*]] = arith.constant 49
   // CHECK-DAG: %[[XMAX:.*]] = arith.constant 47
   // CHECK: %[[Y:.+]] = arith.index_cast %[[IDX1]]
@@ -314,72 +317,68 @@ func.func @resize_nearest_fp(%input: tensor<1x50x48x1xf32>) -> () {
   // CHECK-DAG: %[[IBORDER_X:.*]] = arith.constant 31
 
   // CHECK: %[[Y0:.+]] = arith.uitofp %[[Y]]
-  // CHECK: %[[X0:.+]] = arith.uitofp %[[X]]
   // CHECK: %[[SCALE_Y_N:.*]] = arith.uitofp %[[ISCALE_Y_N]]
   // CHECK: %[[SCALE_Y_D:.*]] = arith.uitofp %[[ISCALE_Y_D]]
+  // CHECK: %[[OFFSET_Y:.*]] = arith.uitofp %[[IOFFSET_Y]]
+  // CHECK: %[[VAL_29:.*]] = arith.mulf %[[Y0]], %[[SCALE_Y_D]]
+  // CHECK: %[[VAL_31:.*]] = arith.addf %[[VAL_29]], %[[OFFSET_Y]]
+  // CHECK: %[[VAL_33:.*]] = arith.divf %[[VAL_31]], %[[SCALE_Y_N]]
+  // CHECK: %[[VAL_35:.*]] = math.floor %[[VAL_33]]
+  // CHECK: %[[D_Y:.*]] = arith.subf %[[VAL_33]], %[[VAL_35]]
+  // CHECK: %[[VAL_39:.*]] = arith.fptosi %[[VAL_35]]
+
+  // CHECK: %[[X0:.+]] = arith.uitofp %[[X]]
   // CHECK: %[[SCALE_X_N:.*]] = arith.uitofp %[[ISCALE_X_N]]
   // CHECK: %[[SCALE_X_D:.*]] = arith.uitofp %[[ISCALE_X_D]]
-  // CHECK: %[[OFFSET_Y:.*]] = arith.uitofp %[[IOFFSET_Y]]
   // CHECK: %[[OFFSET_X:.*]] = arith.uitofp %[[IOFFSET_X]]
-
-  // CHECK: %[[VAL_29:.*]] = arith.mulf %[[Y0]], %[[SCALE_Y_D]]
   // CHECK: %[[VAL_30:.*]] = arith.mulf %[[X0]], %[[SCALE_X_D]]
-  // CHECK: %[[VAL_31:.*]] = arith.addf %[[VAL_29]], %[[OFFSET_Y]]
   // CHECK: %[[VAL_32:.*]] = arith.addf %[[VAL_30]], %[[OFFSET_X]]
-  // CHECK: %[[VAL_33:.*]] = arith.divf %[[VAL_31]], %[[SCALE_Y_N]]
   // CHECK: %[[VAL_34:.*]] = arith.divf %[[VAL_32]], %[[SCALE_X_N]]
-
-  // Find the remainder and integer component of the target index.
-
-  // CHECK: %[[VAL_35:.*]] = math.floor %[[VAL_33]]
   // CHECK: %[[VAL_36:.*]] = math.floor %[[VAL_34]]
-  // CHECK: %[[D_Y:.*]] = arith.subf %[[VAL_33]], %[[VAL_35]]
   // CHECK: %[[D_X:.*]] = arith.subf %[[VAL_34]], %[[VAL_36]]
-  // CHECK: %[[VAL_39:.*]] = arith.fptosi %[[VAL_35]]
   // CHECK: %[[VAL_40:.*]] = arith.fptosi %[[VAL_36]]
 
-  // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0
   // CHECK-DAG: %[[ONE:.*]] = arith.constant 1
   // CHECK-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01
   // CHECK: %[[PRED_Y:.*]] = arith.cmpf oge, %[[D_Y]], %[[HALF]]
-  // CHECK: %[[PRED_X:.*]] = arith.cmpf oge, %[[D_X]], %[[HALF]]
   // CHECK: %[[ROUND_Y:.*]] = arith.select %[[PRED_Y]], %[[ONE]], %[[ZERO]]
-  // CHECK: %[[ROUND_X:.*]] = arith.select %[[PRED_X]], %[[ONE]], %[[ZERO]]
   // CHECK: %[[VAL_48:.*]] = arith.addi %[[VAL_39]], %[[ROUND_Y]]
-  // CHECK: %[[VAL_49:.*]] = arith.addi %[[VAL_40]], %[[ROUND_X]]
-
-  // CHECK: %[[VAL_50:.*]] = arith.cmpi slt, %[[VAL_48]], %[[XYMIN]]
-  // CHECK: %[[VAL_51:.*]] = arith.select %[[VAL_50]], %[[XYMIN]], %[[VAL_48]]
+  // CHECK: %[[VAL_50:.*]] = arith.cmpi slt, %[[VAL_48]], %[[ZERO]]
+  // CHECK: %[[VAL_51:.*]] = arith.select %[[VAL_50]], %[[ZERO]], %[[VAL_48]]
   // CHECK: %[[VAL_52:.*]] = arith.cmpi slt, %[[YMAX]], %[[VAL_48]]
   // CHECK: %[[VAL_53:.*]] = arith.select %[[VAL_52]], %[[YMAX]], %[[VAL_51]]
-  // CHECK: %[[VAL_54:.*]] = arith.cmpi slt, %[[VAL_49]], %[[XYMIN]]
-  // CHECK: %[[VAL_55:.*]] = arith.select %[[VAL_54]], %[[XYMIN]], %[[VAL_49]]
+  // CHECK: %[[IDY:.*]] = arith.index_cast %[[VAL_53]]
+
+  // CHECK-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01
+  // CHECK: %[[PRED_X:.*]] = arith.cmpf oge, %[[D_X]], %[[HALF]]
+  // CHECK: %[[ROUND_X:.*]] = arith.select %[[PRED_X]], %[[ONE]], %[[ZERO]]
+  // CHECK: %[[VAL_49:.*]] = arith.addi %[[VAL_40]], %[[ROUND_X]]
+  // CHECK: %[[VAL_54:.*]] = arith.cmpi slt, %[[VAL_49]], %[[ZERO]]
+  // CHECK: %[[VAL_55:.*]] = arith.select %[[VAL_54]], %[[ZERO]], %[[VAL_49]]
   // CHECK: %[[VAL_56:.*]] = arith.cmpi slt, %[[XMAX]], %[[VAL_49]]
   // CHECK: %[[VAL_57:.*]] = arith.select %[[VAL_56]], %[[XMAX]], %[[VAL_55]]
-
-  // CHECK: %[[IDY:.*]] = arith.index_cast %[[VAL_53]]
   // CHECK: %[[IDX:.*]] = arith.index_cast %[[VAL_57]]
+
   // CHECK: %[[EXTRACT:.+]] = tensor.extract %arg0[%[[IDX0]], %[[IDY]], %[[IDX]], %[[IDX3]]]
   // CHECK: linalg.yield %[[EXTRACT]]
 
   %output = "tosa.resize"(%input) {mode = "NEAREST_NEIGHBOR", scale = [64, 2, 64, 2], offset = [-31, -31], border = [31, 31]} : (tensor<1x50x48x1xf32>) -> tensor<1x1600x1536x1xf32>
-
   return
 }
 
 // -----
 
 // CHECK-LABEL: @resize_bilinear_fp
-func.func @resize_bilinear_fp(%input: tensor<1x23x23x1xf32>) -> () {
-  // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x89x89x1xf32>
+func.func @resize_bilinear_fp(%input: tensor<1x23x24x1xf32>) -> () {
+  // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x92x96x1xf32>
   // CHECK: %[[GENERIC:.+]] = linalg.generic
   // CHECK: %[[IDX_0:.+]] = linalg.index 0
   // CHECK: %[[IDX_1:.+]] = linalg.index 1
   // CHECK: %[[IDX_2:.+]] = linalg.index 2
   // CHECK: %[[IDX_3:.+]] = linalg.index 3
-  // CHECK-DAG: %[[XY_MIN:.*]] = arith.constant 0
+  // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0
   // CHECK-DAG: %[[Y_MAX:.*]] = arith.constant 22
-  // CHECK-DAG: %[[X_MAX:.*]] = arith.constant 22
+  // CHECK-DAG: %[[X_MAX:.*]] = arith.constant 23
   // CHECK: %[[Y:.+]] = arith.index_cast %[[IDX_1]]
   // CHECK: %[[X:.+]] = arith.index_cast %[[IDX_2]]
   // CHECK-DAG: %[[ISCALE_Y_N:.*]] = arith.constant 4
@@ -392,58 +391,58 @@ func.func @resize_bilinear_fp(%input: tensor<1x23x23x1xf32>) -> () {
   // CHECK-DAG: %[[IBORDER_X:.*]] = arith.constant 0
 
   // CHECK: %[[Y0:.+]] = arith.uitofp %[[Y]]
-  // CHECK: %[[X0:.+]] = arith.uitofp %[[X]]
   // CHECK: %[[SCALE_Y_N:.*]] = arith.uitofp %[[ISCALE_Y_N]]
   // CHECK: %[[SCALE_Y_D:.*]] = arith.uitofp %[[ISCALE_Y_D]]
+  // CHECK: %[[OFFSET_Y:.*]] = arith.uitofp %[[IOFFSET_Y]]
+  // CHECK: %[[VAL_29:.*]] = arith.mulf %[[Y0]], %[[SCALE_Y_D]]
+  // CHECK: %[[VAL_31:.*]] = arith.addf %[[VAL_29]], %[[OFFSET_Y]]
+  // CHECK: %[[VAL_33:.*]] = arith.divf %[[VAL_31]], %[[SCALE_Y_N]]
+  // CHECK: %[[VAL_35:.*]] = math.floor %[[VAL_33]]
+  // CHECK: %[[D_Y:.*]] = arith.subf %[[VAL_33]], %[[VAL_35]]
+  // CHECK: %[[I_Y:.*]] = arith.fptosi %[[VAL_35]]
+
+  // CHECK: %[[X0:.+]] = arith.uitofp %[[X]]
   // CHECK: %[[SCALE_X_N:.*]] = arith.uitofp %[[ISCALE_X_N]]
   // CHECK: %[[SCALE_X_D:.*]] = arith.uitofp %[[ISCALE_X_D]]
-  // CHECK: %[[OFFSET_Y:.*]] = arith.uitofp %[[IOFFSET_Y]]
   // CHECK: %[[OFFSET_X:.*]] = arith.uitofp %[[IOFFSET_X]]
-
-  // CHECK: %[[VAL_29:.*]] = arith.mulf %[[Y0]], %[[SCALE_Y_D]]
   // CHECK: %[[VAL_30:.*]] = arith.mulf %[[X0]], %[[SCALE_X_D]]
-  // CHECK: %[[VAL_31:.*]] = arith.addf %[[VAL_29]], %[[OFFSET_Y]]
   // CHECK: %[[VAL_32:.*]] = arith.addf %[[VAL_30]], %[[OFFSET_X]]
-  // CHECK: %[[VAL_33:.*]] = arith.divf %[[VAL_31]], %[[SCALE_Y_N]]
   // CHECK: %[[VAL_34:.*]] = arith.divf %[[VAL_32]], %[[SCALE_X_N]]
-
-  // CHECK: %[[VAL_35:.*]] = math.floor %[[VAL_33]]
   // CHECK: %[[VAL_36:.*]] = math.floor %[[VAL_34]]
-  // CHECK: %[[D_Y:.*]] = arith.subf %[[VAL_33]], %[[VAL_35]]
   // CHECK: %[[D_X:.*]] = arith.subf %[[VAL_34]], %[[VAL_36]]
-  // CHECK: %[[I_Y:.*]] = arith.fptosi %[[VAL_35]]
   // CHECK: %[[I_X:.*]] = arith.fptosi %[[VAL_36]]
 
   // Compute the left, right, and top indices for the bilinear interpolation.
 
-  // CHECK-DAG: %[[ONE:.*]] = arith.constant 1
-  // CHECK: %[[Y1:.*]] = arith.addi %[[I_Y]], %[[ONE]]
-  // CHECK: %[[X1:.*]] = arith.addi %[[I_X]], %[[ONE]]
+  // CHECK: %[[ONE:.*]] = arith.constant 1
 
   // Bound check each dimension.
 
-  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_Y]], %[[XY_MIN]]
-  // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[I_Y]]
+  // CHECK: %[[Y1:.*]] = arith.addi %[[I_Y]], %[[ONE]]
+
+  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_Y]], %[[ZERO]]
+  // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[ZERO]], %[[I_Y]]
   // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[I_Y]]
   // CHECK: %[[YLO:.*]] = arith.select %[[PRED]], %[[Y_MAX]], %[[BOUND]]
 
-  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y1]], %[[XY_MIN]]
-  // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[Y1]]
+  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y1]], %[[ZERO]]
+  // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[ZERO]], %[[Y1]]
   // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[Y_MAX]], %[[Y1]]
   // CHECK: %[[YHI:.*]] = arith.select %[[PRED]], %[[Y_MAX]], %[[BOUND]]
+  // CHECK: %[[YLOI:.+]] = arith.index_cast %[[YLO]]
+  // CHECK: %[[YHII:.+]] = arith.index_cast %[[YHI]]
 
-  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_X]], %[[XY_MIN]]
-  // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[I_X]]
+  // CHECK: %[[X1:.*]] = arith.addi %[[I_X]], %[[ONE]]
+  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[I_X]], %[[ZERO]]
+  // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[ZERO]], %[[I_X]]
   // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X_MAX]], %[[I_X]]
   // CHECK: %[[XLO:.*]] = arith.select %[[PRED]], %[[X_MAX]], %[[BOUND]]
 
-  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X1]], %[[XY_MIN]]
-  // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[XY_MIN]], %[[X1]]
+  // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X1]], %[[ZERO]]
+  // CHECK: %[[BOUND:.*]] = arith.select %[[PRED]], %[[ZERO]], %[[X1]]
   // CHECK: %[[PRED:.*]] = arith.cmpi slt, %[[X_MAX]], %[[X1]]
   // CHECK: %[[XHI:.*]] = arith.select %[[PRED]], %[[X_MAX]], %[[BOUND]]
 
-  // CHECK: %[[YLOI:.+]] = arith.index_cast %[[YLO]]
-  // CHECK: %[[YHII:.+]] = arith.index_cast %[[YHI]]
   // CHECK: %[[XLOI:.+]] = arith.index_cast %[[XLO]]
   // CHECK: %[[XHII:.+]] = arith.index_cast %[[XHI]]
 
@@ -457,6 +456,7 @@ func.func @resize_bilinear_fp(%input: tensor<1x23x23x1xf32>) -> () {
   // CHECK: %[[WLOLO:.+]] = arith.mulf %[[LOLO]], %[[NDX]]
   // CHECK: %[[WLOHI:.+]] = arith.mulf %[[LOHI]], %[[D_X]]
   // CHECK: %[[LO:.+]] = arith.addf %[[WLOLO]], %[[WLOHI]]
+  // CHECK: %[[NDX:.+]] = arith.subf %[[ONE]], %[[D_X]]
   // CHECK: %[[WHILO:.+]] = arith.mulf %[[HILO]], %[[NDX]]
   // CHECK: %[[WHIHI:.+]] = arith.mulf %[[HIHI]], %[[D_X]]
   // CHECK: %[[HI:.+]] = arith.addf %[[WHILO]], %[[WHIHI]]
@@ -467,7 +467,7 @@ func.func @resize_bilinear_fp(%input: tensor<1x23x23x1xf32>) -> () {
   // CHECK: linalg.yield %[[RESULT]]
 
   // Round by bilinear interpolation
-  %output = "tosa.resize"(%input) {mode = "BILINEAR", scale = [4, 1, 4, 1], offset = [0, 0], border = [0, 0]} : (tensor<1x23x23x1xf32>) -> tensor<1x89x89x1xf32>
+  %output = "tosa.resize"(%input) {mode = "BILINEAR", scale = [4, 1, 4, 1], offset = [0, 0], border = [0, 0]} : (tensor<1x23x24x1xf32>) -> tensor<1x92x96x1xf32>
 
   return
 }


        


More information about the Mlir-commits mailing list