33from __future__ import annotations
44
55import asyncio
6+ import os
7+ import tempfile
68from datetime import datetime
79from os .path import isfile
810from typing import Any , ForwardRef , TypedDict , TypeVar , get_type_hints
@@ -116,51 +118,69 @@ def _get_defaults(cls: type[T]) -> T:
116118 new_dict [k ] = None # type: ignore[literal-required]
117119 return new_dict # type: ignore[return-value]
118120
121+ def _write_state (state_yaml : str , state : dict [str , Any ]) -> None :
122+ "Write state atomically to a temp file, this prevents reading a file being written to"
119123
120- async def get_state (state_yaml : str ) -> StateDict :
124+ dir_name = os .path .dirname (os .path .abspath (state_yaml ))
125+ content = yaml .dump (state )
126+ with tempfile .NamedTemporaryFile (mode = "w" , dir = dir_name , delete = False , suffix = ".tmp" ) as tmp :
127+ tmp .write (content )
128+ tmp_path = tmp .name
129+ os .replace (tmp_path , state_yaml )
130+
131+
132+ async def get_state (state_yaml : str , _already_locked : bool = False ) -> StateDict :
121133 """Read in state yaml.
122134
123135 :param state_yaml: filename where to read the state
124136 :type state_yaml: ``str``
137+ :param _already_locked: Whether the lock is already held by the caller (e.g. set_state).
138+ Prevents deadlock when corruption recovery needs to write defaults.
139+ :type _already_locked: ``bool``
125140 :rtype: ``StateDict``
126141 """
127- if not isfile (
128- state_yaml
129- ): # noqa: PTH113 - isfile is fine and simpler in this case.
142+ if not isfile (state_yaml ): # noqa: PTH113 - isfile is fine and simpler in this case.
130143 return _get_defaults (StateDict )
131144
132145 try :
133146 async with aiofiles .open (state_yaml , mode = "r" ) as yaml_file :
134147 LOG .debug ("Loading state from yaml" )
135148 content = await yaml_file .read ()
136- state_yaml_payload : StateDict | None = yaml .safe_load (content )
137-
138- # Handle corrupted/empty YAML files
139- if state_yaml_payload is None or not isinstance (state_yaml_payload , dict ):
140- LOG .warning (
141- "State file %s is corrupted or empty, reinitializing with defaults" ,
142- state_yaml ,
143- )
144- defaults = _get_defaults (StateDict )
145- async with aiofiles .open (state_yaml , mode = "w" ) as yaml_file_write :
146- content = yaml .dump (defaults )
147- await yaml_file_write .write (content )
148- return defaults
149+
150+ state_yaml_payload : StateDict | None = yaml .safe_load (content )
151+
152+ # Handle corrupted/empty YAML files
153+ if state_yaml_payload is None or not isinstance (state_yaml_payload , dict ):
154+ LOG .warning (
155+ "State file %s is corrupted or empty, reinitializing with defaults" ,
156+ state_yaml ,
157+ )
158+ defaults = _get_defaults (StateDict )
159+ if _already_locked :
160+ _write_state (state_yaml , defaults )
161+ else :
162+ async with lock :
163+ _write_state (state_yaml , defaults )
164+ return defaults
149165
150166 return state_yaml_payload
167+
151168 except yaml .YAMLError as e :
152169 LOG .error (
153170 "Failed to parse state file %s: %s. Reinitializing with defaults." ,
154171 state_yaml ,
155172 e ,
156173 )
157174 defaults = _get_defaults (StateDict )
158- async with aiofiles .open (state_yaml , mode = "w" ) as yaml_file_write :
159- content = yaml .dump (defaults )
160- await yaml_file_write .write (content )
175+ if _already_locked :
176+ _write_state (state_yaml , defaults )
177+ else :
178+ async with lock :
179+ _write_state (state_yaml , defaults )
161180 return defaults
162181
163182
183+
164184async def set_state (
165185 state_yaml : str ,
166186 key : str ,
@@ -169,6 +189,7 @@ async def set_state(
169189 ),
170190) -> None :
171191 """Save state yaml.
192+
172193 :param state_yaml: filename where to read the state
173194 :type state_yaml: ``str``
174195 :param key: Key name
@@ -178,14 +199,11 @@ async def set_state(
178199 :rtype: ``StateDict``
179200 """
180201 async with lock : # note ic-dev21: on lock le fichier pour être sûr de finir la job
181- current_state = await get_state (state_yaml ) or {}
202+ current_state = await get_state (state_yaml , _already_locked = True ) or {}
182203 merged_state : dict [str , Any ] = {key : {** current_state .get (key , {}), ** state }} # type: ignore[dict-item]
183204 new_state : dict [str , Any ] = {** current_state , ** merged_state }
184- async with aiofiles .open (state_yaml , mode = "w" ) as yaml_file :
185- LOG .debug ("Saving state to yaml file" )
186- # TODO: Use asyncio.get_running_loop() and run_in_executor to write
187- # to the file in a non blocking manner. Currently, the file writes
188- # are properly async but the yaml dump is done synchronously on the
189- # main event loop.
190- content = yaml .dump (new_state )
191- await yaml_file .write (content )
205+ LOG .debug ("Saving state to yaml file" )
206+ # TODO: Use asyncio.get_running_loop() and run_in_executor to write
207+ # to the file in a non blocking manner. Currently, yaml.dump is
208+ # synchronous on the main event loop.
209+ _write_state_sync (state_yaml , new_state )
0 commit comments