[Mlir-commits] [mlir] 46c96fc - [mlir][tosa] Fix quantized type for tosa.conv2d canonicalization

Rob Suderman llvmlistbot at llvm.org
Thu Dec 9 12:46:40 PST 2021


Author: Rob Suderman
Date: 2021-12-09T12:39:23-08:00
New Revision: 46c96fca0e793a71198f54ab2196dd2af030de60

URL: https://github.com/llvm/llvm-project/commit/46c96fca0e793a71198f54ab2196dd2af030de60
DIFF: https://github.com/llvm/llvm-project/commit/46c96fca0e793a71198f54ab2196dd2af030de60.diff

LOG: [mlir][tosa] Fix quantized type for tosa.conv2d canonicalization

Wrong type was used for the result type in the tosa.conv_2d canonicalization.
The type should match the result element type should match the result type
not the input element type.

Differential Revision: https://reviews.llvm.org/D115463

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Dialect/Tosa/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 898bc258469f6..579f0b407ee8e 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -482,7 +482,7 @@ struct Conv2DFullyConnectedOptimization
         inputShape[0] * inputShape[1] * inputShape[2], weightShape[0]};
     auto fullyConnectedShapeType = RankedTensorType::get(
         fullyConnectedShape,
-        weight.getType().dyn_cast<RankedTensorType>().getElementType());
+        resultType.dyn_cast<ShapedType>().getElementType());
 
     Value fullyConnectedValue;
     if (op.quantization_info()) {

diff  --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index a9418be3e632f..7436bfa8ba9ff 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -86,6 +86,25 @@ func @conv2d_as_fully_connected(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<3x1x
 
 // -----
 
+// CHECK-LABEL: @conv2d_as_fully_connected_quant
+func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<3x1x1x2xi8>, %arg2: tensor<3xi32>) -> tensor<4x10x10x3xi32> {
+  // CHECK-NOT: "tosa.conv2d"
+  // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]}
+  // CHECK-SAME: -> tensor<400x2xi8>
+  // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]}
+  // CHECK-SAME: -> tensor<3x2xi8>
+  // CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2)
+  // CHECK-SAME: quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32}
+  // CHECK-SAME: -> tensor<400x3xi32>
+  // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]}
+  // CHECK-SAME: -> tensor<4x10x10x3xi32>
+  // CHECK: return %[[VAR3]]
+  %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32}} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x10x10x3xi32>
+  return %0 : tensor<4x10x10x3xi32>
+}
+
+// -----
+
 // CHECK-LABEL: @conv2d_stride_2
 func @conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>) -> tensor<4x10x10x3xf32> {
   // CHECK: "tosa.conv2d"


        


More information about the Mlir-commits mailing list