[Mlir-commits] [mlir] 1a8d2a4 - [mlir][tosa] Use generic QuantizedType in Conv verifiers (#126275)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 11 02:58:19 PST 2025


Author: Tai Ly
Date: 2025-02-11T10:58:16Z
New Revision: 1a8d2a40167ea752366fa718a98cafdccce002a5

URL: https://github.com/llvm/llvm-project/commit/1a8d2a40167ea752366fa718a98cafdccce002a5
DIFF: https://github.com/llvm/llvm-project/commit/1a8d2a40167ea752366fa718a98cafdccce002a5.diff

LOG: [mlir][tosa] Use generic QuantizedType in Conv verifiers  (#126275)

Replace UniformQuantizedType by the more generic QuantizedType in Conv verifiers.

Change-Id: Ie1961af931864f801914a62976bc988881ee075e

Signed-off-by: Tai Ly <tai.ly at arm.com>
Co-authored-by: Thibaut Goetghebuer-Planchon <thibaut.goetghebuer-planchon at arm.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Dialect/Tosa/invalid.mlir
    mlir/test/Dialect/Tosa/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index fd166cc1322cef9..af4a5dc96265ed5 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -240,16 +240,13 @@ static LogicalResult verifyConvOp(T op) {
   bool biasIsFloat = llvm::isa<FloatType>(biasEType);
   bool resultIsFloat = llvm::isa<FloatType>(resultEType);
 
-  if (auto quantType =
-          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
+  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
     inputEType = quantType.getStorageType();
 
-  if (auto quantType =
-          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(biasEType))
+  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
     biasEType = quantType.getStorageType();
 
-  if (auto quantType =
-          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultEType))
+  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
     resultEType = quantType.getStorageType();
 
   if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
@@ -346,8 +343,7 @@ static LogicalResult verifyConvOpModes(T op) {
   auto inputEType =
       llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
 
-  if (auto quantType =
-          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
+  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
     inputEType = quantType.getStorageType();
 
   auto accType = op.getAccType();
@@ -369,7 +365,23 @@ static LogicalResult verifyConvOpModes(T op) {
   if (inputEType.isF32() && !accType.isF32())
     return op.emitOpError("accumulator type for f32 tensor is not f32");
 
-  return success();
+  auto resultEType =
+      llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
+
+  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
+    resultEType = quantType.getStorageType();
+
+  // check allowed input/result element types combinations
+  if ((inputEType.isInteger(8) && resultEType.isInteger(32)) ||
+      (inputEType.isInteger(16) && resultEType.isInteger(48)) ||
+      (isa<Float8E5M2Type>(inputEType) && resultEType.isF16()) ||
+      (isa<Float8E4M3FNType>(inputEType) && resultEType.isF16()) ||
+      (inputEType.isF16() && resultEType.isF16()) ||
+      (inputEType.isBF16() && resultEType.isBF16()) ||
+      (inputEType.isF32() && resultEType.isF32()))
+    return success();
+
+  return op.emitOpError("input/output element types are incompatible.");
 }
 
 // verify that inType and outType have same element types

diff  --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 2165e1f7ae3babb..20fc10d77d0e0c4 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -144,6 +144,24 @@ func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xi8>, %arg1: tensor<16x1
   return %0 : tensor<1x32x32x16xi8>
 }
 
+// -----
+// CHECK-LABEL: conv2d_quant_any_acc
+func.func @test_conv2d_quant_any_acc(%arg0: tensor<1x4x4x4x!quant.any<i8<-8:7>>>, %arg1: tensor<8x1x1x4x!quant.any<i8<-8:7>>>, %arg2: tensor<8x!quant.any<i8<-8:7>>>) -> tensor<1x4x4x8x!quant.any<i8<-8:7>>> {
+  %zp = "tosa.const" () { value = dense<0> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.conv2d' op accumulator type for i8 tensor is not i32}}
+  %0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4x!quant.any<i8<-8:7>>>, tensor<8x1x1x4x!quant.any<i8<-8:7>>>, tensor<8x!quant.any<i8<-8:7>>>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x4x8x!quant.any<i8<-8:7>>>
+  return %0 : tensor<1x4x4x8x!quant.any<i8<-8:7>>>
+}
+
+// -----
+// CHECK-LABEL: conv2d_quant_any_result
+func.func @test_conv2d_quant_any_result(%arg0: tensor<1x4x4x4x!quant.any<i8<-8:7>>>, %arg1: tensor<8x1x1x4x!quant.any<i8<-8:7>>>, %arg2: tensor<8x!quant.any<i8<-8:7>>>) -> tensor<1x4x4x8x!quant.any<i8<-8:7>>> {
+  %zp = "tosa.const" () { value = dense<0> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.conv2d' op input/output element types are incompatible}}
+  %0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4x!quant.any<i8<-8:7>>>, tensor<8x1x1x4x!quant.any<i8<-8:7>>>, tensor<8x!quant.any<i8<-8:7>>>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x4x8x!quant.any<i8<-8:7>>>
+  return %0 : tensor<1x4x4x8x!quant.any<i8<-8:7>>>
+}
+
 // -----
 
 func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xf32> {

diff  --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index baf09e089aa30fb..d7e4f682c28b3cc 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -58,6 +58,22 @@ func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %
   return %0 : tensor<1x4x4x8xf32>
 }
 
+// -----
+// CHECK-LABEL: conv2d_quant_uniform
+func.func @test_conv2d_quant_uniform(%arg0: tensor<1x4x4x4x!quant.uniform<i8:f32, 0.01>>, %arg1: tensor<8x1x1x4x!quant.uniform<i8:f32, 0.01>>, %arg2: tensor<8x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x4x4x8x!quant.uniform<i32:f32, 0.01>> {
+  %zp = "tosa.const" () { value = dense<0> : tensor<1xi8> } : () -> tensor<1xi8>
+  %0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4x!quant.uniform<i8:f32, 0.01>>, tensor<8x1x1x4x!quant.uniform<i8:f32, 0.01>>, tensor<8x!quant.uniform<i8:f32, 0.01>>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x4x8x!quant.uniform<i32:f32, 0.01>>
+  return %0 : tensor<1x4x4x8x!quant.uniform<i32:f32, 0.01>>
+}
+
+// -----
+// CHECK-LABEL: conv2d_quant_any
+func.func @test_conv2d_quant_any(%arg0: tensor<1x4x4x4x!quant.any<i8<-8:7>>>, %arg1: tensor<8x1x1x4x!quant.any<i8<-8:7>>>, %arg2: tensor<8x!quant.any<i8<-8:7>>>) -> tensor<1x4x4x8x!quant.any<i32<-8:7>>> {
+  %zp = "tosa.const" () { value = dense<0> : tensor<1xi8> } : () -> tensor<1xi8>
+  %0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4x!quant.any<i8<-8:7>>>, tensor<8x1x1x4x!quant.any<i8<-8:7>>>, tensor<8x!quant.any<i8<-8:7>>>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x4x8x!quant.any<i32<-8:7>>>
+  return %0 : tensor<1x4x4x8x!quant.any<i32<-8:7>>>
+}
+
 // -----
 // CHECK-LABEL: conv2d_q8xi4
 func.func @test_conv2d_q8xi4(%arg0: tensor<1x11x11x3xi8>) -> tensor<1x1x1x3xi8> {


        


More information about the Mlir-commits mailing list