From db7c4a0a06c981910a34cc04c58ce658f234d5ae Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Tue, 18 Jun 2024 00:48:56 -0400 Subject: [PATCH] always cast to float32, try to convert other array types too --- fastplotlib/graphics/_features/_base.py | 32 ++++++------------------- 1 file changed, 7 insertions(+), 25 deletions(-) diff --git a/fastplotlib/graphics/_features/_base.py b/fastplotlib/graphics/_features/_base.py index 1b24d3b78..a57f8a453 100644 --- a/fastplotlib/graphics/_features/_base.py +++ b/fastplotlib/graphics/_features/_base.py @@ -12,36 +12,18 @@ WGPU_MAX_TEXTURE_SIZE = 8192 -supported_dtypes = [ - np.uint8, - np.uint16, - np.uint32, - np.int8, - np.int16, - np.int32, - np.float16, - np.float32, -] - - def to_gpu_supported_dtype(array): """ - If ``array`` is a numpy array, converts it to a supported type. GPUs don't support 64 bit dtypes. + convert input array to float32 numpy array """ if isinstance(array, np.ndarray): - if array.dtype not in supported_dtypes: - if np.issubdtype(array.dtype, np.integer): - warn(f"converting {array.dtype} array to int32") - return array.astype(np.int32) - elif np.issubdtype(array.dtype, np.floating): - warn(f"converting {array.dtype} array to float32") - return array.astype(np.float32, copy=False) - else: - raise TypeError( - "Unsupported type, supported array types must be int or float dtypes" - ) + if not array.dtype == np.float32: + warn(f"casting {array.dtype} array to float32") + return array.astype(np.float32) + return array - return array + # try to make a numpy array from it, should not copy, tested with jax arrays + return np.asarray(array).astype(np.float32) class FeatureEvent(pygfx.Event):