-
Notifications
You must be signed in to change notification settings - Fork 0
/
linefit_gradient.py
69 lines (54 loc) · 1.93 KB
/
linefit_gradient.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import numpy as np
import matplotlib.pyplot as plt
# Gradient descent function to fit a line ax + by + c = 0
def gradient_descent(x, y, learning_rate=0.01, num_iterations=100):
# Initial parameters a, b, c
a = 1.0
b = 1.0
c = 1.0
for iter in range(num_iterations):
# Compute the common denominator d
d = (a**2 + b**2)
# Compute the gradients
da = 2 * np.sum((a*x + b*y + c) * (x * d - a * (a*x + b*y + c)) / d**2)
db = 2 * np.sum((a*x + b*y + c) * (y * d - b * (a*x + b*y + c)) / d**2)
dc = 2 * np.sum((a * x + b * y + c) / d)
# Update parameters
a -= learning_rate * da
b -= learning_rate * db
c -= learning_rate * dc
# Print loss
loss = np.sum((a*x + b*y + c)**2/(a**2 + b**2))
print(f"Iter={iter}: loss = {loss}")
# Plotting the data points and the fitted line
plt.scatter(x, y, color='blue', label='Data points')
# Calculate fitted line points
x_vals = np.linspace(min(x), max(x), 100)
y_vals = -(a * x_vals + c) / b
plt.plot(x_vals, y_vals, color='red', label='Fitted line')
plt.xlim(0, 6)
plt.xlabel('x')
plt.ylabel('y')
plt.legend()
plt.title('Line Fitting using Gradient Descent')
plt.pause(0.1)
plt.clf()
return a, b, c
# Example usage
# Sample data points
x = np.array([1, 2, 3, 4, 5])
y = np.array([2, 2, 3, 5, 4])
# Run gradient descent
a, b, c = gradient_descent(x, y)
print(f"Fitted line parameters: a = {a}, b = {b}, c = {c}")
# Plotting the data points and the fitted line
plt.scatter(x, y, color='blue', label='Data points')
# Calculate fitted line points
x_vals = np.linspace(min(x), max(x), 100)
y_vals = -(a * x_vals + c) / b
plt.plot(x_vals, y_vals, color='red', label='Fitted line')
plt.xlabel('x')
plt.ylabel('y')
plt.legend()
plt.title('Line Fitting using Gradient Descent')
plt.show()