@@ -32,18 +32,18 @@ def compute_mdn_loss(
3232 Scalar loss tensor.
3333 """
3434 batch_size , seq_len , _ = mdn_params .shape
35- num_m = num_mixtures # Renamed num_m to M for consistency with patch
35+ num_m = num_mixtures # Renamed num_m to M for consistency with patch
3636
3737 # Reshape MDN params: [batch, seq, M*6] -> [batch, seq, M, 6]
3838 params = mdn_params .view (batch_size , seq_len , num_m , 6 )
3939
4040 # Extract parameters
41- pi_logits = params [:, :, :, 0 ] # [batch, seq, M]
42- mu_x = params [:, :, :, 1 ] # [batch, seq, M]
43- mu_y = params [:, :, :, 2 ] # [batch, seq, M]
44- sigma_x_raw = params [:, :, :, 3 ] # [batch, seq, M]
45- sigma_y_raw = params [:, :, :, 4 ] # [batch, seq, M]
46- rho_raw = params [:, :, :, 5 ] # [batch, seq, M]
41+ pi_logits = params [:, :, :, 0 ] # [batch, seq, M]
42+ mu_x = params [:, :, :, 1 ] # [batch, seq, M]
43+ mu_y = params [:, :, :, 2 ] # [batch, seq, M]
44+ sigma_x_raw = params [:, :, :, 3 ] # [batch, seq, M]
45+ sigma_y_raw = params [:, :, :, 4 ] # [batch, seq, M]
46+ rho_raw = params [:, :, :, 5 ] # [batch, seq, M]
4747
4848 # Apply activations to ensure valid parameter ranges
4949 sigma_x = torch .exp (sigma_x_raw )
@@ -63,18 +63,18 @@ def compute_mdn_loss(
6363
6464 dx = (target_x - mu_x ) / sigma_x
6565 dy = (target_y - mu_y ) / sigma_y
66- rho_sq = rho ** 2
66+ rho_sq = rho ** 2
6767
6868 # Avoid division by zero
6969 one_minus_rho_sq = (1 - rho_sq ).clamp (min = 1e-6 )
7070
71- z_val = dx ** 2 + dy ** 2 - 2 * rho * dx * dy # Corrected Z variable name from patch
71+ z_val = dx ** 2 + dy ** 2 - 2 * rho * dx * dy # Corrected Z variable name from patch
7272 log_gaussian = (
7373 - math .log (2 * math .pi )
74- - torch .log (sigma_x ) # Changed from log_sigma_x
75- - torch .log (sigma_y ) # Changed from log_sigma_y
74+ - torch .log (sigma_x ) # Changed from log_sigma_x
75+ - torch .log (sigma_y ) # Changed from log_sigma_y
7676 - 0.5 * torch .log (one_minus_rho_sq )
77- - z_val / (2 * one_minus_rho_sq ) # Changed from Z to z_val
77+ - z_val / (2 * one_minus_rho_sq ) # Changed from Z to z_val
7878 ) # [batch, seq, M]
7979
8080 # Weighted sum using log-sum-exp: log(Σ π_k * N_k) = logsumexp(log π_k + log N_k)
@@ -86,7 +86,9 @@ def compute_mdn_loss(
8686 # Pen state loss: cross-entropy (target must be long/int64)
8787 pen_logits_flat = pen_logits .view (- 1 , 3 ) # [batch*seq, 3]
8888 target_pen_flat = target_pen .long ().view (- 1 ) # [batch*seq]
89- pen_loss = functional .cross_entropy (pen_logits_flat , target_pen_flat ) # Corrected syntax from patch
89+ pen_loss = functional .cross_entropy (
90+ pen_logits_flat , target_pen_flat
91+ ) # Corrected syntax from patch
9092
9193 # Total loss
9294 total_loss = stroke_loss + pen_loss
@@ -176,7 +178,9 @@ def strokes_to_absolute(strokes: list[tuple[float, ...]]) -> list[tuple[float, f
176178 return absolute
177179
178180
179- def absolute_to_strokes (points : list [tuple [float , float , int ]]) -> list [tuple [float , float , int , int , int ]]:
181+ def absolute_to_strokes (
182+ points : list [tuple [float , float , int ]],
183+ ) -> list [tuple [float , float , int , int , int ]]:
180184 """
181185 Convert absolute coordinates to relative stroke deltas.
182186
0 commit comments