-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
27 lines (22 loc) · 710 Bytes
/
train.py
File metadata and controls
27 lines (22 loc) · 710 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import deeptrack as dt
import numpy as np
import random
import os
samples = []
data_folder = "training_data\\big_particles"
for sample in os.listdir(data_folder):
samples.append(dt.LoadImage(os.path.join(data_folder, sample))()._value[:, :, :3] / 256)
model = dt.models.LodeSTAR(input_shape=(None, None, 3))
model.load_weights("models\\big_particles\\weights")
train_set = (
dt.Value(lambda: random.choice(samples))
>> dt.Add(lambda: np.random.randn() * 0.1)
>> dt.Gaussian(sigma=lambda:np.random.uniform(0, 0.2))
>> dt.Multiply(lambda: np.random.uniform(0.6, 1.2))
)
model.fit(
train_set,
epochs=40,
batch_size=8,
)
model.save_weights("models\\big_particles\\weights")