-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathline_fitting_sklearn.py
More file actions
30 lines (25 loc) · 952 Bytes
/
line_fitting_sklearn.py
File metadata and controls
30 lines (25 loc) · 952 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import numpy as np
import matplotlib.pyplot as plt
from sklearn import linear_model
true_line = lambda x: -2/3*x + 14/3
data_range = np.array([-4, 12])
data_num = 100
noise_std = 0.5
# Generate the true data
x = np.random.uniform(data_range[0], data_range[1], size=data_num)
y = true_line(x) # y = -2/3*x + 10/3
# Add Gaussian noise
xn = x + np.random.normal(scale=noise_std, size=x.shape)
yn = y + np.random.normal(scale=noise_std, size=y.shape)
# Train a model
model = linear_model.LinearRegression()
model.fit(xn.reshape(-1, 1), yn)
score = model.score(xn.reshape(-1, 1), yn)
# Plot the data and result
plt.title(f'Line: y={model.coef_[0]:.3f}*x + {model.intercept_:.3f} (score={score:.3f})')
plt.plot(data_range, true_line(data_range), 'r-', label='The true line')
plt.plot(xn, yn, 'b.', label='Noisy data')
plt.plot(data_range, model.coef_[0]*data_range + model.intercept_, 'g-', label='Estimate')
plt.xlim(data_range)
plt.legend()
plt.show()