from ml_dtypes import bfloat16
import numpy as np
x = np.arange(4, dtype=bfloat16)
np.save('out.npy', x)
print(np.load('out.npy'))
# [b'\x00\x00' b'\x80\x3F' b'\x00\x40' b'\x40\x40']
This is due to the fact that np.save saves the dtype via the kind code, and we chose np.dtype(bfloat16).kind = 'V' because such codes are not extensible in numpy.
When reconstructing the dtype, this is the result:
np.dtype(np.dtype(bfloat16).kind)
# dtype('V')
I don't think this is something that can be fixed without deeper changes to numpy itself.
This is due to the fact that
np.savesaves the dtype via thekindcode, and we chosenp.dtype(bfloat16).kind = 'V'because such codes are not extensible in numpy.When reconstructing the dtype, this is the result:
I don't think this is something that can be fixed without deeper changes to numpy itself.