[Mlir-commits] [mlir] 97c571a - [mlir][tosa] Add tosa.resize lowering to linalg generic

Rob Suderman llvmlistbot at llvm.org
Fri Apr 23 12:50:38 PDT 2021


Author: Rob Suderman
Date: 2021-04-23T12:43:02-07:00
New Revision: 97c571abbcea2b540511ae5a874da05bf77e5e5d

URL: https://github.com/llvm/llvm-project/commit/97c571abbcea2b540511ae5a874da05bf77e5e5d
DIFF: https://github.com/llvm/llvm-project/commit/97c571abbcea2b540511ae5a874da05bf77e5e5d.diff

LOG: [mlir][tosa] Add tosa.resize lowering to linalg generic

Includes tests and implementation for both integer and floating point values.
Both nearest neighbor and bilinear interpolation is included.

Differential Revision: https://reviews.llvm.org/D101009

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 8ef186e257ec2..21ef0b84a3dd5 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1126,6 +1126,277 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
   }
 };
 
+class ResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
+public:
+  using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::ResizeOp op,
+                                PatternRewriter &rewriter) const final {
+    Location loc = op.getLoc();
+    auto input = op.input();
+    auto inputTy = input.getType().cast<ShapedType>();
+    auto resultTy = op.getType().cast<ShapedType>();
+    auto resultElementTy = resultTy.getElementType();
+
+    auto imageH = inputTy.getShape()[1];
+    auto imageW = inputTy.getShape()[2];
+
+    if (!resultTy.hasStaticShape())
+      return failure();
+    if (op.mode() != "NEAREST_NEIGHBOR" && op.mode() != "BILINEAR")
+      return failure();
+
+    auto initTensor =
+        rewriter
+            .create<linalg::InitTensorOp>(loc, ArrayRef<Value>{},
+                                          resultTy.getShape(), resultElementTy)
+            .result();
+
+    SmallVector<AffineMap, 2> affineMaps = {
+        rewriter.getMultiDimIdentityMap(resultTy.getRank())};
+
+    auto genericOp = rewriter.create<linalg::IndexedGenericOp>(
+        loc, resultTy, ValueRange({}), ValueRange{initTensor}, affineMaps,
+        getNParallelLoopsAttrs(resultTy.getRank()));
+    rewriter.replaceOp(op, genericOp.getResult(0));
+
+    {
+      OpBuilder::InsertionGuard regionGuard(rewriter);
+      Block *block = rewriter.createBlock(
+          &genericOp.region(), genericOp.region().end(),
+          TypeRange({rewriter.getIndexType(), rewriter.getIndexType(),
+                     rewriter.getIndexType(), rewriter.getIndexType(),
+                     resultElementTy}));
+      Value batch = block->getArgument(0);
+      Value channel = block->getArgument(3);
+
+      auto hwMin =
+          rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
+      auto hMax = rewriter.create<ConstantOp>(
+          loc, rewriter.getI32IntegerAttr(imageH - 1));
+      auto wMax = rewriter.create<ConstantOp>(
+          loc, rewriter.getI32IntegerAttr(imageW - 1));
+
+      Value inY = rewriter.create<IndexCastOp>(loc, rewriter.getI32Type(),
+                                               block->getArgument(1));
+      Value inX = rewriter.create<IndexCastOp>(loc, rewriter.getI32Type(),
+                                               block->getArgument(2));
+
+      int32_t shift = op.shift();
+      bool floatingPointMode = shift == 0;
+
+      Value yStride, xStride, yOffset, xOffset;
+      if (floatingPointMode) {
+        yStride = rewriter.create<ConstantOp>(loc, op.stride_fp()[0]);
+        xStride = rewriter.create<ConstantOp>(loc, op.stride_fp()[1]);
+        yOffset = rewriter.create<ConstantOp>(loc, op.offset_fp()[0]);
+        xOffset = rewriter.create<ConstantOp>(loc, op.offset_fp()[1]);
+      } else {
+        SmallVector<int32_t> stride, offset;
+        getValuesFromIntArrayAttribute(op.stride(), stride);
+        getValuesFromIntArrayAttribute(op.offset(), offset);
+
+        yStride = rewriter.create<ConstantOp>(
+            loc, rewriter.getI32IntegerAttr(stride[0]));
+        xStride = rewriter.create<ConstantOp>(
+            loc, rewriter.getI32IntegerAttr(stride[1]));
+        yOffset = rewriter.create<ConstantOp>(
+            loc, rewriter.getI32IntegerAttr(offset[0]));
+        xOffset = rewriter.create<ConstantOp>(
+            loc, rewriter.getI32IntegerAttr(offset[1]));
+      }
+
+      // Compute the the integer index and partial offset.
+      // x = x * stride + offset;
+      // ix = floor(x)
+      // dx = x - ix
+      Value ix, iy, dx, dy;
+      if (floatingPointMode) {
+        Value y = rewriter.create<UIToFPOp>(loc, rewriter.getF32Type(), inY);
+        Value x = rewriter.create<UIToFPOp>(loc, rewriter.getF32Type(), inX);
+
+        y = rewriter.create<MulFOp>(loc, y, yStride);
+        x = rewriter.create<MulFOp>(loc, x, xStride);
+
+        y = rewriter.create<AddFOp>(loc, y, yOffset);
+        x = rewriter.create<AddFOp>(loc, x, xOffset);
+
+        iy = rewriter.create<FloorFOp>(loc, y);
+        ix = rewriter.create<FloorFOp>(loc, x);
+
+        dy = rewriter.create<SubFOp>(loc, y, iy);
+        dx = rewriter.create<SubFOp>(loc, x, ix);
+
+        iy = rewriter.create<FPToSIOp>(loc, rewriter.getI32Type(), iy);
+        ix = rewriter.create<FPToSIOp>(loc, rewriter.getI32Type(), ix);
+      } else {
+        Value shiftVal =
+            rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(shift));
+
+        Value y = rewriter.create<MulIOp>(loc, inY, yStride);
+        Value x = rewriter.create<MulIOp>(loc, inX, xStride);
+
+        y = rewriter.create<AddIOp>(loc, y, yOffset);
+        x = rewriter.create<AddIOp>(loc, x, xOffset);
+
+        iy = rewriter.create<SignedShiftRightOp>(loc, y, shiftVal);
+        ix = rewriter.create<SignedShiftRightOp>(loc, x, shiftVal);
+
+        Value yTrunc = rewriter.create<ShiftLeftOp>(loc, iy, shiftVal);
+        Value xTrunc = rewriter.create<ShiftLeftOp>(loc, ix, shiftVal);
+
+        dy = rewriter.create<SubIOp>(loc, y, yTrunc);
+        dx = rewriter.create<SubIOp>(loc, x, xTrunc);
+      }
+
+      if (op.mode() == "NEAREST_NEIGHBOR") {
+        Value yPred, xPred;
+        // Round the index position towards the closest pixel location.
+        if (floatingPointMode) {
+          auto halfVal =
+              rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.5f));
+          yPred = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGE, dy,
+                                                halfVal);
+          xPred = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGE, dx,
+                                                halfVal);
+        } else {
+          auto halfVal = rewriter.create<ConstantOp>(
+              loc, rewriter.getI32IntegerAttr(1 << (shift - 1)));
+          yPred = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sge, dy,
+                                                halfVal);
+          xPred = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sge, dx,
+                                                halfVal);
+        }
+
+        auto zeroVal =
+            rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
+        auto oneVal =
+            rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(1));
+
+        auto yOffset =
+            rewriter.create<mlir::SelectOp>(loc, yPred, oneVal, zeroVal);
+        auto xOffset =
+            rewriter.create<mlir::SelectOp>(loc, xPred, oneVal, zeroVal);
+
+        iy = rewriter.create<AddIOp>(loc, iy, yOffset);
+        ix = rewriter.create<AddIOp>(loc, ix, xOffset);
+
+        // Clamp the to be within the bounds of the input image.
+
+        iy = clampHelper<mlir::CmpIOp>(loc, iy, hwMin, hMax, CmpIPredicate::slt,
+                                       rewriter);
+        ix = clampHelper<mlir::CmpIOp>(loc, ix, hwMin, wMax, CmpIPredicate::slt,
+                                       rewriter);
+
+        // Read the value from the input array.
+        iy = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), iy);
+        ix = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), ix);
+
+        Value result = rewriter.create<tensor::ExtractOp>(
+            loc, input, ValueRange{batch, iy, ix, channel});
+
+        rewriter.create<linalg::YieldOp>(loc, result);
+
+        return success();
+      }
+
+      if (op.mode() == "BILINEAR") {
+        Value y0 = iy;
+        Value x0 = ix;
+
+        auto oneVal =
+            rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(1));
+        Value y1 = rewriter.create<AddIOp>(loc, y0, oneVal);
+        Value x1 = rewriter.create<AddIOp>(loc, x0, oneVal);
+
+        y0 = clampHelper<mlir::CmpIOp>(loc, y0, hwMin, hMax, CmpIPredicate::slt,
+                                       rewriter);
+        y1 = clampHelper<mlir::CmpIOp>(loc, y1, hwMin, hMax, CmpIPredicate::slt,
+                                       rewriter);
+
+        x0 = clampHelper<mlir::CmpIOp>(loc, x0, hwMin, wMax, CmpIPredicate::slt,
+                                       rewriter);
+        x1 = clampHelper<mlir::CmpIOp>(loc, x1, hwMin, wMax, CmpIPredicate::slt,
+                                       rewriter);
+
+        y0 = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), y0);
+        y1 = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), y1);
+        x0 = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), x0);
+        x1 = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), x1);
+
+        Value y0x0 = rewriter.create<tensor::ExtractOp>(
+            loc, input, ValueRange{batch, y0, x0, channel});
+        Value y0x1 = rewriter.create<tensor::ExtractOp>(
+            loc, input, ValueRange{batch, y0, x1, channel});
+        Value y1x0 = rewriter.create<tensor::ExtractOp>(
+            loc, input, ValueRange{batch, y1, x0, channel});
+        Value y1x1 = rewriter.create<tensor::ExtractOp>(
+            loc, input, ValueRange{batch, y1, x1, channel});
+
+        if (floatingPointMode) {
+          auto oneVal =
+              rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.f));
+          Value rightPart = dx;
+          Value leftPart = rewriter.create<SubFOp>(loc, oneVal, dx);
+
+          y0x0 = rewriter.create<MulFOp>(loc, y0x0, leftPart);
+          y0x1 = rewriter.create<MulFOp>(loc, y0x1, rightPart);
+          Value topAcc = rewriter.create<AddFOp>(loc, y0x0, y0x1);
+
+          y1x0 = rewriter.create<MulFOp>(loc, y1x0, leftPart);
+          y1x1 = rewriter.create<MulFOp>(loc, y1x1, rightPart);
+          Value bottomAcc = rewriter.create<AddFOp>(loc, y1x0, y1x1);
+
+          Value bottomPart = dy;
+          Value topPart = rewriter.create<SubFOp>(loc, oneVal, dy);
+          topAcc = rewriter.create<MulFOp>(loc, topAcc, topPart);
+          bottomAcc = rewriter.create<MulFOp>(loc, bottomAcc, bottomPart);
+          Value result = rewriter.create<AddFOp>(loc, topAcc, bottomAcc);
+
+          rewriter.create<linalg::YieldOp>(loc, result);
+          return success();
+        } else {
+          y0x0 = rewriter.create<SignExtendIOp>(loc, resultElementTy, y0x0);
+          y0x1 = rewriter.create<SignExtendIOp>(loc, resultElementTy, y0x1);
+          y1x0 = rewriter.create<SignExtendIOp>(loc, resultElementTy, y1x0);
+          y1x1 = rewriter.create<SignExtendIOp>(loc, resultElementTy, y1x1);
+
+          if (resultElementTy.getIntOrFloatBitWidth() > 32) {
+            dx = rewriter.create<SignExtendIOp>(loc, resultElementTy, dx);
+            dy = rewriter.create<SignExtendIOp>(loc, resultElementTy, dy);
+          }
+
+          auto unitVal = rewriter.create<ConstantOp>(
+              loc, rewriter.getIntegerAttr(resultElementTy, 1 << shift));
+          Value rightPart = dx;
+          Value leftPart = rewriter.create<SubIOp>(loc, unitVal, dx);
+
+          y0x0 = rewriter.create<MulIOp>(loc, y0x0, leftPart);
+          y0x1 = rewriter.create<MulIOp>(loc, y0x1, rightPart);
+          Value topAcc = rewriter.create<AddIOp>(loc, y0x0, y0x1);
+
+          y1x0 = rewriter.create<MulIOp>(loc, y1x0, leftPart);
+          y1x1 = rewriter.create<MulIOp>(loc, y1x1, rightPart);
+          Value bottomAcc = rewriter.create<AddIOp>(loc, y1x0, y1x1);
+
+          Value bottomPart = dy;
+          Value topPart = rewriter.create<SubIOp>(loc, unitVal, dy);
+          topAcc = rewriter.create<MulIOp>(loc, topAcc, topPart);
+          bottomAcc = rewriter.create<MulIOp>(loc, bottomAcc, bottomPart);
+          Value result = rewriter.create<AddIOp>(loc, topAcc, bottomAcc);
+
+          rewriter.create<linalg::YieldOp>(loc, result);
+          return success();
+        }
+      }
+
+      return failure();
+    }
+
+    return success();
+  }
+};
+
 // At the codegen level any identity operations should be removed. Any cases
 // where identity is load-bearing (e.g. cross device computation) should be
 // handled before lowering to codegen.
@@ -1817,6 +2088,7 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
       PadConverter,
       ReshapeConverter,
       RescaleConverter,
+      ResizeConverter,
       ReverseConverter,
       TableConverter,
       TileConverter,

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 09392e6cef5db..ff4dbf4ac0529 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -963,3 +963,292 @@ func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28x3x3x
   %0 = "tosa.conv2d"(%input, %weights, %bias) {pad = [1, 1, 1, 1], stride = [1, 1], dilation = [2, 1]} : (tensor<1x47x40x28xf32>, tensor<28x3x3x28xf32>, tensor<28xf32>)  -> (tensor<1x45x40x28xf32>)
   return
 }
+
+// -----
+
+// CHECK-LABEL: @resize_nearest
+func @resize_nearest(%input: tensor<1x2x2x1xf32>) -> () {
+  // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 4, 4, 1]
+  // CHECK: %[[GENERIC:.+]] = linalg.indexed_generic
+  // CHECK-DAG: %[[XYMIN:.+]] = constant 0
+  // CHECK-DAG: %[[YMAX:.+]] = constant 1
+  // CHECK-DAG: %[[XMAX:.+]] = constant 1
+  // CHECK-DAG: %[[Y:.+]] = index_cast %arg2
+  // CHECK-DAG: %[[X:.+]] = index_cast %arg3
+  // CHECK-DAG: %[[STRIDEY:.+]] = constant 5.000000e-01
+  // CHECK-DAG: %[[STRIDEX:.+]] = constant 5.000000e-01
+  // CHECK-DAG: %[[OFFSETY:.+]] = constant 1.000000e-01
+  // CHECK-DAG: %[[OFFSETX:.+]] = constant 2.000000e-01
+  // CHECK-DAG: %[[VAL4:.+]] = uitofp %[[Y]]
+  // CHECK-DAG: %[[VAL5:.+]] = uitofp %[[X]]
+  // CHECK-DAG: %[[VAL6:.+]] = mulf %[[VAL4]], %[[STRIDEY]]
+  // CHECK-DAG: %[[VAL7:.+]] = mulf %[[VAL5]], %[[STRIDEX]]
+  // CHECK-DAG: %[[VAL8:.+]] = addf %[[VAL6]], %[[OFFSETY]]
+  // CHECK-DAG: %[[VAL9:.+]] = addf %[[VAL7]], %[[OFFSETX]]
+    
+  // Find the remainder and integer component of the target index.
+
+  // CHECK-DAG: %[[VAL10:.+]] = floorf %[[VAL8]]
+  // CHECK-DAG: %[[VAL11:.+]] = floorf %[[VAL9]]
+  // CHECK-DAG: %[[VAL12:.+]] = subf %[[VAL8]], %[[VAL10]]
+  // CHECK-DAG: %[[VAL13:.+]] = subf %[[VAL9]], %[[VAL11]]
+  // CHECK-DAG: %[[VAL14:.+]] = fptosi %[[VAL10]]
+  // CHECK-DAG: %[[VAL15:.+]] = fptosi %[[VAL11]]
+
+  // Round to the nearest index.
+
+  // CHECK-DAG: %[[ROUND:.+]] = constant 5.000000e-01
+  // CHECK-DAG: %[[VAL16:.+]] = cmpf oge, %[[VAL12]], %[[ROUND]]
+  // CHECK-DAG: %[[VAL17:.+]] = cmpf oge, %[[VAL13]], %[[ROUND]]
+  // CHECK-DAG: %[[ZERO:.+]] = constant 0
+  // CHECK-DAG: %[[ONE:.+]] = constant 1
+  // CHECK-DAG: %[[VAL18:.+]] = select %[[VAL16]], %[[ONE]], %[[ZERO]]
+  // CHECK-DAG: %[[VAL19:.+]] = select %[[VAL17]], %[[ONE]], %[[ZERO]]
+  // CHECK-DAG: %[[VAL20:.+]] = addi %[[VAL14]], %[[VAL18]]
+  // CHECK-DAG: %[[VAL21:.+]] = addi %[[VAL15]], %[[VAL19]]
+
+  // This section applies bound checking to be within the input image.
+
+  // CHECK-DAG: %[[VAL22:.+]] = cmpi slt, %[[VAL20]], %[[XYMIN]]
+  // CHECK-DAG: %[[VAL23:.+]] = select %[[VAL22]], %[[XYMIN]], %[[VAL20]]
+  // CHECK-DAG: %[[VAL24:.+]] = cmpi slt, %[[YMAX]], %[[VAL20]]
+  // CHECK-DAG: %[[VAL25:.+]] = select %[[VAL24]], %[[YMAX]], %[[VAL23]]
+  // CHECK-DAG: %[[VAL26:.+]] = cmpi slt, %[[VAL21]], %[[XYMIN]]
+  // CHECK-DAG: %[[VAL27:.+]] = select %[[VAL26]], %[[XYMIN]], %[[VAL21]]
+  // CHECK-DAG: %[[VAL28:.+]] = cmpi slt, %[[XMAX]], %[[VAL21]]
+  // CHECK-DAG: %[[VAL29:.+]] = select %[[VAL28]], %[[XMAX]], %[[VAL27]]
+
+  // Extract the nearest value using the computed indices.
+
+  // CHECK-DAG: %[[IDY:.+]] = index_cast %[[VAL25]]
+  // CHECK-DAG: %[[IDX:.+]] = index_cast %[[VAL29]]
+  // CHECK-DAG: %[[EXTRACT:.+]] = tensor.extract %arg0[%arg1, %[[IDY]], %[[IDX]], %arg4]
+  // CHECK: linalg.yield %[[EXTRACT]]
+  %output = "tosa.resize"(%input) { output_size = [4, 4], stride = [0, 0], offset = [0, 0], stride_fp = [0.5 : f32, 0.5 : f32], offset_fp = [0.1 : f32, 0.2 : f32], shift = 0 : i32, mode = "NEAREST_NEIGHBOR" } : (tensor<1x2x2x1xf32>)  -> (tensor<1x4x4x1xf32>)
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @resize_bilinear
+func @resize_bilinear(%input: tensor<1x2x2x1xf32>) -> () {
+  // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 4, 4, 1]
+  // CHECK: %[[GENERIC:.+]] = linalg.indexed_generic
+  // CHECK: %[[XYMIN:.+]] = constant 0
+  // CHECK: %[[YMAX:.+]] = constant 1
+  // CHECK: %[[XMAX:.+]] = constant 1
+
+  // CHECK: %[[VAL10:.+]] = floorf %[[VAL8:.+]]
+  // CHECK: %[[VAL11:.+]] = floorf %[[VAL9:.+]]
+
+  // CHECK: %[[DY:.+]] = subf %[[VAL8:.+]], %[[VAL10]] 
+  // CHECK: %[[DX:.+]] = subf %[[VAL9:.+]], %[[VAL11]] 
+
+  // CHECK: %[[Y0:.+]] = fptosi %[[VAL10]]
+  // CHECK: %[[X0:.+]] = fptosi %[[VAL11]]
+
+  // Compute the left, right, and top indices for the bilinear interpolation.
+
+  // CHECK: %[[ONE:.+]] = constant 1
+  // CHECK: %[[Y1:.+]] = addi %[[Y0]], %[[ONE]]
+  // CHECK: %[[X1:.+]] = addi %[[X0]], %[[ONE]]
+
+  // Bound check each dimension.
+
+  // CHECK: %[[PRED:.+]] = cmpi slt, %[[Y0]], %[[XYMIN]]
+  // CHECK: %[[BOUND:.+]] = select %[[PRED]], %[[XYMIN]], %[[Y0]]
+  // CHECK: %[[PRED:.+]] = cmpi slt, %[[YMAX]], %[[Y0]]
+  // CHECK: %[[YLO:.+]] = select %[[PRED]], %[[YMAX]], %[[BOUND]]
+
+  // CHECK: %[[PRED:.+]] = cmpi slt, %[[Y1]], %[[XYMIN]]
+  // CHECK: %[[BOUND:.+]] = select %[[PRED]], %[[XYMIN]], %[[Y1]]
+  // CHECK: %[[PRED:.+]] = cmpi slt, %[[YMAX]], %[[Y1]]
+  // CHECK: %[[YHI:.+]] = select %[[PRED]], %[[YMAX]], %[[BOUND]]
+
+  // CHECK: %[[PRED:.+]] = cmpi slt, %[[X0]], %[[XYMIN]]
+  // CHECK: %[[BOUND:.+]] = select %[[PRED]], %[[XYMIN]], %[[X0]]
+  // CHECK: %[[PRED:.+]] = cmpi slt, %[[XMAX]], %[[X0]]
+  // CHECK: %[[XLO:.+]] = select %[[PRED]], %[[XMAX]], %[[BOUND]]
+
+  // CHECK: %[[PRED:.+]] = cmpi slt, %[[X1]], %[[XYMIN]]
+  // CHECK: %[[BOUND:.+]] = select %[[PRED]], %[[XYMIN]], %[[X1]]
+  // CHECK: %[[PRED:.+]] = cmpi slt, %[[XMAX]], %[[X1]]
+  // CHECK: %[[XHI:.+]] = select %[[PRED]], %[[XMAX]], %[[BOUND]]
+
+  // Extract each corner of the bilinear interpolation.
+
+  // CHECK: %[[YLOI:.+]] = index_cast %[[YLO]]
+  // CHECK: %[[YHII:.+]] = index_cast %[[YHI]]
+  // CHECK: %[[XLOI:.+]] = index_cast %[[XLO]]
+  // CHECK: %[[XHII:.+]] = index_cast %[[XHI]]
+
+  // CHECK: %[[LOLO:.+]] = tensor.extract %arg0[%arg1, %[[YLOI]], %[[XLOI]], %arg4]
+  // CHECK: %[[LOHI:.+]] = tensor.extract %arg0[%arg1, %[[YLOI]], %[[XHII]], %arg4]
+  // CHECK: %[[HILO:.+]] = tensor.extract %arg0[%arg1, %[[YHII]], %[[XLOI]], %arg4]
+  // CHECK: %[[HIHI:.+]] = tensor.extract %arg0[%arg1, %[[YHII]], %[[XHII]], %arg4]
+    
+  // Compute the bilinear interpolation.
+
+  // CHECK: %[[ONE:.+]] = constant 1.000000e+00
+  // CHECK: %[[NDX:.+]] = subf %[[ONE]], %[[DX]]
+  // CHECK: %[[WLOLO:.+]] = mulf %[[LOLO]], %[[NDX]]
+  // CHECK: %[[WLOHI:.+]] = mulf %[[LOHI]], %[[DX]]
+  // CHECK: %[[LO:.+]] = addf %[[WLOLO]], %[[WLOHI]]
+  // CHECK: %[[WHILO:.+]] = mulf %[[HILO]], %[[NDX]]
+  // CHECK: %[[WHIHI:.+]] = mulf %[[HIHI]], %[[DX]]
+  // CHECK: %[[HI:.+]] = addf %[[WHILO]], %[[WHIHI]]
+  // CHECK: %[[NDY:.+]] = subf %[[ONE]], %[[DY]]
+  // CHECK: %[[WLO:.+]] = mulf %[[LO]], %[[NDY]]
+  // CHECK: %[[WHI:.+]] = mulf %[[HI]], %[[DY]]
+  // CHECK: %[[RESULT:.+]] = addf %[[WLO]], %[[WHI]]
+  // CHECK: linalg.yield %[[RESULT]]
+  %output = "tosa.resize"(%input) { output_size = [4, 4], stride = [0, 0], offset = [0, 0], stride_fp = [0.5 : f32, 0.5 : f32], offset_fp = [0.1 : f32, 0.2 : f32], shift = 0 : i32, mode = "BILINEAR" } : (tensor<1x2x2x1xf32>)  -> (tensor<1x4x4x1xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @resize_nearest_int
+func @resize_nearest_int(%input: tensor<1x2x2x1xi32>) -> () {
+  // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 4, 4, 1]
+  // CHECK: %[[GENERIC:.+]] = linalg.indexed_generic
+  // CHECK-DAG: %[[XYMIN:.+]] = constant 0
+  // CHECK-DAG: %[[YMAX:.+]] = constant 1
+  // CHECK-DAG: %[[XMAX:.+]] = constant 1
+  // CHECK-DAG: %[[Y:.+]] = index_cast %arg2
+  // CHECK-DAG: %[[X:.+]] = index_cast %arg3
+  // CHECK-DAG: %[[STRIDEY:.+]] = constant 128
+  // CHECK-DAG: %[[STRIDEX:.+]] = constant 128
+  // CHECK-DAG: %[[OFFSETY:.+]] = constant 1
+  // CHECK-DAG: %[[OFFSETX:.+]] = constant 2
+  // CHECK-DAG: %[[EIGHT:.+]] = constant 8
+  // CHECK-DAG: %[[VAL4:.+]] = muli %[[Y]], %[[STRIDEY]]
+  // CHECK-DAG: %[[VAL5:.+]] = muli %[[X]], %[[STRIDEX]]
+  // CHECK-DAG: %[[VAL6:.+]] = addi %[[VAL4]], %[[OFFSETY]]
+  // CHECK-DAG: %[[VAL7:.+]] = addi %[[VAL5]], %[[OFFSETX]]
+    
+  // Find the remainder and integer component of the target index.
+
+
+  // CHECK-DAG: %[[VAL8:.+]] = shift_right_signed %[[VAL6]], %[[EIGHT]]
+  // CHECK-DAG: %[[VAL9:.+]] = shift_right_signed %[[VAL7]], %[[EIGHT]]
+  // CHECK-DAG: %[[VAL10:.+]] = shift_left %[[VAL8]], %[[EIGHT]]
+  // CHECK-DAG: %[[VAL11:.+]] = shift_left %[[VAL9]], %[[EIGHT]]
+  // CHECK-DAG: %[[VAL12:.+]] = subi %[[VAL6]], %[[VAL10]]
+  // CHECK-DAG: %[[VAL13:.+]] = subi %[[VAL7]], %[[VAL11]]
+
+  // Round to the nearest index.
+
+  // CHECK-DAG: %[[ROUND:.+]] = constant 128
+  // CHECK-DAG: %[[VAL16:.+]] = cmpi sge, %[[VAL12]], %[[ROUND]]
+  // CHECK-DAG: %[[VAL17:.+]] = cmpi sge, %[[VAL13]], %[[ROUND]]
+  // CHECK-DAG: %[[ZERO:.+]] = constant 0
+  // CHECK-DAG: %[[ONE:.+]] = constant 1
+  // CHECK-DAG: %[[VAL18:.+]] = select %[[VAL16]], %[[ONE]], %[[ZERO]]
+  // CHECK-DAG: %[[VAL19:.+]] = select %[[VAL17]], %[[ONE]], %[[ZERO]]
+  // CHECK-DAG: %[[VAL20:.+]] = addi %[[VAL8]], %[[VAL18]]
+  // CHECK-DAG: %[[VAL21:.+]] = addi %[[VAL9]], %[[VAL19]]
+
+  // This section applies bound checking to be within the input image.
+
+  // CHECK-DAG: %[[VAL22:.+]] = cmpi slt, %[[VAL20]], %[[XYMIN]]
+  // CHECK-DAG: %[[VAL23:.+]] = select %[[VAL22]], %[[XYMIN]], %[[VAL20]]
+  // CHECK-DAG: %[[VAL24:.+]] = cmpi slt, %[[YMAX]], %[[VAL20]]
+  // CHECK-DAG: %[[VAL25:.+]] = select %[[VAL24]], %[[YMAX]], %[[VAL23]]
+  // CHECK-DAG: %[[VAL26:.+]] = cmpi slt, %[[VAL21]], %[[XYMIN]]
+  // CHECK-DAG: %[[VAL27:.+]] = select %[[VAL26]], %[[XYMIN]], %[[VAL21]]
+  // CHECK-DAG: %[[VAL28:.+]] = cmpi slt, %[[XMAX]], %[[VAL21]]
+  // CHECK-DAG: %[[VAL29:.+]] = select %[[VAL28]], %[[XMAX]], %[[VAL27]]
+
+  // Extract the nearest value using the computed indices.
+
+  // CHECK-DAG: %[[IDY:.+]] = index_cast %[[VAL25]]
+  // CHECK-DAG: %[[IDX:.+]] = index_cast %[[VAL29]]
+  // CHECK: %[[EXTRACT:.+]] = tensor.extract %arg0[%arg1, %[[IDY]], %[[IDX]], %arg4]
+  // CHECK: linalg.yield %[[EXTRACT]]
+  %output = "tosa.resize"(%input) { output_size = [4, 4], stride = [128, 128], offset = [1, 2], stride_fp = [0. : f32, 0. : f32], offset_fp = [0. : f32, 0. : f32], shift = 8 : i32, mode = "NEAREST_NEIGHBOR" } : (tensor<1x2x2x1xi32>)  -> (tensor<1x4x4x1xi32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @resize_bilinear_int
+func @resize_bilinear_int(%input: tensor<1x2x2x1xi8>) -> () {
+  // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 4, 4, 1]
+  // CHECK: %[[GENERIC:.+]] = linalg.indexed_generic
+
+  // CHECK: %[[XYMIN:.+]] = constant 0
+  // CHECK: %[[YMAX:.+]] = constant 1
+  // CHECK: %[[XMAX:.+]] = constant 1
+
+  // CHECK: %[[Y0:.+]] = shift_right_signed
+  // CHECK: %[[X0:.+]] = shift_right_signed
+  // CHECK: %[[ROUNDY:.+]] = shift_left %[[Y0]]
+  // CHECK: %[[ROUNDX:.+]] = shift_left %[[X0]]
+  // CHECK: %[[DY:.+]] = subi %6, %[[ROUNDY]]
+  // CHECK: %[[DX:.+]] = subi %7, %[[ROUNDX]]
+
+  // Compute the left, right, and top indices for the bilinear interpolation.
+
+  // CHECK: %[[ONE:.+]] = constant 1
+  // CHECK: %[[Y1:.+]] = addi %[[Y0]], %[[ONE]]
+  // CHECK: %[[X1:.+]] = addi %[[X0]], %[[ONE]]
+
+  // Bound check each dimension.
+
+  // CHECK: %[[PRED:.+]] = cmpi slt, %[[Y0]], %[[XYMIN]]
+  // CHECK: %[[BOUND:.+]] = select %[[PRED]], %[[XYMIN]], %[[Y0]]
+  // CHECK: %[[PRED:.+]] = cmpi slt, %[[YMAX]], %[[Y0]]
+  // CHECK: %[[YLO:.+]] = select %[[PRED]], %[[YMAX]], %[[BOUND]]
+
+  // CHECK: %[[PRED:.+]] = cmpi slt, %[[Y1]], %[[XYMIN]]
+  // CHECK: %[[BOUND:.+]] = select %[[PRED]], %[[XYMIN]], %[[Y1]]
+  // CHECK: %[[PRED:.+]] = cmpi slt, %[[YMAX]], %[[Y1]]
+  // CHECK: %[[YHI:.+]] = select %[[PRED]], %[[YMAX]], %[[BOUND]]
+
+  // CHECK: %[[PRED:.+]] = cmpi slt, %[[X0]], %[[XYMIN]]
+  // CHECK: %[[BOUND:.+]] = select %[[PRED]], %[[XYMIN]], %[[X0]]
+  // CHECK: %[[PRED:.+]] = cmpi slt, %[[XMAX]], %[[X0]]
+  // CHECK: %[[XLO:.+]] = select %[[PRED]], %[[XMAX]], %[[BOUND]]
+
+  // CHECK: %[[PRED:.+]] = cmpi slt, %[[X1]], %[[XYMIN]]
+  // CHECK: %[[BOUND:.+]] = select %[[PRED]], %[[XYMIN]], %[[X1]]
+  // CHECK: %[[PRED:.+]] = cmpi slt, %[[XMAX]], %[[X1]]
+  // CHECK: %[[XHI:.+]] = select %[[PRED]], %[[XMAX]], %[[BOUND]]
+
+  // Extract each corner of the bilinear interpolation.
+
+  // CHECK: %[[YLOI:.+]] = index_cast %[[YLO]]
+  // CHECK: %[[YHII:.+]] = index_cast %[[YHI]]
+  // CHECK: %[[XLOI:.+]] = index_cast %[[XLO]]
+  // CHECK: %[[XHII:.+]] = index_cast %[[XHI]]
+
+  // CHECK: %[[LOLO:.+]] = tensor.extract %arg0[%arg1, %[[YLOI]], %[[XLOI]], %arg4]
+  // CHECK: %[[LOHI:.+]] = tensor.extract %arg0[%arg1, %[[YLOI]], %[[XHII]], %arg4]
+  // CHECK: %[[HILO:.+]] = tensor.extract %arg0[%arg1, %[[YHII]], %[[XLOI]], %arg4]
+  // CHECK: %[[HIHI:.+]] = tensor.extract %arg0[%arg1, %[[YHII]], %[[XHII]], %arg4]
+
+  // CHECK: %[[XLOLO:.+]] = sexti %[[LOLO]]
+  // CHECK: %[[XLOHI:.+]] = sexti %[[LOHI]]
+  // CHECK: %[[XHILO:.+]] = sexti %[[HILO]]
+  // CHECK: %[[XHIHI:.+]] = sexti %[[HIHI]]
+    
+  // Compute the bilinear interpolation.
+
+  // CHECK: %[[SCALE:.+]] = constant 256
+  // CHECK: %[[NDX:.+]] = subi %[[SCALE]], %[[DX]]
+  // CHECK: %[[WLOLO:.+]] = muli %[[XLOLO]], %[[NDX]]
+  // CHECK: %[[WLOHI:.+]] = muli %[[XLOHI]], %[[DX]]
+  // CHECK: %[[LO:.+]] = addi %[[WLOLO]], %[[WLOHI]]
+  // CHECK: %[[WHILO:.+]] = muli %[[XHILO]], %[[NDX]]
+  // CHECK: %[[WHIHI:.+]] = muli %[[XHIHI]], %[[DX]]
+  // CHECK: %[[HI:.+]] = addi %[[WHILO]], %[[WHIHI]]
+  // CHECK: %[[NDY:.+]] = subi %[[SCALE]], %[[DY]]
+  // CHECK: %[[WLO:.+]] = muli %[[LO]], %[[NDY]]
+  // CHECK: %[[WHI:.+]] = muli %[[HI]], %[[DY]]
+  // CHECK: %[[RESULT:.+]] = addi %[[WLO]], %[[WHI]]
+  // CHECK: linalg.yield %[[RESULT]]
+  %output = "tosa.resize"(%input) { output_size = [4, 4], stride = [128, 128], offset = [1, 2], stride_fp = [0. : f32, 0. : f32], offset_fp = [0. : f32, 0. : f32], shift = 8 : i32, mode = "BILINEAR" } : (tensor<1x2x2x1xi8>)  -> (tensor<1x4x4x1xi32>)
+  return
+}


        


More information about the Mlir-commits mailing list