1+ import pickle
2+ import time
3+ import inspect
4+ import base64
5+ import hashlib
6+
7+
8+ debug = False
9+
10+
11+ def log (s ):
12+ if debug :
13+ print (s )
14+
15+
16+ caches = dict ()
17+ updated_caches = []
18+
19+
20+ def get_cache (fname ):
21+ if fname in caches :
22+ return caches [fname ]
23+ try :
24+ with open (fname , "rb" ) as f :
25+ c = pickle .load (f )
26+ except :
27+ c = dict ()
28+ caches [fname ] = c
29+ return c
30+
31+
32+ def write_to_cache (fname , obj ):
33+ updated_caches .append (fname )
34+ caches [fname ] = obj
35+
36+
37+ def cleanup ():
38+ for fname in updated_caches :
39+ with open (fname , "wb" ) as f :
40+ pickle .dump (caches [fname ], f )
41+
42+
43+ def get_fn_hash (f ):
44+ return base64 .b64encode (hashlib .sha1 (inspect .getsource (f ).encode ("utf-8" )).digest ())
45+
46+
47+ NONE = 0
48+ ARGS = 1
49+ KWARGS = 2
50+
51+
52+ def cache (fname = ".cache.pkl" , timeout = - 1 , key = ARGS | KWARGS ):
53+
54+ def impl (fn ):
55+ load_t = time .time ()
56+ c = get_cache (fname )
57+ log ("loaded cache in {:.2f}s" .format (time .time () - load_t ))
58+
59+ def d (* args , ** kwargs ):
60+ log ("checking cache on {}" .format (fn .__name__ ))
61+ if key == ARGS | KWARGS :
62+ k = pickle .dumps ((fn .__name__ , args , kwargs ))
63+ if key == ARGS :
64+ k = pickle .dumps ((fn .__name__ , args ))
65+ if key == KWARGS :
66+ k = pickle .dumps ((fn .__name__ , kwargs ))
67+ if key == NONE :
68+ k = pickle .dumps ((fn .__name__ ))
69+ if k in c :
70+ h , t , to , res = c [k ]
71+ if get_fn_hash (fn ) == h and (to < 0 or (time .time () - t ) < to ):
72+ log ("cache hit." )
73+ return res
74+ log ("cache miss." )
75+ res = fn (* args , ** kwargs )
76+ c [k ] = (get_fn_hash (fn ), time .time (), timeout , res )
77+ save_t = time .time ()
78+ write_to_cache (fname , c )
79+ log ("saved cache in {:.2f}s" .format (time .time () - save_t ))
80+ return res
81+
82+ return d
83+
84+ return impl
85+
86+
87+ @cache (timeout = 0.2 )
88+ def expensive (k ):
89+ time .sleep (0.2 )
90+ return k
91+
92+
93+ @cache (key = KWARGS )
94+ def expensive2 (k , kwarg1 = None ):
95+ time .sleep (0.2 )
96+ return k
97+
98+
99+ def test ():
100+ # Test timeout
101+ t = time .time ()
102+ v = expensive (1 )
103+ assert v == 1
104+ assert time .time () - t > 0.1
105+ t = time .time ()
106+ expensive (1 )
107+ assert time .time () - t < 0.1
108+ time .sleep (0.3 )
109+ t = time .time ()
110+ expensive (1 )
111+ assert time .time () - t > 0.1
112+ t = time .time ()
113+ v = expensive (2 )
114+ assert v == 2
115+ assert time .time () - t > 0.1
116+ # Test key=_ annotation
117+ t = time .time ()
118+ v = expensive2 (2 , kwarg1 = "test" )
119+ assert v == 2
120+ assert time .time () - t > 0.1
121+ t = time .time ()
122+ v = expensive2 (1 , kwarg1 = "test" )
123+ assert v == 2
124+ assert time .time () - t < 0.1
125+ t = time .time ()
126+ v = expensive2 (1 , kwarg1 = "test2" )
127+ assert v == 1
128+ assert time .time () - t > 0.1
129+ cleanup ()
130+ print ("pass" )
131+
132+
133+ if __name__ == "__main__" :
134+ test ()
0 commit comments