-
Notifications
You must be signed in to change notification settings - Fork 241
Expand file tree
/
Copy pathjax_backend.py
More file actions
336 lines (277 loc) · 12.7 KB
/
jax_backend.py
File metadata and controls
336 lines (277 loc) · 12.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
import numpy as np
from docarray.computation.abstract_comp_backend import AbstractComputationalBackend
from docarray.computation.abstract_numpy_based_backend import AbstractNumpyBasedBackend
from docarray.typing import JaxArray
from docarray.utils._internal.misc import import_library
if TYPE_CHECKING:
import jax
import jax.numpy as jnp
else:
jax = import_library('jax', raise_error=True)
jnp = jax.numpy
def _expand_if_single_axis(*matrices: jnp.ndarray) -> List[jnp.ndarray]:
"""Expands arrays that only have one axis, at dim 0.
This ensures that all outputs can be treated as matrices, not vectors.
:param matrices: Matrices to be expanded
:return: List of the input matrices,
where single axis matrices are expanded at dim 0.
"""
expanded = []
for m in matrices:
if len(m.shape) == 1:
expanded.append(jnp.expand_dims(m, axis=0))
else:
expanded.append(m)
return expanded
def _expand_if_scalar(arr: jnp.ndarray) -> jnp.ndarray:
if len(arr.shape) == 0: # avoid scalar output
arr = jnp.expand_dims(arr, axis=0)
return arr
def norm_left(t: jnp.ndarray) -> JaxArray:
return JaxArray(tensor=t)
def norm_right(t: JaxArray) -> jnp.ndarray:
return t.tensor
class JaxCompBackend(AbstractNumpyBasedBackend):
"""
Computational backend for Jax.
"""
_module = jnp
_cast_output: Callable = norm_left
_get_tensor: Callable = norm_right
@classmethod
def to_device(cls, tensor: 'JaxArray', device: str) -> 'JaxArray':
"""Move the tensor to the specified device."""
if cls.device(tensor) == device:
return tensor
else:
jax_devices = jax.devices(device)
return cls._cast_output(
jax.device_put(cls._get_tensor(tensor), jax_devices)
)
@classmethod
def device(cls, tensor: 'JaxArray') -> Optional[str]:
"""Return device on which the tensor is allocated."""
return cls._get_tensor(tensor).device().platform
@classmethod
def to_numpy(cls, array: 'JaxArray') -> 'np.ndarray':
return cls._get_tensor(array).__array__()
@classmethod
def none_value(cls) -> Any:
"""Provide a compatible value that represents None in JAX."""
return jnp.nan
@classmethod
def detach(cls, tensor: 'JaxArray') -> 'JaxArray':
"""
Returns the tensor detached from its current graph.
:param tensor: tensor to be detached
:return: a detached tensor with the same data.
"""
return cls._cast_output(jax.lax.stop_gradient(cls._get_tensor(tensor)))
@classmethod
def dtype(cls, tensor: 'JaxArray') -> jnp.dtype:
"""Get the data type of the tensor."""
d_type = cls._get_tensor(tensor).dtype
return d_type.name
@classmethod
def minmax_normalize(
cls,
tensor: 'JaxArray',
t_range: Tuple = (0, 1),
x_range: Optional[Tuple] = None,
eps: float = 1e-7,
) -> 'JaxArray':
"""
Normalize values in `tensor` into `t_range`.
`tensor` can be a 1D array or a 2D array. When `tensor` is a 2D array, then
normalization is row-based.
!!! note
- with `t_range=(0, 1)` will normalize the min-value of data to 0, max to 1;
- with `t_range=(1, 0)` will normalize the min-value of data to 1, max value
of the data to 0.
:param tensor: the data to be normalized
:param t_range: a tuple represents the target range.
:param x_range: a tuple represents tensors range.
:param eps: a small jitter to avoid dividing by zero
:return: normalized data in `t_range`
"""
a, b = t_range
t = jnp.asarray(cls._get_tensor(tensor), jnp.float32)
min_d = x_range[0] if x_range else jnp.min(t, axis=-1, keepdims=True)
max_d = x_range[1] if x_range else jnp.max(t, axis=-1, keepdims=True)
r = (b - a) * (t - min_d) / (max_d - min_d + eps) + a
normalized = jnp.clip(r, *((a, b) if a < b else (b, a)))
return cls._cast_output(jnp.asarray(normalized, cls._get_tensor(tensor).dtype))
@classmethod
def equal(cls, tensor1: 'JaxArray', tensor2: 'JaxArray') -> bool:
"""
Check if two tensors are equal.
:param tensor1: the first tensor
:param tensor2: the second tensor
:return: True if two tensors are equal, False otherwise.
If one or more of the inputs is not a TensorFlowTensor, return False.
"""
t1, t2 = getattr(tensor1, 'tensor', None), getattr(tensor2, 'tensor', None)
if isinstance(t1, jnp.ndarray) and isinstance(t2, jnp.ndarray):
# mypy doesn't know that tf.is_tensor implies that t1, t2 are not None
return t1.shape == t2.shape and jnp.all(jnp.equal(t1, t1)) # type: ignore
return False
class Retrieval(AbstractComputationalBackend.Retrieval[JaxArray]):
"""
Abstract class for retrieval and ranking functionalities
"""
@staticmethod
def top_k(
values: 'JaxArray',
k: int,
descending: bool = False,
device: Optional[str] = None,
) -> Tuple['JaxArray', 'JaxArray']:
"""
Returns the k smallest values in `values` along with their indices.
Can also be used to retrieve the k largest values,
by setting the `descending` flag.
:param values: Jax tensor of values to rank.
Should be of shape (n_queries, n_values_per_query).
Inputs of shape (n_values_per_query,) will be expanded
to (1, n_values_per_query).
:param k: number of values to retrieve
:param descending: retrieve largest values instead of smallest values
:param device: Not supported for this backend
:return: Tuple containing the retrieved values, and their indices.
Both are of shape (n_queries, k)
"""
comp_be = JaxCompBackend
if device is not None:
values = comp_be.to_device(values, device)
jax_values: jnp.ndarray = comp_be._get_tensor(values)
if len(jax_values.shape) == 1:
jax_values = jnp.expand_dims(jax_values, axis=0)
if descending:
jax_values = -jax_values
if k >= jax_values.shape[1]:
idx = jax_values.argsort(axis=1)[:, :k]
jax_values = jnp.take_along_axis(jax_values, idx, axis=1)
else:
idx_ps = jax_values.argpartition(kth=k, axis=1)[:, :k]
jax_values = jnp.take_along_axis(jax_values, idx_ps, axis=1)
idx_fs = jax_values.argsort(axis=1)
idx = jnp.take_along_axis(idx_ps, idx_fs, axis=1)
jax_values = jnp.take_along_axis(jax_values, idx_fs, axis=1)
if descending:
jax_values = -jax_values
return comp_be._cast_output(jax_values), comp_be._cast_output(idx)
class Metrics(AbstractComputationalBackend.Metrics[JaxArray]):
"""
Abstract base class for metrics (distances and similarities).
"""
@staticmethod
def cosine_sim(
x_mat: 'JaxArray',
y_mat: 'JaxArray',
eps: float = 1e-7,
device: Optional[str] = None,
) -> 'JaxArray':
"""Pairwise cosine similarities between all vectors in x_mat and y_mat.
:param x_mat: tensor of shape (n_vectors, n_dim), where n_vectors is the
number of vectors and n_dim is the number of dimensions of each example.
:param y_mat: tensor of shape (n_vectors, n_dim), where n_vectors is the
number of vectors and n_dim is the number of dimensions of each example.
:param eps: a small jitter to avoid dividing by zero
:param device: the device to use for computations.
If not provided, the devices of x_mat and y_mat are used.
:return: JaxArray of shape (n_vectors, n_vectors) containing all pairwise
cosine distances.
The index [i_x, i_y] contains the cosine distance between
x_mat[i_x] and y_mat[i_y].
"""
comp_be = JaxCompBackend
x_mat_jax: jnp.ndarray = comp_be._get_tensor(x_mat)
y_mat_jax: jnp.ndarray = comp_be._get_tensor(y_mat)
x_mat_jax, y_mat_jax = _expand_if_single_axis(x_mat_jax, y_mat_jax)
sims = jnp.clip(
(jnp.dot(x_mat_jax, y_mat_jax.T) + eps)
/ (
jnp.outer(
jnp.linalg.norm(x_mat_jax, axis=1),
jnp.linalg.norm(y_mat_jax, axis=1),
)
+ eps
),
-1,
1,
).squeeze()
sims = _expand_if_scalar(sims)
return comp_be._cast_output(sims)
@classmethod
def euclidean_dist(
cls, x_mat: JaxArray, y_mat: JaxArray, device: Optional[str] = None
) -> JaxArray:
"""Pairwise Euclidian distances between all vectors in x_mat and y_mat.
:param x_mat: jnp.ndarray of shape (n_vectors, n_dim), where n_vectors is
the number of vectors and n_dim is the number of dimensions of each
example.
:param y_mat: jnp.ndarray of shape (n_vectors, n_dim), where n_vectors is
the number of vectors and n_dim is the number of dimensions of each
example.
:param eps: a small jitter to avoid dividing by zero
:param device: Not supported for this backend
:return: JaxArray of shape (n_vectors, n_vectors) containing all
pairwise euclidian distances.
The index [i_x, i_y] contains the euclidian distance between
x_mat[i_x] and y_mat[i_y].
"""
comp_be = JaxCompBackend
x_mat_jax: jnp.ndarray = comp_be._get_tensor(x_mat)
y_mat_jax: jnp.ndarray = comp_be._get_tensor(y_mat)
if device is not None:
# warnings.warn('`device` is not supported for numpy operations')
pass
x_mat_jax, y_mat_jax = _expand_if_single_axis(x_mat_jax, y_mat_jax)
x_mat_jax_arr: JaxArray = comp_be._cast_output(x_mat_jax)
y_mat_jax_arr: JaxArray = comp_be._cast_output(y_mat_jax)
dists = _expand_if_scalar(
jnp.sqrt(
comp_be._get_tensor(
cls.sqeuclidean_dist(x_mat_jax_arr, y_mat_jax_arr)
)
).squeeze()
)
return comp_be._cast_output(dists)
@staticmethod
def sqeuclidean_dist(
x_mat: JaxArray,
y_mat: JaxArray,
device: Optional[str] = None,
) -> JaxArray:
"""Pairwise Squared Euclidian distances between all vectors in
x_mat and y_mat.
:param x_mat: jnp.ndarray of shape (n_vectors, n_dim), where n_vectors is
the number of vectors and n_dim is the number of dimensions of each
example.
:param y_mat: jnp.ndarray of shape (n_vectors, n_dim), where n_vectors is
the number of vectors and n_dim is the number of dimensions of each
example.
:param device: Not supported for this backend
:return: JaxArray of shape (n_vectors, n_vectors) containing all
pairwise Squared Euclidian distances.
The index [i_x, i_y] contains the cosine Squared Euclidian between
x_mat[i_x] and y_mat[i_y].
"""
comp_be = JaxCompBackend
x_mat_jax: jnp.ndarray = comp_be._get_tensor(x_mat)
y_mat_jax: jnp.ndarray = comp_be._get_tensor(y_mat)
eps: float = 1e-7 # avoid problems with numerical inaccuracies
if device is not None:
pass
# warnings.warn('`device` is not supported for numpy operations')
x_mat_jax, y_mat_jax = _expand_if_single_axis(x_mat_jax, y_mat_jax)
dists = (
jnp.sum(y_mat_jax**2, axis=1)
+ jnp.sum(x_mat_jax**2, axis=1)[:, jnp.newaxis]
- 2 * jnp.dot(x_mat_jax, y_mat_jax.T)
).squeeze()
# remove numerical artifacts
dists = jnp.where(np.logical_and(dists < 0, dists > -eps), 0, dists)
dists = _expand_if_scalar(dists)
return comp_be._cast_output(dists)