[Mlir-commits] [mlir] 95e4b71 - [mlir][tosa] Fix tosa average_pool2d to linalg type issue

Rob Suderman llvmlistbot at llvm.org
Tue Oct 12 13:11:22 PDT 2021


Author: Rob Suderman
Date: 2021-10-12T13:09:21-07:00
New Revision: 95e4b71519e6621a132252b462b9bf9fce63ff61

URL: https://github.com/llvm/llvm-project/commit/95e4b71519e6621a132252b462b9bf9fce63ff61
DIFF: https://github.com/llvm/llvm-project/commit/95e4b71519e6621a132252b462b9bf9fce63ff61.diff

LOG: [mlir][tosa] Fix tosa average_pool2d to linalg type issue

Average pool assumed the same input/output type. Result type for integers
is always an i32, should be updated appropriately.

Reviewed By: GMNGeoffrey

Differential Revision: https://reviews.llvm.org/D111590

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
    mlir/test/Dialect/Tosa/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index cc8c8fab56e5b..b57e8b2fb8cbb 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -82,6 +82,8 @@ def Tosa_AvgPool2dOp : Tosa_Op<"avg_pool2d", [
   );
 
   let builders = [Tosa_AvgPool2dOpQuantInfoBuilder];
+
+  let verifier = [{ return verifyAveragePoolOp(*this); }];
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index f24a849810e2c..bd1769d1d1baf 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -2796,7 +2796,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
     Type inElementTy = inputTy.getElementType();
 
     ShapedType resultTy = op.getType().template cast<ShapedType>();
-    Type resultETy = inputTy.getElementType();
+    Type resultETy = op.getType().cast<ShapedType>().getElementType();
 
     Type accETy =
         inElementTy.isa<IntegerType>() ? rewriter.getI32Type() : inElementTy;
@@ -2810,9 +2810,10 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
     pad.resize(2, 0);
     getValuesFromIntArrayAttribute(op.pad(), pad);
     pad.resize(pad.size() + 2, 0);
-    Attribute initialAttr = rewriter.getZeroAttr(accETy);
-    Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter);
+    Attribute padAttr = rewriter.getZeroAttr(inElementTy);
+    Value paddedInput = applyPad(loc, input, pad, padAttr, rewriter);
 
+    Attribute initialAttr = rewriter.getZeroAttr(accETy);
     Value initialValue = rewriter.create<ConstantOp>(loc, initialAttr);
 
     SmallVector<int64_t> kernel, stride;
@@ -2909,8 +2910,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
           // to be applied.
           Value poolVal = args[0];
           if (accETy.isa<FloatType>()) {
-            auto countF =
-                rewriter.create<mlir::SIToFPOp>(loc, inElementTy, countI);
+            auto countF = rewriter.create<mlir::SIToFPOp>(loc, accETy, countI);
             poolVal =
                 rewriter.create<DivFOp>(loc, poolVal, countF)->getResult(0);
           } else {
@@ -2974,8 +2974,11 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
             auto clamp = clampHelper<mlir::CmpIOp>(
                 loc, scaled, min, max, CmpIPredicate::slt, rewriter);
 
+            poolVal = clamp;
             // Convert type.
-            poolVal = rewriter.create<TruncateIOp>(loc, resultETy, clamp);
+            if (resultETy != clamp.getType()) {
+              poolVal = rewriter.create<TruncateIOp>(loc, resultETy, poolVal);
+            }
           }
 
           // Cast to output type.

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 3a025437bc05b..9c8d4ac54a848 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -342,6 +342,26 @@ static LogicalResult verifyConvOp(T op) {
   return success();
 }
 
+static LogicalResult verifyAveragePoolOp(tosa::AvgPool2dOp op) {
+  auto inputETy = op.input().getType().cast<ShapedType>().getElementType();
+  auto resultETy = op.getType().cast<ShapedType>().getElementType();
+
+  if (auto quantType = inputETy.dyn_cast<mlir::quant::UniformQuantizedType>())
+    inputETy = quantType.getStorageType();
+
+  if (auto quantType = resultETy.dyn_cast<mlir::quant::UniformQuantizedType>())
+    resultETy = quantType.getStorageType();
+
+  if (inputETy.isF32() && resultETy.isF32())
+    return success();
+  if (inputETy.isInteger(8) && resultETy.isInteger(32))
+    return success();
+  if (inputETy.isInteger(16) && resultETy.isInteger(32))
+    return success();
+
+  return op.emitOpError("input/output element types are incompatible.");
+}
+
 //===----------------------------------------------------------------------===//
 // TOSA Operator Quantization Builders.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 1c81a2aa731cd..df6677281baa4 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1465,15 +1465,14 @@ func @avg_pool_i8(%arg0 : tensor<1x128x128x2xi8>) -> () {
   // CHECK: %[[SCALE:.+]] = "tosa.apply_scale"(%{{.+}}, %[[MULTIPLIER]], %[[SHIFT]]) {double_round = false}
   // CHECK: %[[OUTZP:.+]] = constant -128
   // CHECK: %[[OUT:.+]] = addi %[[SCALE]], %[[OUTZP]]
-  // CHECK: %[[MIN:.+]] = constant -128
-  // CHECK: %[[MAX:.+]] = constant 127
+  // CHECK: %[[MIN:.+]] = constant -2147483648
+  // CHECK: %[[MAX:.+]] = constant 2147483647
   // CHECK: %[[CMP_MIN:.+]] = cmpi slt, %[[OUT]], %[[MIN]]
   // CHECK: %[[CLMP_MIN:.+]] = select %[[CMP_MIN]], %[[MIN]], %[[OUT]]
   // CHECK: %[[CMP_MAX:.+]] = cmpi slt, %[[MAX]], %[[OUT]]
   // CHECK: %[[CLMP_MAX:.+]] = select %[[CMP_MAX]], %[[MAX]], %[[CLMP_MIN]]
-  // CHECK: %[[TRUNC:.+]] = trunci %[[CLMP_MAX]]
-  // CHECK: linalg.yield %[[TRUNC]]
-  %0 = "tosa.avg_pool2d"(%arg0) {kernel = [4, 4], pad = [0, 0, 0, 0], quantization_info = {input_zp = -128 : i32, output_zp = -128 : i32}, stride = [4, 4]} : (tensor<1x128x128x2xi8>) -> tensor<1x32x32x2xi8>
+  // CHECK: linalg.yield %[[CLMP_MAX]]
+  %0 = "tosa.avg_pool2d"(%arg0) {kernel = [4, 4], pad = [0, 0, 0, 0], quantization_info = {input_zp = -128 : i32, output_zp = -128 : i32}, stride = [4, 4]} : (tensor<1x128x128x2xi8>) -> tensor<1x32x32x2xi32>
   return
 }
 

diff  --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index ec169d0e16ebf..df79ebb89c3e7 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -10,12 +10,33 @@ func @test_argmax(%arg0: tensor<14x19xf32>) -> tensor<14xi32> {
 }
 
 // -----
-// CHECK-LABEL: avg_pool2d
-func @test_avg_pool2d(%arg0: tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> {
+// CHECK-LABEL: avg_pool2d_f32
+func @test_avg_pool2d_f32(%arg0: tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> {
     %0 = "tosa.avg_pool2d"(%arg0) {kernel = [2, 2], pad = [0, 1, 0, 1], stride = [1, 1]} : (tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32>
     return %0 : tensor<1x7x7x9xf32>
 }
 
+// -----
+// CHECK-LABEL: avg_pool2d_i8
+func @test_avg_pool2d_i8(%arg0: tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi32> {
+    %0 = "tosa.avg_pool2d"(%arg0) {kernel = [2, 2], pad = [0, 1, 0, 1], stride = [1, 1]} : (tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi32>
+    return %0 : tensor<1x7x7x9xi32>
+}
+
+// -----
+// CHECK-LABEL: avg_pool2d_i16
+func @test_avg_pool2d_i16(%arg0: tensor<1x7x7x9xi16>) -> tensor<1x7x7x9xi32> {
+    %0 = "tosa.avg_pool2d"(%arg0) {kernel = [2, 2], pad = [0, 1, 0, 1], stride = [1, 1]} : (tensor<1x7x7x9xi16>) -> tensor<1x7x7x9xi32>
+    return %0 : tensor<1x7x7x9xi32>
+}
+
+// -----
+// CHECK-LABEL: avg_pool2d_q8
+func @test_avg_pool2d_q8(%arg0: tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i32:f32, 0.01>> {
+    %0 = "tosa.avg_pool2d"(%arg0) {kernel = [2, 2], pad = [0, 1, 0, 1], stride = [1, 1]} : (tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i32:f32, 0.01>>
+    return %0 : tensor<1x7x7x9x!quant.uniform<i32:f32, 0.01>>
+}
+
 // -----
 // CHECK-LABEL: conv2d
 func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> {


        


More information about the Mlir-commits mailing list