Skip to content

Commit 2fd7275

Browse files
committed
Update state.py
Fix pourpossible race condition
1 parent eb9a673 commit 2fd7275

File tree

1 file changed

+47
-29
lines changed

1 file changed

+47
-29
lines changed

pyhilo/util/state.py

Lines changed: 47 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from __future__ import annotations
44

55
import asyncio
6+
import os
7+
import tempfile
68
from datetime import datetime
79
from os.path import isfile
810
from typing import Any, ForwardRef, TypedDict, TypeVar, get_type_hints
@@ -116,51 +118,69 @@ def _get_defaults(cls: type[T]) -> T:
116118
new_dict[k] = None # type: ignore[literal-required]
117119
return new_dict # type: ignore[return-value]
118120

121+
def _write_state(state_yaml: str, state: dict[str, Any]) -> None:
122+
"Write state atomically to a temp file, this prevents reading a file being written to"
119123

120-
async def get_state(state_yaml: str) -> StateDict:
124+
dir_name = os.path.dirname(os.path.abspath(state_yaml))
125+
content = yaml.dump(state)
126+
with tempfile.NamedTemporaryFile(mode = "w", dir=dir_name, delete=False, suffix=".tmp") as tmp:
127+
tmp.write(content)
128+
tmp_path = tmp.name
129+
os.replace(tmp_path, state_yaml)
130+
131+
132+
async def get_state(state_yaml: str, _already_locked: bool = False) -> StateDict:
121133
"""Read in state yaml.
122134
123135
:param state_yaml: filename where to read the state
124136
:type state_yaml: ``str``
137+
:param _already_locked: Whether the lock is already held by the caller (e.g. set_state).
138+
Prevents deadlock when corruption recovery needs to write defaults.
139+
:type _already_locked: ``bool``
125140
:rtype: ``StateDict``
126141
"""
127-
if not isfile(
128-
state_yaml
129-
): # noqa: PTH113 - isfile is fine and simpler in this case.
142+
if not isfile(state_yaml): # noqa: PTH113 - isfile is fine and simpler in this case.
130143
return _get_defaults(StateDict)
131144

132145
try:
133146
async with aiofiles.open(state_yaml, mode="r") as yaml_file:
134147
LOG.debug("Loading state from yaml")
135148
content = await yaml_file.read()
136-
state_yaml_payload: StateDict | None = yaml.safe_load(content)
137-
138-
# Handle corrupted/empty YAML files
139-
if state_yaml_payload is None or not isinstance(state_yaml_payload, dict):
140-
LOG.warning(
141-
"State file %s is corrupted or empty, reinitializing with defaults",
142-
state_yaml,
143-
)
144-
defaults = _get_defaults(StateDict)
145-
async with aiofiles.open(state_yaml, mode="w") as yaml_file_write:
146-
content = yaml.dump(defaults)
147-
await yaml_file_write.write(content)
148-
return defaults
149+
150+
state_yaml_payload: StateDict | None = yaml.safe_load(content)
151+
152+
# Handle corrupted/empty YAML files
153+
if state_yaml_payload is None or not isinstance(state_yaml_payload, dict):
154+
LOG.warning(
155+
"State file %s is corrupted or empty, reinitializing with defaults",
156+
state_yaml,
157+
)
158+
defaults = _get_defaults(StateDict)
159+
if _already_locked:
160+
_write_state(state_yaml, defaults)
161+
else:
162+
async with lock:
163+
_write_state(state_yaml, defaults)
164+
return defaults
149165

150166
return state_yaml_payload
167+
151168
except yaml.YAMLError as e:
152169
LOG.error(
153170
"Failed to parse state file %s: %s. Reinitializing with defaults.",
154171
state_yaml,
155172
e,
156173
)
157174
defaults = _get_defaults(StateDict)
158-
async with aiofiles.open(state_yaml, mode="w") as yaml_file_write:
159-
content = yaml.dump(defaults)
160-
await yaml_file_write.write(content)
175+
if _already_locked:
176+
_write_state(state_yaml, defaults)
177+
else:
178+
async with lock:
179+
_write_state(state_yaml, defaults)
161180
return defaults
162181

163182

183+
164184
async def set_state(
165185
state_yaml: str,
166186
key: str,
@@ -169,6 +189,7 @@ async def set_state(
169189
),
170190
) -> None:
171191
"""Save state yaml.
192+
172193
:param state_yaml: filename where to read the state
173194
:type state_yaml: ``str``
174195
:param key: Key name
@@ -178,14 +199,11 @@ async def set_state(
178199
:rtype: ``StateDict``
179200
"""
180201
async with lock: # note ic-dev21: on lock le fichier pour être sûr de finir la job
181-
current_state = await get_state(state_yaml) or {}
202+
current_state = await get_state(state_yaml, _already_locked=True) or {}
182203
merged_state: dict[str, Any] = {key: {**current_state.get(key, {}), **state}} # type: ignore[dict-item]
183204
new_state: dict[str, Any] = {**current_state, **merged_state}
184-
async with aiofiles.open(state_yaml, mode="w") as yaml_file:
185-
LOG.debug("Saving state to yaml file")
186-
# TODO: Use asyncio.get_running_loop() and run_in_executor to write
187-
# to the file in a non blocking manner. Currently, the file writes
188-
# are properly async but the yaml dump is done synchronously on the
189-
# main event loop.
190-
content = yaml.dump(new_state)
191-
await yaml_file.write(content)
205+
LOG.debug("Saving state to yaml file")
206+
# TODO: Use asyncio.get_running_loop() and run_in_executor to write
207+
# to the file in a non blocking manner. Currently, yaml.dump is
208+
# synchronous on the main event loop.
209+
_write_state_sync(state_yaml, new_state)

0 commit comments

Comments
 (0)