55import logging
66import pkgutil
77from importlib import import_module
8- from typing import Dict , Any , List
8+ from typing import Dict , Any , List , Callable
99
1010from fastapi import FastAPI
1111
@@ -19,6 +19,7 @@ class BasePlugin:
1919
2020class Component :
2121 instance = None
22+ hooks : Dict [str , dict ]
2223
2324 def get (self ):
2425 return self .instance
@@ -37,6 +38,7 @@ class PluginManager:
3738 optional_components : Dict [str , Component ]
3839
3940 hooks : Dict [str , Dict ]
41+ _temp_hook_funcs : Dict [str , Callable ]
4042
4143 def __init__ (self ):
4244 self .status = defaultdict (dict )
@@ -47,12 +49,22 @@ def __init__(self):
4749 """
4850 Each hook have multiple options, those are.
4951 * required: True|False, if app-start should fail if it is missing.
50- * mode: one of
51- * IMMUTABLE: Hook can only be set once, it can not be overwritten.
5252 """
53- self .hooks = {'version' : {'mode' : 'IMMUTABLE' }}
53+ self .hooks = {'version' : {}}
5454
55- def check_required (self ):
55+ self ._temp_hook_funcs = {}
56+
57+ def post_hook_registrations (self ):
58+ self .finish_hook_registration ()
59+ self .check_invalid_hooks ()
60+ self .check_required_hooks ()
61+
62+ def finish_hook_registration (self ):
63+ for name in self .hooks .keys ():
64+ if name in self ._temp_hook_funcs :
65+ self .hooks [name ]['func' ] = self ._temp_hook_funcs [name ]
66+
67+ def check_required_hooks (self ):
5668 missing = []
5769 for hookname , config in self .hooks .items ():
5870 if config .get ('required' ):
@@ -61,6 +73,11 @@ def check_required(self):
6173 if missing :
6274 raise Exception (f'Missing required hooks: { missing } ' )
6375
76+ def check_invalid_hooks (self ):
77+ for hookname in self ._temp_hook_funcs :
78+ if not hookname in self .hooks :
79+ raise Exception (f'Hook "{ hookname } " is not registered for use.' )
80+
6481 def register_driver (self , name : str , component : Component ):
6582 name = name .lower ()
6683 if name in self .component_drivers :
@@ -99,19 +116,18 @@ async def load_components(self):
99116 connection_status = driverinstance .connect (opts = values .get ('OPTS' , {}))
100117 self .optional_components [name ] = driverinstance
101118
102- def register_hook (self , name , func ):
103- if name not in self .hooks :
104- raise Exception (f'Invalid plugin hook "{ name } ", see docs for valid hooks.' )
119+ def add_hooks (self , hooks ):
120+ for k , v in hooks .items ():
121+ if k in self .hooks :
122+ raise Exception (f'Hook "{ k } " can only be added once' )
123+ logging .debug (f'Adding hook { k } with data { v } ' )
124+ self .hooks [k ] = v
105125
106- if (
107- self .hooks [name ].get ('mode' ) == 'IMMUTABLE'
108- and self .hooks [name ].get ('func' ) is not None
109- ):
110- raise Exception (
111- f'Hook { name } is already occupied and it is marked as IMMUTABLE'
112- )
126+ def register_hook (self , name , func ):
127+ if name in self ._temp_hook_funcs :
128+ raise Exception (f'Hook "{ name } " is already handled' )
113129
114- self .hooks [name ][ 'func' ] = func
130+ self ._temp_hook_funcs [name ] = func
115131
116132 def run_setup_queue (self , app ):
117133 for plugin in self .setup_queue :
@@ -239,6 +255,9 @@ async def startup():
239255
240256 plugin_manager .setup_queue .append ({'obj' : obj , 'name' : plugin .name })
241257
258+ if hasattr (obj , 'hooks' ):
259+ plugin_manager .add_hooks (obj .hooks )
260+
242261 if hasattr (obj , 'startup' ):
243262 plugin_manager .status [plugin .name ]['startup' ] = obj .startup (
244263 ** filter_dict_to_function (
@@ -251,7 +270,7 @@ async def startup():
251270 )
252271
253272 await plugin_manager .load_components ()
254- plugin_manager .check_required ()
273+ plugin_manager .post_hook_registrations ()
255274
256275
257276async def shutdown ():
0 commit comments