[Mlir-commits] [mlir] [mlir][tosa] Add FP8 support (PR #127730)

Tai Ly llvmlistbot at llvm.org
Wed Feb 19 10:56:38 PST 2025


================
@@ -459,16 +472,123 @@ LogicalResult tosa::AvgPool2dOp::verify() {
   if (inputETy.isF32() && !accType.isF32())
     return emitOpError("accumulator type for f32 tensor is not f32");
 
+  if ((llvm::isa<Float8E5M2Type>(inputETy) ||
+       llvm::isa<Float8E4M3FNType>(inputETy)) &&
+      !accType.isF16())
+    return emitOpError("accumulator type for f8 tensor is not f16");
+
   if ((inputETy.isF32() && resultETy.isF32()) ||
       (inputETy.isF16() && resultETy.isF16()) ||
       (inputETy.isBF16() && resultETy.isBF16()) ||
+      (llvm::isa<Float8E5M2Type>(inputETy) &&
+       llvm::isa<Float8E5M2Type>(resultETy)) ||
+      (llvm::isa<Float8E4M3FNType>(inputETy) &&
+       llvm::isa<Float8E4M3FNType>(resultETy)) ||
       (inputETy.isInteger(8) && resultETy.isInteger(8)) ||
       (inputETy.isInteger(16) && resultETy.isInteger(16)))
     return success();
 
   return emitOpError("input/output element types are incompatible.");
 }
 
+LogicalResult tosa::CastOp::verify() {
+  mlir::Type inputETy =
+      llvm::cast<ShapedType>(getInput().getType()).getElementType();
+  if (auto inputQuantType =
+          llvm::dyn_cast<mlir::quant::QuantizedType>(inputETy)) {
+    inputETy = inputQuantType.getStorageType();
+  }
+  mlir::Type outputETy =
+      llvm::cast<ShapedType>(getOutput().getType()).getElementType();
+  if (auto outputQuantType =
+          llvm::dyn_cast<mlir::quant::QuantizedType>(outputETy)) {
+    outputETy = outputQuantType.getStorageType();
+  }
+
+  // input element type: bool
+  if (inputETy.isInteger(1)) {
+    if (outputETy.isInteger(8) || outputETy.isInteger(16) ||
+        outputETy.isInteger(32)) {
+      return success();
+    }
+  }
+  // input element type: int8
+  if (inputETy.isInteger(8)) {
----------------
Tai78641 wrote:

ditto

https://github.com/llvm/llvm-project/pull/127730


More information about the Mlir-commits mailing list