[Mlir-commits] [mlir] Support float8_e3m4 and float8_e4m3 in np_to_memref (PR #186453)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Mar 14 17:53:19 PDT 2026
================
@@ -68,15 +88,17 @@ def to_numpy(array):
if array.dtype == F16:
return array.view("float16")
assert not (
- array.dtype == BF16 and ml_dtypes is None
- ), f"bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
+ array.dtype in ML_DTYPES_REQUIRED and ml_dtypes is None
+ ), f"{array.dtype.__name__} requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
----------------
srcarroll wrote:
Thanks for the suggestion. I'm actually unfamiliar with what you are suggesting does but will check it out. I used __name__ just to get a nicer looking message. As simply printing `array.dtype` is a little ugly, something like `<class Blah>` (not at my computer to check so don't exactly remember the format)
https://github.com/llvm/llvm-project/pull/186453
More information about the Mlir-commits
mailing list