[Mlir-commits] [mlir] [mlir][tosa] Add more verifiers for the following operators (PR #127923)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 27 11:25:20 PST 2025
================
@@ -472,6 +472,104 @@ LogicalResult tosa::AvgPool2dOp::verify() {
return emitOpError("input/output element types are incompatible.");
}
+LogicalResult tosa::CastOp::verify() {
+ mlir::Type inputETy =
+ llvm::cast<ShapedType>(getInput().getType()).getElementType();
+ if (auto inputQuantType =
+ llvm::dyn_cast<mlir::quant::QuantizedType>(inputETy)) {
+ inputETy = inputQuantType.getStorageType();
+ }
+ mlir::Type outputETy =
+ llvm::cast<ShapedType>(getOutput().getType()).getElementType();
+ if (auto outputQuantType =
+ llvm::dyn_cast<mlir::quant::QuantizedType>(outputETy)) {
+ outputETy = outputQuantType.getStorageType();
+ }
+
+ // input element type: bool
+ if (inputETy.isInteger(1)) {
+ if (outputETy.isInteger(8) || outputETy.isInteger(16) ||
+ outputETy.isInteger(32)) {
+ return success();
+ }
+ }
+ // input element type: int8
+ if (inputETy.isInteger(8)) {
+ if (outputETy.isInteger(1) || outputETy.isInteger(16) ||
+ outputETy.isInteger(32) || outputETy.isF16() || outputETy.isBF16() ||
+ outputETy.isF32()) {
+ return success();
+ }
+ }
+ // input element type: int16
+ if (inputETy.isInteger(16)) {
+ if (outputETy.isInteger(1) || outputETy.isInteger(8) ||
+ outputETy.isInteger(32) || outputETy.isF16() || outputETy.isBF16() ||
+ outputETy.isF32()) {
+ return success();
+ }
+ }
+ // input element type: int32
+ if (inputETy.isInteger(32)) {
+ if (outputETy.isInteger(1) || outputETy.isInteger(8) ||
+ outputETy.isInteger(16) || outputETy.isF16() || outputETy.isBF16() ||
+ outputETy.isF32()) {
+ return success();
+ }
+ }
+ // input element type: bf16 or fp16
+ if (inputETy.isBF16() || inputETy.isF16()) {
+ if (outputETy.isInteger(8) || outputETy.isInteger(16) ||
+ outputETy.isInteger(32) || llvm::isa<Float8E5M2Type>(outputETy) ||
+ llvm::isa<Float8E4M3FNType>(outputETy) || outputETy.isF32()) {
+ return success();
+ }
+ }
+ // input element type: f8e4m3 or f8e5m2
+ if (llvm::isa<Float8E4M3FNType>(inputETy) ||
+ llvm::isa<Float8E5M2Type>(inputETy)) {
+ if (outputETy.isF16() || outputETy.isBF16() || outputETy.isF32()) {
+ return success();
+ }
+ }
+ // input element type: fp32
+ if (inputETy.isF32()) {
+ if (outputETy.isInteger(8) || outputETy.isInteger(16) ||
+ outputETy.isInteger(32) || llvm::isa<Float8E5M2Type>(outputETy) ||
+ llvm::isa<Float8E4M3FNType>(outputETy) || outputETy.isF16() ||
+ outputETy.isBF16()) {
+ return success();
+ }
+ }
+
+ // following are outside of TOSA Spec
----------------
Jerry-Ge wrote:
i think those should be moved to the validation pass
https://github.com/llvm/llvm-project/pull/127923
More information about the Mlir-commits
mailing list