[Mlir-commits] [mlir] Support float8_e3m4 and float8_e4m3 in np_to_memref (PR #186453)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 13 10:18:01 PDT 2026
================
@@ -77,6 +94,16 @@ def to_numpy(array):
), f"float8_e5m2 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
if array.dtype == F8E5M2:
return array.view("float8_e5m2")
+ assert not (
+ array.dtype == F8E3M4 and ml_dtypes is None
+ ), f"float8_e3m4 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
+ if array.dtype == F8E3M4:
+ return array.view("float8_e3m4")
+ assert not (
+ array.dtype == F8E4M3 and ml_dtypes is None
+ ), f"float8_e4m3 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
+ if array.dtype == F8E4M3:
+ return array.view("float8_e4m3")
----------------
srcarroll wrote:
i was following precedence with BF16 and F8E5M2 here as well, but yes agreed that it should be refactored. will do.
https://github.com/llvm/llvm-project/pull/186453
More information about the Mlir-commits
mailing list