[Mlir-commits] [mlir] 8662a2f - [mlir][tosa] Relax ranked constraint on quantization builder
Rob Suderman
llvmlistbot at llvm.org
Thu Sep 16 11:49:05 PDT 2021
Author: Rob Suderman
Date: 2021-09-16T11:43:47-07:00
New Revision: 8662a2f2081c2a6bf51a490caa045648c88dd230
URL: https://github.com/llvm/llvm-project/commit/8662a2f2081c2a6bf51a490caa045648c88dd230
DIFF: https://github.com/llvm/llvm-project/commit/8662a2f2081c2a6bf51a490caa045648c88dd230.diff
LOG: [mlir][tosa] Relax ranked constraint on quantization builder
TosaOp defintion had an artificial constraint that the input/output types
needed to be ranked to invoke the quantization builder. This is correct as an
unranked tensor could still be quantized.
Reviewed By: NatashaKnk
Differential Revision: https://reviews.llvm.org/D109863
Added:
Modified:
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 4d28d24789d23..483493528b3a2 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -350,8 +350,8 @@ static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
if (quantAttr) {
result.addAttribute("quantization_info", quantAttr);
- auto inputType = a.getType().dyn_cast<RankedTensorType>();
- assert(inputType && "Input must be a ranked tensor type!");
+ auto inputType = a.getType().dyn_cast<ShapedType>();
+ assert(inputType && "Input must be a shaped tensor type!");
auto inputQType = inputType.getElementType()
.dyn_cast<mlir::quant::UniformQuantizedType>();
@@ -359,17 +359,15 @@ static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
- auto outputShapedType = outputType.dyn_cast<RankedTensorType>();
- assert(outputShapedType && "Output must be a ranked tensor type");
-
- auto outputShape = outputShapedType.getShape();
+ auto outputShapedType = outputType.dyn_cast<ShapedType>();
+ assert(outputShapedType && "Output must be a shaped type");
IntegerType accElementType;
if (inputBits == 16)
accElementType = builder.getIntegerType(48);
else
accElementType = builder.getI32Type();
- auto accType = RankedTensorType::get(outputShape, accElementType);
+ auto accType = outputShapedType.clone(accElementType);
result.addTypes(accType);
} else {
result.addTypes(outputType);
diff --git a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
index af3d2be4ec437..6f21e779b37df 100644
--- a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
@@ -102,8 +102,8 @@ ConvOpQuantizationAttr
mlir::tosa::buildConvOpQuantizationAttr(OpBuilder &builder, Value input,
Value weight) {
- auto inputType = input.getType().dyn_cast<RankedTensorType>();
- auto weightType = weight.getType().dyn_cast<RankedTensorType>();
+ auto inputType = input.getType().dyn_cast<ShapedType>();
+ auto weightType = weight.getType().dyn_cast<ShapedType>();
if (!inputType || !weightType)
return nullptr;
@@ -151,8 +151,8 @@ MatMulOpQuantizationAttr
mlir::tosa::buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a,
Value b) {
- auto aType = a.getType().dyn_cast<RankedTensorType>();
- auto bType = b.getType().dyn_cast<RankedTensorType>();
+ auto aType = a.getType().dyn_cast<ShapedType>();
+ auto bType = b.getType().dyn_cast<ShapedType>();
if (!aType || !bType)
return nullptr;
@@ -187,8 +187,8 @@ UnaryOpQuantizationAttr
mlir::tosa::buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input,
Type outputRawType) {
- auto inputType = input.getType().dyn_cast<RankedTensorType>();
- auto outputType = outputRawType.dyn_cast<RankedTensorType>();
+ auto inputType = input.getType().dyn_cast<ShapedType>();
+ auto outputType = outputRawType.dyn_cast<ShapedType>();
if (!inputType || !outputType)
return nullptr;
@@ -220,7 +220,7 @@ mlir::tosa::buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input,
PadOpQuantizationAttr mlir::tosa::buildPadOpQuantizationAttr(OpBuilder &builder,
Value input) {
- auto inputType = input.getType().dyn_cast<RankedTensorType>();
+ auto inputType = input.getType().dyn_cast<ShapedType>();
if (!inputType)
return nullptr;
@@ -245,8 +245,8 @@ PadOpQuantizationAttr mlir::tosa::buildPadOpQuantizationAttr(OpBuilder &builder,
Type mlir::tosa::buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType,
Value input, Value weight) {
- auto inputType = input.getType().dyn_cast<RankedTensorType>();
- auto weightType = weight.getType().dyn_cast<RankedTensorType>();
+ auto inputType = input.getType().dyn_cast<ShapedType>();
+ auto weightType = weight.getType().dyn_cast<ShapedType>();
assert(inputType && weightType &&
"Could not extract input or weight tensors from Conv op");
@@ -260,18 +260,16 @@ Type mlir::tosa::buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType,
unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
unsigned weightBits = weightQType.getStorageTypeIntegralWidth();
- auto outputShapedType = outputType.dyn_cast<RankedTensorType>();
+ auto outputShapedType = outputType.dyn_cast<ShapedType>();
assert(outputShapedType &&
"Could not extract output shape type from Conv op");
- auto outputShape = outputShapedType.getShape();
-
IntegerType accElementType;
if (inputBits == 16 && weightBits == 8)
accElementType = builder.getIntegerType(48);
else
accElementType = builder.getI32Type();
- auto accType = RankedTensorType::get(outputShape, accElementType);
+ auto accType = outputShapedType.clone(accElementType);
return accType;
}
More information about the Mlir-commits
mailing list