Skip to content

Speed-up Sinkhorn#57

Merged
rflamary merged 3 commits intoPythonOT:masterfrom
LeoGautheron:master
Jul 18, 2018
Merged

Speed-up Sinkhorn#57
rflamary merged 3 commits intoPythonOT:masterfrom
LeoGautheron:master

Conversation

@LeoGautheron
Copy link

Speed-up in 3 places:

  • the computation of pairwise distance is faster with sklearn.metrics.pairwise.euclidean_distances
  • faster computation of K = np.exp(-M / reg)
  • faster computation of the error every 10 iterations

Example with this little script:

import time
import numpy as np
import ot
rng = np.random.RandomState(0)
transport = ot.da.SinkhornTransport()
time1 = time.time()
Xs, ys, Xt = rng.randn(10000, 100), rng.randint(0, 2, size=10000), rng.randn(10000, 100)
transport.fit(Xs=Xs, Xt=Xt)
time2 = time.time()
print("OT Computation Time {:6.2f} sec".format(time2-time1))
transport = ot.da.SinkhornLpl1Transport()
transport.fit(Xs=Xs, ys=ys, Xt=Xt)
time3 = time.time()
print("OT LpL1 Computation Time {:6.2f} sec".format(time3-time2))

Before
OT Computation Time 19.93 sec
OT LpL1 Computation Time 133.43 sec

After
OT Computation Time 7.55 sec
OT LpL1 Computation Time 82.25 sec

Speed-up in 3 places:
 - the computation of pairwise distance is faster with sklearn.metrics.pairwise.euclidean_distances
 - faster computation of K = np.exp(-M / reg)
 - faster computation of the error every 10 iterations

Example with this little script:

import time
import numpy as np
import ot
rng = np.random.RandomState(0)
transport = ot.da.SinkhornTransport()
time1 = time.time()
Xs, ys, Xt = rng.randn(10000, 100), rng.randint(0, 2, size=10000), rng.randn(10000, 100)
transport.fit(Xs=Xs, Xt=Xt)
time2 = time.time()
print("OT Computation Time {:6.2f} sec".format(time2-time1))
transport = ot.da.SinkhornLpl1Transport()
transport.fit(Xs=Xs, ys=ys, Xt=Xt)
time3 = time.time()
print("OT LpL1 Computation Time {:6.2f} sec".format(time3-time2))

Before
OT Computation Time  19.93 sec
OT LpL1 Computation Time 133.43 sec

After
OT Computation Time   7.55 sec
OT LpL1 Computation Time  82.25 sec
@rflamary
Copy link
Collaborator

Hello @LeoGautheron ,

Thank you for all your work. Those are very nice speedup and i can confirm that I have similar gains. Still have a few comments.

  • My main concern is to add sklearn as a dependency. Up to now we managed to avoid that (and the tests fail because sklearn is not a hard dependency).

When i run the following code

ot.tic()
M1=ot.dist(x1,x2)
ot.toc()

ot.tic()
M2=euclidean_distances(x1,x2, squared=True)
ot.toc()

ot.tic()
x1p2 = np.sum(np.square(x1), 1)
x2p2 = np.sum(np.square(x2), 1)
M3=x1p2.reshape((-1, 1)) + x2p2.reshape((1, -1)) - 2 * np.dot(x1, x2.T)
ot.toc()

I get the following:

Elapsed time : 2.0640811920166016 s
Elapsed time : 0.37867116928100586 s
Elapsed time : 0.46784543991088867 s

with the last one pure numpy. I think the gain is sufficient with the last implementation and avoid an additional dependency for POT.

  • Next I looked at all your speedup and they all come at the cost of hard to read code. Since they all bring at least a 20% gain in performance i'm ready to keep them but please add some comments around the code like
# Next N lines equivalent to:
# K= np.exp(-M/reg)

For the next guy who wants to have a look at the function.

In any case its nice work and the computational gain is very important.

@agramfort
Copy link
Collaborator

just copy the code you need then https://github.com/scikit-learn/scikit-learn/blob/a24c8b46/sklearn/metrics/pairwise.py#L163

it's pure python code.

@LeoGautheron
Copy link
Author

I used the code from sklearn, still the same performances now :)

@rflamary rflamary merged commit 5cd6c0a into PythonOT:master Jul 18, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants