@@ -26,12 +26,13 @@ class Driver(BasePlugin):
2626 opts : Dict [str , Any ]
2727 pm : 'PluginManager'
2828
29- def __init__ (self , opts = None ):
29+ def __init__ (self , opts = None , load = 'auto' ):
3030 self .opts = opts or {}
31+ self .load = load
3132
3233 def _pre_connection_check (self , connectionstatus , load ):
3334 if connectionstatus == False : # if no host found
34- if load == 'yes' :
35+ if self . load == 'yes' :
3536 raise Exception (
3637 f'Connect pre-check failed for { self .name } , as if the host is not there? Options { self .opts } '
3738 )
@@ -41,14 +42,14 @@ def _pre_connection_check(self, connectionstatus, load):
4142 if hasattr (self , 'validate' ):
4243 return True
4344
44- def initialize (self , load = None ):
45+ def initialize (self ):
4546 connectionstatus = self .connect ()
46- if self ._pre_connection_check (connectionstatus , load ):
47+ if self ._pre_connection_check (connectionstatus , self . load ):
4748 self .validate ()
4849
49- async def initialize_async (self , load = None ):
50+ async def initialize_async (self ):
5051 connectionstatus = await self .connect ()
51- if self ._pre_connection_check (connectionstatus , load ):
52+ if self ._pre_connection_check (connectionstatus , self . load ):
5253 await self .validate ()
5354
5455 def get_instance (self ):
@@ -70,8 +71,9 @@ class Setup(BasePlugin):
7071 ...
7172
7273
73- def get_defined_plugins (mod ):
74- returndata = {'hook-definitions' : [], 'hooks' : [], 'drivers' : [], 'setup' : []}
74+ def get_defined_plugins (mod , plugin_types = None ):
75+ plugin_types = plugin_types or ['hook-definitions' , 'hooks' , 'drivers' , 'setup' ]
76+ returndata = defaultdict (list )
7577
7678 for name , obj in inspect .getmembers (mod , inspect .isclass ):
7779 if mod .__name__ != obj .__module__ :
@@ -81,13 +83,13 @@ def get_defined_plugins(mod):
8183 if obj is Hook or obj is Driver or obj is Setup :
8284 continue
8385
84- if issubclass (obj , Hook ):
86+ if issubclass (obj , Hook ) and 'hooks' in plugin_types :
8587 returndata ['hooks' ].append (obj )
86- elif issubclass (obj , HookDefinition ):
88+ elif issubclass (obj , HookDefinition ) and 'hook-definitions' in plugin_types :
8789 returndata ['hook-definitions' ].append (obj )
88- elif issubclass (obj , Driver ):
90+ elif issubclass (obj , Driver ) and 'drivers' in plugin_types :
8991 returndata ['drivers' ].append (obj )
90- elif issubclass (obj , Setup ):
92+ elif issubclass (obj , Setup ) and 'setup' in plugin_types :
9193 returndata ['setup' ].append (obj )
9294 return returndata
9395
@@ -100,13 +102,17 @@ class PluginManager:
100102 hooks : Dict [str , List [Hook ]]
101103 hook_definitions : Dict [str , Hook ]
102104
105+ store : Dict [str , Any ]
106+
103107 def __init__ (self ):
104108 self .status = defaultdict (dict )
105109 self .drivers = {}
106110 self .optional_components = {}
107111 self .hooks = defaultdict (list )
108112 self .hook_definitions = {}
109113
114+ self .store = {'task_candidates' : []}
115+
110116 def post_hook (self ):
111117 final_hooks : Dict [str , Hook ] = {}
112118 for hook_name , hooks in self .hooks .items ():
@@ -143,13 +149,7 @@ def register_driver(self, driver: Driver):
143149 logging .debug (f'Registered driver { name } ' )
144150 self .drivers [name ] = driver
145151
146- async def load_components (self ):
147- """
148- Preload the components that we are going to use.
149- The components will be available using singletons that represent connections.
150- We will therefor reuse each components connection and connect once per
151- definition in config.OPTIONAL_COMPONENTS
152- """
152+ def _preload_drivers (self ):
153153 for name , values in config .OPTIONAL_COMPONENTS .items ():
154154 name = name .lower ()
155155 load = values .get ('LOAD' , 'auto' )
@@ -164,18 +164,30 @@ async def load_components(self):
164164 f'Invalid driver specified ({ drivername } ), no way to handle it'
165165 )
166166
167- driverinstance = driver (opts = values .get ('OPTS' , {}))
167+ driverinstance = driver (opts = values .get ('OPTS' , {}), load = load )
168168 driverinstance .pm = self
169-
170- if asyncio .iscoroutinefunction (driverinstance .connect ):
171- await driverinstance .initialize_async (load = load )
172- else :
173- driverinstance .initialize (load = load )
174169 self .optional_components [name ] = driverinstance
175170
176171 logging .info (
177172 f'Connecting to { name } with driver { drivername } , using { driverinstance .opts } '
178173 )
174+ yield driverinstance
175+
176+
177+ async def load_components (self ):
178+ for driverinstance in self ._preload_drivers ():
179+ if asyncio .iscoroutinefunction (driverinstance .connect ):
180+ await driverinstance .initialize_async ()
181+ else :
182+ driverinstance .initialize ()
183+
184+ def load_sync_components_global (self ):
185+ for driverinstance in self ._preload_drivers ():
186+ if asyncio .iscoroutinefunction (driverinstance .connect ):
187+ logging .debug (f'Driver { driverinstance .name } is async, wont load' )
188+ else :
189+ driverinstance .initialize ()
190+
179191
180192 def register_hook_definition (self , obj ):
181193 try :
@@ -238,14 +250,7 @@ async def call_async(self, name, *args, **kwargs):
238250
239251 return await self .hooks [name ].run (* args , ** kwargs )
240252
241-
242- plugin_manager : PluginManager
243-
244-
245- async def startup (app ):
246- global plugin_manager
247- plugin_manager = PluginManager () # Singleton used around the app
248-
253+ def _get_plugindata ():
249254 """
250255 Plugins are imported from multiple paths with these rules:
251256 * First with a unique name wins
@@ -284,12 +289,17 @@ async def startup(app):
284289 sys .path = unique (sys_paths )
285290
286291 plugins_to_load = defaultdict (list )
292+ task_candidates = []
287293
288294 for plugin in pkgutil .iter_modules (PLUGIN_PATHS ):
289295 allow_match = os .path .join (plugin .module_finder .path , plugin .name )
296+ tasks_candidate = False
290297
291298 if plugin .ispkg :
292299 metafile = os .path .join (allow_match , 'meta.json' )
300+
301+ if os .path .exists (os .path .join (allow_match , 'tasks.py' )):
302+ tasks_candidate = True
293303 else :
294304 metafile = f'{ allow_match } -meta.json'
295305
@@ -345,24 +355,60 @@ async def startup(app):
345355 logging .info (f'Loading plugin: { plugin .name } ' )
346356 mod = import_module (plugin .name )
347357
358+ if tasks_candidate :
359+ task_candidates .append (plugin .name )
360+
348361 defined_plugins = get_defined_plugins (mod )
349362 for pt in ['hook-definitions' , 'hooks' , 'drivers' , 'setup' ]:
350363 plugins_to_load [pt ] += defined_plugins [pt ]
351364
352- for hook_definition in plugins_to_load ['hook-definitions' ]:
365+ return {'plugins_to_load' : plugins_to_load , 'task_candidates' : task_candidates }
366+
367+
368+ plugin_manager : PluginManager
369+
370+
371+ async def startup (app ):
372+ global plugin_manager
373+ plugin_manager = PluginManager ()
374+
375+ plugin_manager .store .update (** _get_plugindata ())
376+
377+ for hook_definition in plugin_manager .store ['plugins_to_load' ]['hook-definitions' ]:
353378 plugin_manager .register_hook_definition (hook_definition )
354379
355- for hook in plugins_to_load ['hooks' ]:
380+ for hook in plugin_manager . store [ ' plugins_to_load' ] ['hooks' ]:
356381 plugin_manager .register_hook (hook )
357382 plugin_manager .post_hook ()
358383
359- for driver in plugins_to_load ['drivers' ]:
384+ for driver in plugin_manager . store [ ' plugins_to_load' ] ['drivers' ]:
360385 plugin_manager .register_driver (driver )
361386 await plugin_manager .load_components ()
362387
363- for setup in plugins_to_load ['setup' ]:
388+ for setup in plugin_manager . store [ ' plugins_to_load' ] ['setup' ]:
364389 plugin_manager .run_setup (setup , {'app' : app , 'pm' : plugin_manager })
365390
391+ def startup_worker ():
392+ """
393+ This function is dedicated to celery worker startup. We can't use the regular one
394+ because it is async, and it is tailored to fastapi
395+ """
396+ global plugin_manager
397+ plugin_manager = PluginManager ()
398+
399+ plugin_manager .store .update (** _get_plugindata ())
400+
401+ for hook_definition in plugin_manager .store ['plugins_to_load' ]['hook-definitions' ]:
402+ plugin_manager .register_hook_definition (hook_definition )
403+
404+ for hook in plugin_manager .store ['plugins_to_load' ]['hooks' ]:
405+ plugin_manager .register_hook (hook )
406+ plugin_manager .post_hook ()
407+
408+ for driver in plugin_manager .store ['plugins_to_load' ]['drivers' ]:
409+ plugin_manager .register_driver (driver )
410+ plugin_manager .load_sync_components_global ()
411+
366412
367413async def shutdown ():
368414 pass
0 commit comments