[Mlir-commits] [mlir] [mlir][tosa] Make Convolution Zero Points Inputs (PR #122939)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 24 12:27:00 PST 2025


================
@@ -220,30 +220,57 @@ static LogicalResult verifyConvOp(T op) {
   // All TOSA conv ops have an input() and weight().
   auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
 
-  RankedTensorType weightType;
-  if constexpr (std::is_same_v<T, tosa::TransposeConv2DOp>)
-    weightType = llvm::dyn_cast<RankedTensorType>(op.getFilter().getType());
-  else
-    weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
+  auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
 
   // Must be ranked tensor types
   if (!inputType) {
     op.emitOpError("expect a ranked tensor for input, got ") << op.getInput();
     return failure();
   }
   if (!weightType) {
-    if constexpr (std::is_same_v<T, tosa::TransposeConv2DOp>) {
-      op.emitOpError("expect a ranked tensor for filter, got ")
-          << op.getFilter();
-    } else {
-      op.emitOpError("expect a ranked tensor for weight, got ")
-          << op.getWeight();
-    }
+    op.emitOpError("expect a ranked tensor for weight, got ") << op.getWeight();
     return failure();
   }
 
   auto inputEType = inputType.getElementType();
   auto weightEType = weightType.getElementType();
+  auto biasEType =
+      llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
+  auto resultEType =
+      llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
+  bool biasIsFloat = llvm::isa<FloatType>(biasEType);
+  bool resultIsFloat = llvm::isa<FloatType>(resultEType);
+
+  if (auto quantType =
+          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
+    inputEType = quantType.getStorageType();
+
+  if (auto quantType =
+          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(biasEType))
+    biasEType = quantType.getStorageType();
+
+  if (auto quantType =
+          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultEType))
+    resultEType = quantType.getStorageType();
+
+  if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
+    // for now, only enforce bias element type == result element type for
+    // float types.
+    op.emitOpError(
+        "expect both bias and result to have same element type, got ")
+        << biasEType << " and " << resultEType;
+    return failure();
+  }
+
+  if (inputEType.isFloat8E5M2() || inputEType.isFloat8E4M3FN() ||
----------------
Jerry-Ge wrote:

I have another FP8 patch waiting for upstreaming. It's better to let this patch to be based on the FP8 patch. 

https://github.com/llvm/llvm-project/pull/122939


More information about the Mlir-commits mailing list