-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
138 lines (84 loc) · 5.42 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
import time
from matplotlib import pyplot as plt
import numpy as np
from sklearn.linear_model import LassoCV, Ridge, RidgeCV, ridge_regression
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import GridSearchCV, RepeatedKFold
from code.em_algorithm import EMAlgorithm
from code.lasso import ElasticNet, lambda_search
from code.utils import cv_boxplot, execution_time, generate_data, load_data, plot_E_beta, plot_E_beta_2, plot_E_beta_2_diag, plot_E_beta_difference, split_train_test, cross_validate
from code.variational_inference import VariationalInference
def em_algorithm(y, Z, X, omega_init, sigma_b2_init, sigma_e2_init, max_iter=200, tol=1e-6, cv_folds=5):
# ---------------- EM algorithm ----------------
em_params = {'max_iter': max_iter, 'tol': tol, 'omega': omega_init, 'sigma_b2': sigma_b2_init, 'sigma_e2': sigma_e2_init}
em = EMAlgorithm(**em_params)
log_marginal_likelihoods, Theta, E_beta = execution_time(em.fit, y, Z, X)
# print("Log marginal likelihoods:", log_marginal_likelihoods)
print("Estimate of Theta:", Theta)
# print("Posterior mean of beta:", E_beta)
y_train_pred = em.predict(Z, X)
print('EM Train MSE:', mean_squared_error(y, y_train_pred))
# Plot marginal likelihood
em.plot_marginal_likelihood('outputs/log_marginal_likelihoods.png')
# Cross validation
cv_mses = cross_validate(y, Z, X, EMAlgorithm, em_params, n_splits=cv_folds)
print('CV Mean MSE:', np.mean(cv_mses))
print('CV MSEs:', cv_mses)
np.savetxt('outputs/em_cv_mses.csv', cv_mses, delimiter=',')
def variational_inference(y, Z, X, omega_init, sigma_b2_init, sigma_e2_init, max_iter=200, tol=1e-6, cv_folds=5):
# ---------------- EM algorithm ----------------
vi_params = {'max_iter': max_iter, 'tol': tol, 'omega': omega_init, 'sigma_b2': sigma_b2_init, 'sigma_e2': sigma_e2_init}
vi = VariationalInference(**vi_params)
log_marginal_likelihoods, Theta = execution_time(vi.fit, y, Z, X)
print("Log marginal likelihoods:", log_marginal_likelihoods)
print("Estimate of Theta:", Theta)
y_train_pred = vi.predict(Z, X)
print('EM Train MSE:', mean_squared_error(y, y_train_pred))
# Plot marginal likelihood
vi.plot_marginal_likelihood_and_elbo('outputs/log_marginal_likelihoods_vi.png')
vi.plot_gap('outputs/gap.png')
# Cross validation
# cv_mses = cross_validate(y, Z, X, VariationalInference, vi_params, n_splits=cv_folds)
# print('CV Mean MSE:', np.mean(cv_mses))
# print('CV MSEs:', cv_mses)
# np.savetxt('outputs/vi_cv_mses.csv', cv_mses, delimiter=',')
def lasso(y, Z, X, cv_folds=10):
# ---------------- LASSO -----------------
lasso_params = {'alpha': 0.01227734, 'fit_intercept': False, 'l1_ratio': 1.0, 'max_iter': 1000}
lasso = ElasticNet(**lasso_params)
execution_time(lasso.fit, y, Z, X)
y_pred = lasso.predict(Z, X)
print('LASSO Train MSE:', mean_squared_error(y, y_pred)) # Calculate MSE
omega_beta = lasso.coef_
omega = omega_beta[:Z.shape[1]]
beta = omega_beta[Z.shape[1]:]
np.save('outputs/beta_lasso.npy', beta)
print("Estimated omega:", omega)
print("Estimated beta:", beta)
# Search for the best lambda
lambdas = np.arange(-8, 0, 0.1)
lambdas = np.exp(lambdas)
min_mse, min_average_row = execution_time(lambda_search, y, Z, X, ElasticNet, lasso_params, lambda_list = lambdas, save_path='outputs/lasso_cv_mses_all.csv', cv_folds=cv_folds)
print("Minimum MSE:", min_mse)
print("Minimum average row:", min_average_row)
np.savetxt('outputs/lasso_cv_mses.csv', min_average_row[1:], delimiter=',')
if __name__ == '__main__':
y, Z, X = load_data(path='data/lmm_y_z_x.txt')
# ---------------- EM -----------------
# em_algorithm(y, Z, X, omega_init=np.zeros(10), sigma_b2_init=1, sigma_e2_init=1, max_iter=1000, tol=1e-6, cv_folds=10)
# plot_E_beta_2(data_path='outputs/E_beta_2_em.npy', save_path='outputs/E_beta_2_em.png')
plot_E_beta(data_path='outputs/E_beta_em.npy', save_path='outputs/E_beta_em.png')
# plot_E_beta_2_diag(data_path='outputs/E_beta_2_em.npy', save_path='outputs/E_beta_2_diag_em.png')
# ---------------- Lasso -----------------
lasso(y, Z, X, cv_folds=10)
# plot_E_beta(data_path='outputs/E_beta_vi.npy', save_path='outputs/E_beta_vi.png')
# cv_boxplot(cv_mses_dirs=['outputs/em_cv_mses.csv', 'outputs/lasso_cv_mses.csv'], save_path='outputs/cv_boxplot.png', labels=['EM', 'LASSO'])
# ---------------- VI -----------------
# variational_inference(y, Z, X, omega_init=np.zeros(10), sigma_b2_init=0.8, sigma_e2_init=0.2, max_iter=100, tol=1e-6, cv_folds=10)
plot_E_beta(data_path='outputs/E_beta_vi.npy', save_path='outputs/E_beta_vi.png')
# plot_E_beta_2(data_path='outputs/E_beta_2_vi.npy', save_path='outputs/E_beta_2_vi.png')
# plot_E_beta_2_diag(data_path='outputs/E_beta_2_diag_vi.npy', save_path='outputs/E_beta_2_diag_vi.png')
# cv_boxplot(cv_mses_dirs=['outputs/em_cv_mses.csv', 'outputs/lasso_cv_mses.csv', 'outputs/vi_cv_mses.csv'], save_path='outputs/cv_boxplot_2.png', labels=['EM', 'LASSO', 'VI'])
# ---------------------------------------
plot_E_beta_difference(data_path_em='outputs/E_beta_em.npy', data_path_vi='outputs/E_beta_vi.npy', save_path='outputs/E_beta_difference.png')