[Mlir-commits] [mlir] 8662a2f - [mlir][tosa] Relax ranked constraint on quantization builder

Rob Suderman llvmlistbot at llvm.org
Thu Sep 16 11:49:05 PDT 2021


Author: Rob Suderman
Date: 2021-09-16T11:43:47-07:00
New Revision: 8662a2f2081c2a6bf51a490caa045648c88dd230

URL: https://github.com/llvm/llvm-project/commit/8662a2f2081c2a6bf51a490caa045648c88dd230
DIFF: https://github.com/llvm/llvm-project/commit/8662a2f2081c2a6bf51a490caa045648c88dd230.diff

LOG: [mlir][tosa] Relax ranked constraint on quantization builder

TosaOp defintion had an artificial constraint that the input/output types
needed to be ranked to invoke the quantization builder. This is correct as an
unranked tensor could still be quantized.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D109863

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 4d28d24789d23..483493528b3a2 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -350,8 +350,8 @@ static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
   if (quantAttr) {
     result.addAttribute("quantization_info", quantAttr);
 
-    auto inputType = a.getType().dyn_cast<RankedTensorType>();
-    assert(inputType && "Input must be a ranked tensor type!");
+    auto inputType = a.getType().dyn_cast<ShapedType>();
+    assert(inputType && "Input must be a shaped tensor type!");
 
     auto inputQType = inputType.getElementType()
                           .dyn_cast<mlir::quant::UniformQuantizedType>();
@@ -359,17 +359,15 @@ static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
 
     unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
 
-    auto outputShapedType = outputType.dyn_cast<RankedTensorType>();
-    assert(outputShapedType && "Output must be a ranked tensor type");
-
-    auto outputShape = outputShapedType.getShape();
+    auto outputShapedType = outputType.dyn_cast<ShapedType>();
+    assert(outputShapedType && "Output must be a shaped type");
 
     IntegerType accElementType;
     if (inputBits == 16)
       accElementType = builder.getIntegerType(48);
     else
       accElementType = builder.getI32Type();
-    auto accType = RankedTensorType::get(outputShape, accElementType);
+    auto accType = outputShapedType.clone(accElementType);
     result.addTypes(accType);
   } else {
     result.addTypes(outputType);

diff  --git a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
index af3d2be4ec437..6f21e779b37df 100644
--- a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
@@ -102,8 +102,8 @@ ConvOpQuantizationAttr
 mlir::tosa::buildConvOpQuantizationAttr(OpBuilder &builder, Value input,
                                         Value weight) {
 
-  auto inputType = input.getType().dyn_cast<RankedTensorType>();
-  auto weightType = weight.getType().dyn_cast<RankedTensorType>();
+  auto inputType = input.getType().dyn_cast<ShapedType>();
+  auto weightType = weight.getType().dyn_cast<ShapedType>();
 
   if (!inputType || !weightType)
     return nullptr;
@@ -151,8 +151,8 @@ MatMulOpQuantizationAttr
 mlir::tosa::buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a,
                                           Value b) {
 
-  auto aType = a.getType().dyn_cast<RankedTensorType>();
-  auto bType = b.getType().dyn_cast<RankedTensorType>();
+  auto aType = a.getType().dyn_cast<ShapedType>();
+  auto bType = b.getType().dyn_cast<ShapedType>();
 
   if (!aType || !bType)
     return nullptr;
@@ -187,8 +187,8 @@ UnaryOpQuantizationAttr
 mlir::tosa::buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input,
                                          Type outputRawType) {
 
-  auto inputType = input.getType().dyn_cast<RankedTensorType>();
-  auto outputType = outputRawType.dyn_cast<RankedTensorType>();
+  auto inputType = input.getType().dyn_cast<ShapedType>();
+  auto outputType = outputRawType.dyn_cast<ShapedType>();
 
   if (!inputType || !outputType)
     return nullptr;
@@ -220,7 +220,7 @@ mlir::tosa::buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input,
 PadOpQuantizationAttr mlir::tosa::buildPadOpQuantizationAttr(OpBuilder &builder,
                                                              Value input) {
 
-  auto inputType = input.getType().dyn_cast<RankedTensorType>();
+  auto inputType = input.getType().dyn_cast<ShapedType>();
 
   if (!inputType)
     return nullptr;
@@ -245,8 +245,8 @@ PadOpQuantizationAttr mlir::tosa::buildPadOpQuantizationAttr(OpBuilder &builder,
 Type mlir::tosa::buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType,
                                            Value input, Value weight) {
 
-  auto inputType = input.getType().dyn_cast<RankedTensorType>();
-  auto weightType = weight.getType().dyn_cast<RankedTensorType>();
+  auto inputType = input.getType().dyn_cast<ShapedType>();
+  auto weightType = weight.getType().dyn_cast<ShapedType>();
 
   assert(inputType && weightType &&
          "Could not extract input or weight tensors from Conv op");
@@ -260,18 +260,16 @@ Type mlir::tosa::buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType,
   unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
   unsigned weightBits = weightQType.getStorageTypeIntegralWidth();
 
-  auto outputShapedType = outputType.dyn_cast<RankedTensorType>();
+  auto outputShapedType = outputType.dyn_cast<ShapedType>();
   assert(outputShapedType &&
          "Could not extract output shape type from Conv op");
 
-  auto outputShape = outputShapedType.getShape();
-
   IntegerType accElementType;
   if (inputBits == 16 && weightBits == 8)
     accElementType = builder.getIntegerType(48);
   else
     accElementType = builder.getI32Type();
-  auto accType = RankedTensorType::get(outputShape, accElementType);
+  auto accType = outputShapedType.clone(accElementType);
   return accType;
 }
 


        


More information about the Mlir-commits mailing list