[Mlir-commits] [mlir] b37a031 - [mlir][tosa] Make tosa.resize to linalg avoid redundant loads for unit width

Rob Suderman llvmlistbot at llvm.org
Thu Dec 15 16:23:50 PST 2022


Author: Rob Suderman
Date: 2022-12-15T16:22:46-08:00
New Revision: b37a0318cb026fe30a76482e36f66f8a3e61e055

URL: https://github.com/llvm/llvm-project/commit/b37a0318cb026fe30a76482e36f66f8a3e61e055
DIFF: https://github.com/llvm/llvm-project/commit/b37a0318cb026fe30a76482e36f66f8a3e61e055.diff

LOG: [mlir][tosa] Make tosa.resize to linalg avoid redundant loads for unit width

When using a tosa resize for ?x1x1x? to ?x1x?x? we should avoid doing a 2D
interpolation as only two unique values are loaded. As the extract operation
performance numerical computation on its values the superfluous extracts may
fail to be coalesced. Instead we only interpolate between the values if there
are multiple values to interpolate between.

For the integer case we also perform scaling by the scaling-factor to apply
the same integer scaling behavior as interpolation.

Reviewed By: jpienaar, NatashaKnk

Differential Revision: https://reviews.llvm.org/D139979

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index d704b5e040916..ebc63bdff3dca 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1468,7 +1468,10 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
       Value x = b.create<linalg::IndexOp>(2);
       Value channel = b.create<linalg::IndexOp>(3);
 
-      Value zeroI32 = b.create<arith::ConstantOp>(b.getI32IntegerAttr(0));
+      Value zeroI32 =
+          b.create<arith::ConstantOp>(b.getZeroAttr(b.getI32Type()));
+      Value zeroFp32 =
+          b.create<arith::ConstantOp>(b.getZeroAttr(b.getF32Type()));
       Value hMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageH - 1));
       Value wMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageW - 1));
 
@@ -1498,6 +1501,11 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
       auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in,
                                     Value scaleN, Value scaleD, Value offset,
                                     int size, ImplicitLocOpBuilder &b) {
+        if (size == 1) {
+          index = zeroI32;
+          delta = zeroFp32;
+          return;
+        }
         // x = x * scale_d + offset;
         // ix = floor(x / scale_n)
         // dx = x / scale_n - ix
@@ -1517,6 +1525,11 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
       auto getIndexAndDeltaInt = [&](Value &index, Value &delta, Value in,
                                      Value scaleN, Value scaleD, Value offset,
                                      int size, ImplicitLocOpBuilder &b) {
+        if (size == 1) {
+          index = zeroI32;
+          delta = zeroI32;
+          return;
+        }
         // x = x * scale_d + offset;
         // ix = floor(x / scale_n)
         //  dx = x - ix * scale_n;
@@ -1606,7 +1619,10 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
         if (floatingPointMode) {
           auto oneVal = b.create<arith::ConstantOp>(b.getF32FloatAttr(1.0f));
           auto interpolate = [&](Value val0, Value val1, Value delta,
+                                 int inputSize,
                                  ImplicitLocOpBuilder &b) -> Value {
+            if (inputSize == 1)
+              return val0;
             Value oneMinusDelta = b.create<arith::SubFOp>(oneVal, delta);
             Value mul0 = b.create<arith::MulFOp>(val0, oneMinusDelta);
             Value mul1 = b.create<arith::MulFOp>(val1, delta);
@@ -1616,16 +1632,16 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
           // Linalg equivalent to the section below:
           //   topAcc = v00 * (unit_x - dx);
           //   topAcc += v01 * dx;
-          Value topAcc = interpolate(y0x0, y0x1, dx, b);
+          Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b);
 
           // Linalg equivalent to the section below:
           //   bottomAcc = v10 * (unit_x - dx);
           //   bottomAcc += v11 * dx;
-          Value bottomAcc = interpolate(y1x0, y1x1, dx, b);
+          Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b);
 
           // Linalg equivalent to the section below:
           //   result = topAcc * (unit_y - dy) + bottomAcc * dy
-          Value result = interpolate(topAcc, bottomAcc, dy, b);
+          Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);
           b.create<linalg::YieldOp>(result);
         } else {
           // Perform in quantized space.
@@ -1650,22 +1666,21 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
             xScaleNExt = b.create<arith::ExtSIOp>(resultETy, xScaleN);
           }
 
-          auto interpolate = [](Value val0, Value val1, Value weight0,
-                                Value weight1,
+          auto interpolate = [](Value val0, Value val1, Value weight1,
+                                Value scale, int inputSize,
                                 ImplicitLocOpBuilder &b) -> Value {
+            if (inputSize == 1)
+              return b.create<arith::MulIOp>(val0, scale);
+            Value weight0 = b.create<arith::SubIOp>(scale, weight1);
             Value mul0 = b.create<arith::MulIOp>(val0, weight0);
             Value mul1 = b.create<arith::MulIOp>(val1, weight1);
             return b.create<arith::AddIOp>(mul0, mul1);
           };
 
-          Value weight0 = b.create<arith::SubIOp>(xScaleNExt, dx);
-          Value weight1 = dx;
-          Value topAcc = interpolate(y0x0, y0x1, weight0, weight1, b);
-          Value bottomAcc = interpolate(y1x0, y1x1, weight0, weight1, b);
-
-          weight0 = b.create<arith::SubIOp>(yScaleNExt, dy);
-          weight1 = dy;
-          Value result = interpolate(topAcc, bottomAcc, weight0, weight1, b);
+          Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);
+          Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);
+          Value result =
+              interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);
           b.create<linalg::YieldOp>(result);
         }
       }

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir
index 3aa6d2aac7623..382dae542ea9a 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir
@@ -278,6 +278,7 @@ func.func @resize_bilinear_int(%arg0: tensor<1x19x20x1xi8>) {
   // CHECK: %[[WLOLO:.+]] = arith.muli %[[XLOLO]], %[[NDX]]
   // CHECK: %[[WLOHI:.+]] = arith.muli %[[XLOHI]], %[[D_X_EXT]]
   // CHECK: %[[LO:.+]] = arith.addi %[[WLOLO]], %[[WLOHI]]
+  // CHECK: %[[NDX:.+]] = arith.subi %[[X_N_EXT]], %[[D_X_EXT]]
   // CHECK: %[[WHILO:.+]] = arith.muli %[[XHILO]], %[[NDX]]
   // CHECK: %[[WHIHI:.+]] = arith.muli %[[XHIHI]], %[[D_X_EXT]]
   // CHECK: %[[HI:.+]] = arith.addi %[[WHILO]], %[[WHIHI]]
@@ -492,3 +493,47 @@ func.func @resize_bilinear_int48(%arg0: tensor<1x19x19x1xi16>) {
   %0 = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = [16, 1, 16, 1], offset = [0, 0], border = [0, 0]} : (tensor<1x19x19x1xi16>) -> tensor<1x289x289x1xi48>
            return
 }
+
+// -----
+
+// CHECK-LABEL: skip_interpolate_bilinear_i8
+func.func @skip_interpolate_bilinear_i8(%arg0 : tensor<3x1x2x7xi8>) -> tensor<3x1x5x7xi32> {
+  // CHECK:  %[[GENERIC:.+]] = linalg.generic
+  // CHECK:    %[[BATCH:.+]] = linalg.index 0
+  // CHECK:    %[[CHANNEL:.+]] = linalg.index 3
+  // CHECK-DAG:    %[[C3:.+]] = arith.constant 3
+  // CHECK-DAG:    %[[C2:.+]] = arith.constant 2
+  // CHECK:    %[[EXTRACT0:.+]] = tensor.extract %arg0[%[[BATCH]], %{{.+}}, %{{.+}}, %[[CHANNEL]]] : tensor<3x1x2x7xi8>
+  // CHECK:    %[[EXTRACT1:.+]] = tensor.extract %arg0[%[[BATCH]], %{{.+}}, %{{.+}}, %[[CHANNEL]]] : tensor<3x1x2x7xi8>
+  // CHECK:    %[[EXT0:.+]] = arith.extsi %[[EXTRACT0]] : i8 to i32
+  // CHECK:    %[[EXT1:.+]] = arith.extsi %[[EXTRACT1]] : i8 to i32
+  // CHECK:    %[[SUB:.+]] = arith.subi %[[C3]], %[[DX:.+]]
+  // CHECK:    %[[MUL0:.+]] = arith.muli %[[EXT0]], %[[SUB]]
+  // CHECK:    %[[MUL1:.+]] = arith.muli %[[EXT1]], %[[DX]]
+  // CHECK:    %[[ADD:.+]] = arith.addi %[[MUL0]], %[[MUL1]]
+  // CHECK:    %[[RES:.+]] = arith.muli %[[ADD]], %[[C2]]
+  // CHECK:    linalg.yield %[[RES]]
+  %resize = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = [2, 1, 3, 1], offset = [0, 0], border = [0, 0]} : (tensor<3x1x2x7xi8>) -> tensor<3x1x5x7xi32>
+
+  // CHECK:  return %[[GENERIC]]
+  return %resize : tensor<3x1x5x7xi32>
+}
+
+// CHECK-LABEL: skip_interpolate_bilinear_f32
+func.func @skip_interpolate_bilinear_f32(%arg0 : tensor<3x1x2x7xf32>) -> tensor<3x1x5x7xf32> {
+  // CHECK:  %[[GENERIC:.+]] = linalg.generic
+  // CHECK:    %[[BATCH:.+]] = linalg.index 0 : index
+  // CHECK:    %[[CHANNEL:.+]] = linalg.index 3 : index
+  // CHECK:    %[[EXTRACT0:.+]] = tensor.extract %arg0[%[[BATCH]], %{{.+}}, %{{.+}}, %[[CHANNEL]]] : tensor<3x1x2x7xf32>
+  // CHECK:    %[[EXTRACT1:.+]] = tensor.extract %arg0[%[[BATCH]], %{{.+}}, %{{.+}}, %[[CHANNEL]]] : tensor<3x1x2x7xf32>
+  // CHECK:    %[[C1:.+]] = arith.constant 1.000000e+00
+  // CHECK:    %[[SUB:.+]] = arith.subf %[[C1]], %[[DX:.+]]
+  // CHECK:    %[[MUL0:.+]] = arith.mulf %[[EXTRACT0]], %[[SUB]]
+  // CHECK:    %[[MUL1:.+]] = arith.mulf %[[EXTRACT1]], %[[DX]]
+  // CHECK:    %[[ADD:.+]] = arith.addf %[[MUL0]], %[[MUL1]]
+  // CHECK:    linalg.yield %[[ADD]]
+  %resize = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = [2, 1, 3, 1], offset = [0, 0], border = [0, 0]} : (tensor<3x1x2x7xf32>) -> tensor<3x1x5x7xf32>
+
+  // CHECK:  return %[[GENERIC]]
+  return %resize : tensor<3x1x5x7xf32>
+}


        


More information about the Mlir-commits mailing list