[Mlir-commits] [mlir] [mlir][tosa] Add FP8 support (PR #127730)
Tai Ly
llvmlistbot at llvm.org
Wed Feb 19 10:56:38 PST 2025
================
@@ -459,16 +472,123 @@ LogicalResult tosa::AvgPool2dOp::verify() {
if (inputETy.isF32() && !accType.isF32())
return emitOpError("accumulator type for f32 tensor is not f32");
+ if ((llvm::isa<Float8E5M2Type>(inputETy) ||
+ llvm::isa<Float8E4M3FNType>(inputETy)) &&
+ !accType.isF16())
+ return emitOpError("accumulator type for f8 tensor is not f16");
+
if ((inputETy.isF32() && resultETy.isF32()) ||
(inputETy.isF16() && resultETy.isF16()) ||
(inputETy.isBF16() && resultETy.isBF16()) ||
+ (llvm::isa<Float8E5M2Type>(inputETy) &&
+ llvm::isa<Float8E5M2Type>(resultETy)) ||
+ (llvm::isa<Float8E4M3FNType>(inputETy) &&
+ llvm::isa<Float8E4M3FNType>(resultETy)) ||
(inputETy.isInteger(8) && resultETy.isInteger(8)) ||
(inputETy.isInteger(16) && resultETy.isInteger(16)))
return success();
return emitOpError("input/output element types are incompatible.");
}
+LogicalResult tosa::CastOp::verify() {
+ mlir::Type inputETy =
+ llvm::cast<ShapedType>(getInput().getType()).getElementType();
+ if (auto inputQuantType =
+ llvm::dyn_cast<mlir::quant::QuantizedType>(inputETy)) {
+ inputETy = inputQuantType.getStorageType();
+ }
+ mlir::Type outputETy =
+ llvm::cast<ShapedType>(getOutput().getType()).getElementType();
+ if (auto outputQuantType =
+ llvm::dyn_cast<mlir::quant::QuantizedType>(outputETy)) {
+ outputETy = outputQuantType.getStorageType();
+ }
+
+ // input element type: bool
+ if (inputETy.isInteger(1)) {
+ if (outputETy.isInteger(8) || outputETy.isInteger(16) ||
+ outputETy.isInteger(32)) {
+ return success();
+ }
+ }
+ // input element type: int8
+ if (inputETy.isInteger(8)) {
----------------
Tai78641 wrote:
ditto
https://github.com/llvm/llvm-project/pull/127730
More information about the Mlir-commits
mailing list