[Mlir-commits] [mlir] [mlir][tosa] Make Convolution Zero Points Inputs (PR #122939)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 24 12:26:16 PST 2025
================
@@ -220,30 +220,57 @@ static LogicalResult verifyConvOp(T op) {
// All TOSA conv ops have an input() and weight().
auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
- RankedTensorType weightType;
- if constexpr (std::is_same_v<T, tosa::TransposeConv2DOp>)
- weightType = llvm::dyn_cast<RankedTensorType>(op.getFilter().getType());
- else
- weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
+ auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
// Must be ranked tensor types
if (!inputType) {
op.emitOpError("expect a ranked tensor for input, got ") << op.getInput();
return failure();
}
if (!weightType) {
- if constexpr (std::is_same_v<T, tosa::TransposeConv2DOp>) {
- op.emitOpError("expect a ranked tensor for filter, got ")
- << op.getFilter();
- } else {
- op.emitOpError("expect a ranked tensor for weight, got ")
- << op.getWeight();
- }
+ op.emitOpError("expect a ranked tensor for weight, got ") << op.getWeight();
return failure();
}
auto inputEType = inputType.getElementType();
auto weightEType = weightType.getElementType();
+ auto biasEType =
+ llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
+ auto resultEType =
+ llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
+ bool biasIsFloat = llvm::isa<FloatType>(biasEType);
+ bool resultIsFloat = llvm::isa<FloatType>(resultEType);
+
+ if (auto quantType =
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
+ inputEType = quantType.getStorageType();
+
+ if (auto quantType =
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(biasEType))
+ biasEType = quantType.getStorageType();
+
+ if (auto quantType =
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultEType))
+ resultEType = quantType.getStorageType();
+
+ if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
+ // for now, only enforce bias element type == result element type for
+ // float types.
+ op.emitOpError(
+ "expect both bias and result to have same element type, got ")
+ << biasEType << " and " << resultEType;
+ return failure();
+ }
+
+ if (inputEType.isFloat8E5M2() || inputEType.isFloat8E4M3FN() ||
----------------
Jerry-Ge wrote:
`inputEType.isFloat8E5M2()` this API is outdated and won't compile. https://discourse.llvm.org/t/rethink-on-approach-to-low-precision-fp-types/82361/28
https://github.com/llvm/llvm-project/pull/122939
More information about the Mlir-commits
mailing list