Skip to content

[MRG] POT numpy/torch/jax backends #249

Merged
rflamary merged 93 commits intomasterfrom
backend
Jun 1, 2021
Merged

[MRG] POT numpy/torch/jax backends #249
rflamary merged 93 commits intomasterfrom
backend

Conversation

@rflamary
Copy link
Collaborator

@rflamary rflamary commented May 4, 2021

Types of changes

We are working on a new multi-backend API for POT with the objective to be able to handle numpy/torch/jax/cupy arrays seamlessly.

This will be a big one.

Expected features :

  • Numpy backend
  • Torch backend + tests on CPU
  • Jax backend + tests on CPU
  • ot.dist working with numpy/jax/torch arrays
  • ot.emd working with numpy/jax/torch arrays
  • ot.emd2 working with numpy/jax/torch arrays (with working gradients wrt a,b,M except jax that require to define a new function)
  • ot.sinkhorn working with numpy/jax/torch arrays (default algorithm only)
  • ot.sinkhorn2 working with numpy/jax/torch arrays (with working gradients wrt a,b,M)
  • Update documentations from emd/emd2 sinkhorn/sinkhorn2 and quickstart guide for the new backend
  • ot.utils.proj_simplex working with numpy/jax/torch arrays

@codecov
Copy link

codecov bot commented May 4, 2021

Codecov Report

Merging #249 (a229668) into master (1f16614) will increase coverage by 0.95%.
The diff coverage is 97.01%.

@@            Coverage Diff             @@
##           master     #249      +/-   ##
==========================================
+ Coverage   92.10%   93.06%   +0.95%     
==========================================
  Files          16       17       +1     
  Lines        3054     3445     +391     
==========================================
+ Hits         2813     3206     +393     
+ Misses        241      239       -2     

@rflamary rflamary changed the title [WIP] POT multi backends (no judging yet) [WIP] POT multi backends May 5, 2021
@rflamary
Copy link
Collaborator Author

Just updated the doc one last time, it should build ;).

The PR feels OK, wdyt @ncourty ?


lst_tot = []

for nx in [ot.backend.NumpyBackend(), backend]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NumpyBackend is not already in backend_list? if not so why?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because we run it for numpy and the current backend and check that they provide the same output (see below the loop)

v = rnd.randn(10)
c = rnd.randn(1)

if torch:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the fact that you have to branch for torch here suggests that the backend mechanism
is not consistent

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for gradients no we cannot be consistent because each backend handes gradients differently. The fact that we have a branch mans that we can handle different ways to compute the gradients (when they can be computed).

if you want to test if the backend works, you have to do a test utiings functiosn specfis to ecah backend.

@rflamary rflamary changed the title [WIP] POT numpy/torch/jax backends [MRG] POT numpy/torch/jax backends Jun 1, 2021
@rflamary
Copy link
Collaborator Author

rflamary commented Jun 1, 2021

here is the new example. I love how simply the Wasserstein loss can be used and optimized.

https://626-71472695-gh.circle-artifacts.com/0/dev/auto_examples/backends/plot_unmix_optim_torch.html#sphx-glr-auto-examples-backends-plot-unmix-optim-torch-py

Just corrected the legend in the last figure. Will merge when tests pass.

@rflamary rflamary merged commit 184f8f4 into master Jun 1, 2021
@agramfort
Copy link
Collaborator

agramfort commented Jun 1, 2021 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants