[Mlir-commits] [mlir] e9fae0f - [mlir][tosa] Disable tosa.depthwise_conv2d canonicalizer for quantized case

Rob Suderman llvmlistbot at llvm.org
Tue Dec 7 10:17:29 PST 2021


Author: Rob Suderman
Date: 2021-12-07T10:16:12-08:00
New Revision: e9fae0f19eec1fce746101b410d2345f0fbf89b4

URL: https://github.com/llvm/llvm-project/commit/e9fae0f19eec1fce746101b410d2345f0fbf89b4
DIFF: https://github.com/llvm/llvm-project/commit/e9fae0f19eec1fce746101b410d2345f0fbf89b4.diff

LOG: [mlir][tosa] Disable tosa.depthwise_conv2d canonicalizer for quantized case

Quantized case needs to include zero-point corrections before the tosa.mul.
Disabled for the quantized use-case.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D115264

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Dialect/Tosa/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index cefe13f57dbba..601e66006d6ff 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -526,12 +526,18 @@ struct DepthwiseConv2DMulOptimization
     ShapedType inputType = input.getType().cast<ShapedType>();
     ShapedType weightType = weight.getType().cast<ShapedType>();
     ShapedType resultType = op.output().getType().cast<ShapedType>();
+    Type inputEType = inputType.getElementType();
 
     if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
           resultType.hasStaticShape())) {
       return failure();
     }
 
+    // Quantization information needs to still be performed.
+    if (op.quantization_info() || !inputEType.isa<FloatType>()) {
+      return failure();
+    }
+
     // Stride must be 1 for this optimization.
     for (Attribute stride : op.stride().getValue()) {
       if (!stride.cast<IntegerAttr>().getValue().isOne()) {

diff  --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index ed659ee91964d..a9418be3e632f 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -128,6 +128,15 @@ func @depthwise_conv2d_as_mul(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x
 
 // -----
 
+// CHECK-LABEL: @depthwise_conv2d_as_mul_q
+func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<1x1x2x3xi8>, %arg2: tensor<6xi32>) -> tensor<4x10x10x6xi32> {
+  // CHECK: "tosa.depthwise_conv2d"
+  %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 0 : i32, weight_zp = 0 : i32}} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32>
+  return %0 : tensor<4x10x10x6xi32>
+}
+
+// -----
+
 // CHECK-LABEL: @depthwise_conv2d_stride_2
 func @depthwise_conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> {
   // CHECK: "tosa.depthwise_conv2d"


        


More information about the Mlir-commits mailing list