Skip to content

Commit 746438a

Browse files
Fix CI: run ruff format
1 parent 52ff4f1 commit 746438a

12 files changed

Lines changed: 216 additions & 140 deletions

File tree

backend/app/api/routes/generate.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ def _cleanup_old_jobs() -> None:
4444

4545
now = time.monotonic()
4646
expired = [
47-
jid for jid, job in _jobs.items()
48-
if now - job.get("created_at", now) > _JOB_TTL_SECONDS
47+
jid for jid, job in _jobs.items() if now - job.get("created_at", now) > _JOB_TTL_SECONDS
4948
]
5049
for jid in expired:
5150
del _jobs[jid]
@@ -106,8 +105,7 @@ async def generate_handwriting(
106105
stream_url = f"{base_url}/api/stream/{job_id}"
107106

108107
logger.info(
109-
f"Job {job_id} created: {len(request_body.text)} chars, "
110-
f"style={request_body.style_id}"
108+
f"Job {job_id} created: {len(request_body.text)} chars, style={request_body.style_id}"
111109
)
112110

113111
return GenerateResponse(
@@ -150,7 +148,9 @@ async def websocket_stream(websocket: WebSocket, job_id: str) -> None:
150148
return
151149

152150
if not engine.is_ready:
153-
await websocket.send_json({"type": "error", "message": "Engine not ready — model still loading"})
151+
await websocket.send_json(
152+
{"type": "error", "message": "Engine not ready — model still loading"}
153+
)
154154
await websocket.close(code=4003)
155155
return
156156

@@ -285,10 +285,12 @@ async def event_generator():
285285
job["status"] = "failed"
286286
job["error"] = str(e)
287287

288-
error_event = json.dumps({
289-
"type": "error",
290-
"message": str(e),
291-
})
288+
error_event = json.dumps(
289+
{
290+
"type": "error",
291+
"message": str(e),
292+
}
293+
)
292294
yield f"data: {error_event}\n\n"
293295

294296
return StreamingResponse(

backend/app/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
# Lifespan — Load model ONCE on startup, release on shutdown
3434
# ============================================================
3535

36+
3637
@asynccontextmanager
3738
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
3839
"""

backend/app/ml/dataset.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,13 @@ def _load_from_stroke_files(self, split: str) -> list[dict]:
132132
else:
133133
text_str = str(text_val)
134134

135-
samples.append({
136-
"strokes": data["strokes"].tolist(),
137-
"text": text_str,
138-
"writer_id": writer_id,
139-
})
135+
samples.append(
136+
{
137+
"strokes": data["strokes"].tolist(),
138+
"text": text_str,
139+
"writer_id": writer_id,
140+
}
141+
)
140142
except Exception:
141143
# Skip files that cannot be loaded without pickle or have other errors
142144
continue
@@ -344,7 +346,7 @@ def parse_iam_xml(xml_path: Path) -> list[tuple[float, float, int, int, int]]:
344346
for i, (x, y, pen_down) in enumerate(abs_points):
345347
dx = x - prev_x
346348
dy = y - prev_y
347-
is_last = (i == len(abs_points) - 1)
349+
is_last = i == len(abs_points) - 1
348350

349351
if is_last:
350352
p1, p2, p3 = 0, 0, 1 # end-of-sequence

backend/app/ml/llm_engine.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,10 @@
3737
# Engine Configuration
3838
# ============================================================
3939

40+
4041
class QuantizationMode(str, Enum):
4142
"""Quantization strategy for model weights."""
43+
4244
NONE = "fp16"
4345
INT8 = "int8"
4446
INT4 = "int4"
@@ -51,6 +53,7 @@ class EngineConfig:
5153
Maps directly to Settings fields — separated so the engine
5254
has no hard dependency on FastAPI/Pydantic.
5355
"""
56+
5457
model_name: str = "inkforge-lstm-mdn-v1"
5558
checkpoint_path: str = "checkpoints/lstm_mdn_v1_best.pt"
5659
device: str = "cpu"
@@ -66,6 +69,7 @@ class EngineConfig:
6669
@dataclass
6770
class EngineStatus:
6871
"""Runtime status snapshot of the engine."""
72+
6973
model_loaded: bool = False
7074
model_name: str = ""
7175
engine_backend: str = "mock"
@@ -87,6 +91,7 @@ class EngineStatus:
8791
# Singleton LLM Engine
8892
# ============================================================
8993

94+
9095
class LLMEngine:
9196
"""
9297
Singleton engine that manages the lifecycle of the LSTM+MDN model.
@@ -179,10 +184,11 @@ async def initialize_model(self, config: EngineConfig) -> None:
179184
if config.device == "cuda":
180185
try:
181186
import torch
187+
182188
if torch.cuda.is_available():
183189
self._gpu_name = torch.cuda.get_device_name(0)
184190
self._vram_total_gb = round(
185-
torch.cuda.get_device_properties(0).total_memory / (1024 ** 3), 1
191+
torch.cuda.get_device_properties(0).total_memory / (1024**3), 1
186192
)
187193
logger.info(f" → CUDA device found: {self._gpu_name}")
188194
logger.info(f" → Total VRAM: {self._vram_total_gb} GB")
@@ -272,6 +278,7 @@ async def shutdown(self) -> None:
272278
# Clear CUDA cache if available
273279
try:
274280
import torch
281+
275282
if torch.cuda.is_available():
276283
torch.cuda.empty_cache()
277284
except ImportError:
@@ -493,8 +500,12 @@ async def _stream_mock(
493500
line_num += 1
494501

495502
# Baseline drift (Exaggerated for visual awareness in mock mode)
496-
global_drift = baseline_drift * 8.0 * math.sin(
497-
2 * math.pi * line_num / max(total_words / 5.0, 3.0) + random.uniform(0, 0.5)
503+
global_drift = (
504+
baseline_drift
505+
* 8.0
506+
* math.sin(
507+
2 * math.pi * line_num / max(total_words / 5.0, 3.0) + random.uniform(0, 0.5)
508+
)
498509
)
499510

500511
char_x = cursor_x
@@ -543,7 +554,7 @@ async def _stream_mock(
543554
char_x += dx
544555

545556
delay = config.stream_chunk_delay_ms / 1000.0
546-
delay *= random.uniform(0.2, 0.8) # faster for mock strokes
557+
delay *= random.uniform(0.2, 0.8) # faster for mock strokes
547558
await asyncio.sleep(delay)
548559

549560
# Pen-up between words
@@ -580,8 +591,7 @@ async def _stream_mock(
580591
}
581592

582593
logger.info(
583-
f"[req-{request_id}] Complete (mock): "
584-
f"{stroke_index} strokes, {line_num + 1} lines"
594+
f"[req-{request_id}] Complete (mock): {stroke_index} strokes, {line_num + 1} lines"
585595
)
586596

587597
# --------------------------------------------------------

backend/app/ml/model.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __init__(
8989
# Total per mixture: 6 parameters
9090
# Plus 3 pen state logits
9191
mdn_output_dim = num_mixtures * 6 # π, μx, μy, σx, σy, ρ
92-
pen_state_dim = 3 # p1, p2, p3
92+
pen_state_dim = 3 # p1, p2, p3
9393

9494
self.mdn_head = nn.Linear(hidden_dim, mdn_output_dim)
9595
self.pen_head = nn.Linear(hidden_dim, pen_state_dim)
@@ -163,18 +163,18 @@ def sample(
163163
# Each component has 6 params, total = M*6
164164
params = mdn_params.view(num_m, 6)
165165

166-
pi_logits = params[:, 0] # Mixture weights (logits)
167-
mu_x = params[:, 1] # Mean x
168-
mu_y = params[:, 2] # Mean y
169-
sigma_x = torch.exp(params[:, 3]) # Std x (exp to ensure positive)
170-
sigma_y = torch.exp(params[:, 4]) # Std y
171-
rho = torch.tanh(params[:, 5]) # Correlation (tanh to bound [-1, 1])
166+
pi_logits = params[:, 0] # Mixture weights (logits)
167+
mu_x = params[:, 1] # Mean x
168+
mu_y = params[:, 2] # Mean y
169+
sigma_x = torch.exp(params[:, 3]) # Std x (exp to ensure positive)
170+
sigma_y = torch.exp(params[:, 4]) # Std y
171+
rho = torch.tanh(params[:, 5]) # Correlation (tanh to bound [-1, 1])
172172

173173
# 2. Apply temperature
174174
# Scale mixture logits by 1/τ, scale sigmas by √τ
175175
pi = torch.softmax(pi_logits / temperature, dim=0)
176-
sigma_x = sigma_x * (temperature ** 0.5)
177-
sigma_y = sigma_y * (temperature ** 0.5)
176+
sigma_x = sigma_x * (temperature**0.5)
177+
sigma_y = sigma_y * (temperature**0.5)
178178

179179
# 3. Sample mixture component from categorical distribution
180180
mixture_idx = torch.multinomial(pi, 1).item()
@@ -193,7 +193,7 @@ def sample(
193193
z2 = torch.randn(1).item()
194194

195195
dx = mu_x_k + sigma_x_k * z1
196-
dy = mu_y_k + sigma_y_k * (rho_k * z1 + (1 - rho_k ** 2) ** 0.5 * z2)
196+
dy = mu_y_k + sigma_y_k * (rho_k * z1 + (1 - rho_k**2) ** 0.5 * z2)
197197

198198
# 5. Sample pen state from Bernoulli
199199
pen_probs = torch.softmax(pen_logits / temperature, dim=0)
@@ -256,19 +256,16 @@ def __init__(self, style_dim: int = STYLE_DIM, input_channels: int = 1) -> None:
256256
nn.BatchNorm2d(32),
257257
nn.ReLU(inplace=True),
258258
nn.MaxPool2d(2, 2),
259-
260259
# Block 2: 32 -> 64 channels
261260
nn.Conv2d(32, 64, kernel_size=3, padding=1),
262261
nn.BatchNorm2d(64),
263262
nn.ReLU(inplace=True),
264263
nn.MaxPool2d(2, 2),
265-
266264
# Block 3: 64 -> 128 channels
267265
nn.Conv2d(64, 128, kernel_size=3, padding=1),
268266
nn.BatchNorm2d(128),
269267
nn.ReLU(inplace=True),
270268
nn.MaxPool2d(2, 2),
271-
272269
# Block 4: 128 -> 256 channels
273270
nn.Conv2d(128, 256, kernel_size=3, padding=1),
274271
nn.BatchNorm2d(256),

backend/app/ml/utils.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,18 @@ def compute_mdn_loss(
3232
Scalar loss tensor.
3333
"""
3434
batch_size, seq_len, _ = mdn_params.shape
35-
num_m = num_mixtures # Renamed num_m to M for consistency with patch
35+
num_m = num_mixtures # Renamed num_m to M for consistency with patch
3636

3737
# Reshape MDN params: [batch, seq, M*6] -> [batch, seq, M, 6]
3838
params = mdn_params.view(batch_size, seq_len, num_m, 6)
3939

4040
# Extract parameters
41-
pi_logits = params[:, :, :, 0] # [batch, seq, M]
42-
mu_x = params[:, :, :, 1] # [batch, seq, M]
43-
mu_y = params[:, :, :, 2] # [batch, seq, M]
44-
sigma_x_raw = params[:, :, :, 3] # [batch, seq, M]
45-
sigma_y_raw = params[:, :, :, 4] # [batch, seq, M]
46-
rho_raw = params[:, :, :, 5] # [batch, seq, M]
41+
pi_logits = params[:, :, :, 0] # [batch, seq, M]
42+
mu_x = params[:, :, :, 1] # [batch, seq, M]
43+
mu_y = params[:, :, :, 2] # [batch, seq, M]
44+
sigma_x_raw = params[:, :, :, 3] # [batch, seq, M]
45+
sigma_y_raw = params[:, :, :, 4] # [batch, seq, M]
46+
rho_raw = params[:, :, :, 5] # [batch, seq, M]
4747

4848
# Apply activations to ensure valid parameter ranges
4949
sigma_x = torch.exp(sigma_x_raw)
@@ -63,18 +63,18 @@ def compute_mdn_loss(
6363

6464
dx = (target_x - mu_x) / sigma_x
6565
dy = (target_y - mu_y) / sigma_y
66-
rho_sq = rho ** 2
66+
rho_sq = rho**2
6767

6868
# Avoid division by zero
6969
one_minus_rho_sq = (1 - rho_sq).clamp(min=1e-6)
7070

71-
z_val = dx ** 2 + dy ** 2 - 2 * rho * dx * dy # Corrected Z variable name from patch
71+
z_val = dx**2 + dy**2 - 2 * rho * dx * dy # Corrected Z variable name from patch
7272
log_gaussian = (
7373
-math.log(2 * math.pi)
74-
- torch.log(sigma_x) # Changed from log_sigma_x
75-
- torch.log(sigma_y) # Changed from log_sigma_y
74+
- torch.log(sigma_x) # Changed from log_sigma_x
75+
- torch.log(sigma_y) # Changed from log_sigma_y
7676
- 0.5 * torch.log(one_minus_rho_sq)
77-
- z_val / (2 * one_minus_rho_sq) # Changed from Z to z_val
77+
- z_val / (2 * one_minus_rho_sq) # Changed from Z to z_val
7878
) # [batch, seq, M]
7979

8080
# Weighted sum using log-sum-exp: log(Σ π_k * N_k) = logsumexp(log π_k + log N_k)
@@ -86,7 +86,9 @@ def compute_mdn_loss(
8686
# Pen state loss: cross-entropy (target must be long/int64)
8787
pen_logits_flat = pen_logits.view(-1, 3) # [batch*seq, 3]
8888
target_pen_flat = target_pen.long().view(-1) # [batch*seq]
89-
pen_loss = functional.cross_entropy(pen_logits_flat, target_pen_flat) # Corrected syntax from patch
89+
pen_loss = functional.cross_entropy(
90+
pen_logits_flat, target_pen_flat
91+
) # Corrected syntax from patch
9092

9193
# Total loss
9294
total_loss = stroke_loss + pen_loss
@@ -176,7 +178,9 @@ def strokes_to_absolute(strokes: list[tuple[float, ...]]) -> list[tuple[float, f
176178
return absolute
177179

178180

179-
def absolute_to_strokes(points: list[tuple[float, float, int]]) -> list[tuple[float, float, int, int, int]]:
181+
def absolute_to_strokes(
182+
points: list[tuple[float, float, int]],
183+
) -> list[tuple[float, float, int, int, int]]:
180184
"""
181185
Convert absolute coordinates to relative stroke deltas.
182186

0 commit comments

Comments
 (0)