Skip to content

Commit

Permalink
fixed discrepancy plotting and now weights are init at every update
Browse files Browse the repository at this point in the history
  • Loading branch information
daneschi committed Feb 1, 2024
1 parent 3d33685 commit 893e986
Show file tree
Hide file tree
Showing 26 changed files with 373 additions and 2,954 deletions.
3 changes: 2 additions & 1 deletion linfa/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from linfa.tests.test_linfa import linfa_test_suite
from linfa.tests.test_linfa import linfa_test_suite
from linfa.tests.test_linfa_discrepancy import linfa_test_suite_discrepancy
40 changes: 23 additions & 17 deletions linfa/discrepancy.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -69,21 +70,21 @@ def __init__(self, model_name,

if(var_grid_in is None):

self.device = None
self.input_size = None
self.output_size = None
self.dnn_arch = None
self.device = None
self.input_size = None
self.output_size = None
self.dnn_arch = None
self.dnn_activation = None
self.dnn_dropout = None
self.is_trained = None
self.lf_model = None
self.var_grid_in = None
self.var_grid_out = None
self.var_in_avg = None
self.var_in_std = None
self.var_out_avg = None
self.var_out_std = None
self.surrogate = None
self.dnn_dropout = None
self.is_trained = None
self.lf_model = None
self.var_grid_in = None
self.var_grid_out = None
self.var_in_avg = None
self.var_in_std = None
self.var_out_avg = None
self.var_out_std = None
self.surrogate = None

else:

Expand Down Expand Up @@ -190,9 +191,14 @@ def update(self, batch_x, max_iters=10000, lr=0.01, lr_exp=0.999, record_interva
optimizer = torch.optim.RMSprop(self.surrogate.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, lr_exp)

# Set surrogate in training mode
self.surrogate.train()

# Init weights
self.surrogate.init_weight()

# Loop on iterations - epochs with fill batch
for i in range(max_iters):
# Set surrogate in training mode
self.surrogate.train()

# Surrogate returns a table with rows as batches and columns as variables considered
disc = self.surrogate(var_grid)
Expand All @@ -201,7 +207,7 @@ def update(self, batch_x, max_iters=10000, lr=0.01, lr_exp=0.999, record_interva
# Mean over the columns (batches)
# Mean over the rows (variables)
# Also we need to account for the number of repeated observations
loss = torch.tensor(0.0)
loss = torch.tensor(0.0)
# Loop over the number of observations
for loopA in range(var_out.size(1)):
loss += torch.sum((disc.flatten() - var_out[:,loopA]) ** 2)
Expand Down
7 changes: 7 additions & 0 deletions linfa/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
import torch.nn as nn
import torch.nn.functional as F

def init_weights(m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.0)

class FNN(nn.Module):
"""Fully Connected Neural Network"""

Expand Down Expand Up @@ -38,6 +43,8 @@ def __init__(self, input_size, output_size, arch=None, activation='relu', device
self.dropout_visible = nn.Dropout(p=self.dropout[0])
self.dropout_hidden = nn.Dropout(p=self.dropout[1])

def init_weight(self):
self.apply(init_weights)

def forward(self, x):
"""
Expand Down
62 changes: 32 additions & 30 deletions linfa/plot_disc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,18 @@ def plot_disr_histograms(lf_file, lf_dicr_file, lf_discr_noise_file, data_file,
# Check for repated observations
## shape : no. var inputs pairs x no. batches
num_dim = len(np.shape(lf_model))

if num_dim == 1:

if(num_dim == 1):

# DES: NEED TO TEST FOR TP1 I THINK !!!

# Plot histograms
plt.figure(figsize = (6,4))
ax = plt.gca()
plt.hist(lf_model, label = r'$\eta \vert \mathbf{\theta}$', alpha = 0.5, density = True, hatch = '/')
plt.hist(lf_model_plus_disc, label = r'$\zeta \vert \mathbf{\theta}, \delta$', alpha = 0.5, density = True)
plt.hist(lf_model_plus_disc_plus_noise, label = r'$y \vert \mathbf{\theta}, \delta, \epsilon$', alpha = 0.5, density = True, hatch = '.')

for loopA in range(len(data[2:])):
if loopA == 0:
# add label to legend
Expand Down Expand Up @@ -84,10 +87,7 @@ def plot_disr_histograms(lf_file, lf_dicr_file, lf_discr_noise_file, data_file,
sample[loopA, loopB] = lf_model_plus_disc[loopA, loopB, random_array]

# Plot function & save line psroperties for legend
plt.plot(np.tile(pressures, (sample_size, 1)).transpose(),
sample[loopA],
linewidth=0.01,
color=clrs[loopA])
plt.plot(np.tile(pressures, (sample_size, 1)).transpose(), sample[loopA], linewidth=0.1, color=clrs[loopA])

for loopC in range(no_reps):
line = plt.plot(np.tile(pressures, (sample_size, 1)).transpose(),
Expand Down Expand Up @@ -197,6 +197,11 @@ def plot_discr_surface_2d(file_path, lf_file, data_file, num_1d_grid_points, dat
else:
# Define test grid
test_grid = prep_test_grid(dicr.var_grid_in, data_limit_factor, num_1d_grid_points)

## Variable input 1
min_dim_1 = torch.min(dicr.var_grid_in[:,0])
max_dim_1 = torch.max(dicr.var_grid_in[:,0])
min_dim_1, max_dim_1 = scale_limits(min_dim_1, max_dim_1, data_limit_factor)

## Variable input 2
min_dim_2 = torch.min(dicr.var_grid_in[:,1])
Expand All @@ -211,31 +216,26 @@ def plot_discr_surface_2d(file_path, lf_file, data_file, num_1d_grid_points, dat

# Evaluate discrepancy over test grid
res = dicr.forward(test_grid)

# Assign obsersations from data
observations = exp_data.transpose()[num_var_ins:]
observations = exp_data[:,num_var_ins:]

# Assign variable inputs
var_train = [exp_data[:, 0], exp_data[:, 1]]

## Assign discrepancy
# Check for repeated observations
if np.shape(observations)[0] > 1:
print('code needs to be debugged here') # TODO
disc = np.average(observations) - lf_train
exit()
else:
disc = []
for loopA in range(len(observations.flatten())):
disc.append(observations[0][loopA] - lf_train[loopA])
# Assign discrepancy target, i.e., obs - lf model predictions for each batch
disc = observations - lf_train

# Compute average discrepancy across batches used for training
train_disc = np.average(disc, axis = 1)
train_disc = disc.mean(axis=1)

# Compute error bars for averaged discrepancy
discBnds = [np.percentile(disc, nom_coverage/ 2, axis = 1), # Lower bound
np.percentile(disc, 100 - nom_coverage / 2, axis = 1)] # Upper bound
discBnds = [np.percentile(disc, 100 - nom_coverage, axis = 1), # 5 percentile
np.percentile(disc, nom_coverage, axis = 1)] # 95 percentile

errBnds = [train_disc - np.percentile(disc, 100 - nom_coverage, axis = 1), # 5 percentile
np.percentile(disc, nom_coverage, axis = 1) - train_disc] # 95 percentile

# For debugging
if True:
print_disc_stats(train_disc, discBnds)
Expand All @@ -250,14 +250,16 @@ def plot_discr_surface_2d(file_path, lf_file, data_file, num_1d_grid_points, dat
ax = plt.figure(figsize = (4,4)).add_subplot(projection='3d')
ax.plot_trisurf(x, y, z, cmap = plt.cm.Spectral, linewidth = 0.2, antialiased = True)
ax.scatter(var_train[0], var_train[1], train_disc, color = 'k', s = 8)
ax.errorbar(var_train[0], var_train[1], train_disc, zerr = discBnds, fmt = 'o', color = 'k', ecolor = 'k', capsize = 3)
ax.errorbar(var_train[0], var_train[1], train_disc, zerr = errBnds, fmt = 'o', color = 'k', ecolor = 'k', capsize = 3)
ax.set_xlabel('Temperature [K]', fontsize = 16, fontweight = 'bold', labelpad = 15)
ax.set_ylabel('Pressure [Pa]', fontsize = 16, fontweight = 'bold', labelpad = 15)
ax.set_zlabel('Discrepancy [ ]', fontsize = 16, fontweight = 'bold', labelpad = 15)
ax.tick_params(axis = 'both', which = 'both', direction = 'in', top = True, right = True, labelsize = 15)
ax.yaxis.set_major_formatter(FormatStrFormatter("%.1f"))
plt.tight_layout()
print('Generating plot...: ',out_dir+'disc_surf_'+ str(step_num) +'.%s' % img_format)
plt.savefig(out_dir+'disc_surf_'+ str(step_num) +'.%s' % img_format, format = img_format, bbox_inches = 'tight', dpi = 300)
# plt.show()

# def eval_discrepancy_custom_grid(file_path, train_grid_in, train_grid_out, test_grid):

Expand Down Expand Up @@ -492,18 +494,18 @@ def plot_marginal_posterior(params_file, step_num, out_dir, img_format = 'png'):

# Set file name/path for lf and discr results
out_dir = args.folder_name + args.exp_name + '/'
lf_file = out_dir + args.exp_name + '_outputs_lf_' + str(args.step_num)
lf_dicr_file = out_dir + args.exp_name + '_outputs_lf+discr_' + str(args.step_num)
lf_file = out_dir + args.exp_name + '_outputs_lf_' + str(args.step_num)
lf_dicr_file = out_dir + args.exp_name + '_outputs_lf+discr_' + str(args.step_num)
lf_discr_noise_file = out_dir + args.exp_name + '_outputs_lf+discr+noise_' + str(args.step_num)
discr_sur_file = out_dir + args.exp_name
data_file = out_dir + args.exp_name + '_data'
marg_stats_file = out_dir + args.exp_name + '_marginal_stats_'
params_file = out_dir + args.exp_name + '_params_' + str(args.step_num)
discr_sur_file = out_dir + args.exp_name
data_file = out_dir + args.exp_name + '_data'
marg_stats_file = out_dir + args.exp_name + '_marginal_stats_'
params_file = out_dir + args.exp_name + '_params_' + str(args.step_num)

if(args.result_mode == 'histograms'):
plot_disr_histograms(lf_file, lf_dicr_file, lf_discr_noise_file, data_file, args.step_num, out_dir, args.img_format)
elif(args.result_mode == 'discr_surface'):
plot_discr_surface_2d(discr_sur_file, lf_dicr_file, data_file, args.num_1d_grid_points, args.data_limit_factor, args.step_num, out_dir, args.img_format)
plot_discr_surface_2d(discr_sur_file, lf_file, data_file, args.num_1d_grid_points, args.data_limit_factor, args.step_num, out_dir, args.img_format)
elif(args.result_mode == 'marginal_stats'):
plot_marginal_stats(marg_stats_file, args.step_num, args.saveinterval, args.img_format, out_dir)
elif(args.result_mode == 'marginal_posterior'):
Expand Down
108 changes: 92 additions & 16 deletions linfa/plot_res.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,34 @@
plt.rc('ytick', labelsize='x-small')
plt.rc('text', usetex = False)

def plot_output(output_file,obs_file,out_dir,out_info):

if(os.path.isfile(output_file) and os.path.isfile(obs_file)):
tot_outputs = np.loadtxt(output_file).shape[1]
print('Plotting posterior predictive samples...')
for loopA in range(tot_outputs):
for loopB in range(loopA+1, tot_outputs):
plot_outputs(output_file,obs_file,loopA,loopB,out_dir,out_info,fig_format=args.img_format,use_dark_mode=args.use_dark_mode)
else:
if not(os.path.isfile(output_file)):
print('File with posterior predictive samples not found: '+output_file)
if not(os.path.isfile(obs_file)):
print('File with observations not found: '+obs_file)

# def plot_output_discrepancy(output_type,output_file,obs_file,out_dir,out_info):

# if(os.path.isfile(output_file) and os.path.isfile(obs_file)):
# tot_outputs = np.loadtxt(output_file).shape[1]
# print('Plotting posterior predictive samples...')
# for loopA in range(tot_outputs):
# for loopB in range(loopA+1, tot_outputs):
# plot_outputs_discr(output_type,output_file,obs_file,loopA,loopB,out_dir,out_info,fig_format=args.img_format,use_dark_mode=args.use_dark_mode)
# else:
# if not(os.path.isfile(output_file)):
# print('File with posterior predictive samples ('+output_type+') not found: '+output_file)
# if not(os.path.isfile(obs_file)):
# print('File with observations not found: '+obs_file)

def plot_log(log_file,out_dir,fig_format='png',use_dark_mode=False):
log_data = np.loadtxt(log_file)

Expand Down Expand Up @@ -86,7 +114,50 @@ def plot_outputs(sample_file,obs_file,idx1,idx2,out_dir,out_info,fig_format='png
plt.xlim([avg_1-3*std_1,avg_1+3*std_1])
plt.ylim([avg_2-3*std_2,avg_2+3*std_2])
plt.tight_layout()
plt.savefig(out_dir+'data_plot_' + out_info + '_'+str(idx1)+'_'+str(idx2)+'.'+fig_format,bbox_inches='tight',dpi=200)
# Save plot
plt.savefig(out_dir + out_info + '_'+str(idx1)+'_'+str(idx2)+'.'+fig_format,bbox_inches='tight',dpi=200)
plt.close()

def plot_outputs_discr(out_type,sample_file,obs_file,idx1,idx2,out_dir,out_info,fig_format='png',use_dark_mode=False):

# Read data
sample_data = np.loadtxt(sample_file)
obs_data = np.loadtxt(obs_file)

# Set dark mode
if(use_dark_mode):
plt.style.use('dark_background')

plt.figure(figsize=(2.5,2))
# THE OUTPUT ARE FOR EACH TP!!!
plt.scatter(sample_data[:,idx1],sample_data[:,idx2],s=2,c='b',marker='o',edgecolor=None,alpha=0.1)
# THE OBSERVATIONS ARE FOR EACH TP!!!
plt.scatter(obs_data[idx1,:],obs_data[idx2,:],s=3,c='r',alpha=1,zorder=99)
plt.gca().xaxis.set_major_formatter(mtick.FormatStrFormatter('%.2f'))
plt.gca().yaxis.set_major_formatter(mtick.FormatStrFormatter('%.2f'))
plt.gca().tick_params(axis='both', labelsize=fs)
plt.xlabel('$x_{'+str(idx1+1)+'}$',fontsize=fs)
plt.ylabel('$x_{'+str(idx2+1)+'}$',fontsize=fs)
# Set limits based on avg and std
avg_1 = np.mean(sample_data[:,idx1])
std_1 = np.std(sample_data[:,idx1])
avg_2 = np.mean(sample_data[:,idx2])
std_2 = np.std(sample_data[:,idx2])
plt.xlim([avg_1-3*std_1,avg_1+3*std_1])
plt.ylim([avg_2-3*std_2,avg_2+3*std_2])
plt.tight_layout()
if(out_type == 'none'):
file_prefix = 'data_plot_'
elif(out_type == 'disc'):
file_prefix = 'data_plot_disc_'
elif(out_type == 'lf'):
file_prefix = 'data_plot_lf_'
elif(out_type == 'lf+disc'):
file_prefix = 'data_plot_lf+disc_'
elif(out_type == 'lf+disc+noise'):
file_prefix = 'data_plot_lf+disc+noise_'
# Save plot
plt.savefig(out_dir + file_prefix + out_info + '_'+str(idx1) + '_' + str(idx2) + '.' + fig_format,bbox_inches='tight',dpi=200)
plt.close()

# =========
Expand All @@ -109,7 +180,6 @@ def plot_outputs(sample_file,obs_file,idx1,idx2,out_dir,out_info,fig_format='png
metavar='',
dest='folder_name')


# folder name
parser.add_argument('-n', '--name',
action=None,
Expand Down Expand Up @@ -164,9 +234,22 @@ def plot_outputs(sample_file,obs_file,idx1,idx2,out_dir,out_info,fig_format='png
sample_file = out_dir + args.exp_name + '_samples_' + str(args.step_num)
param_file = out_dir + args.exp_name + '_params_' + str(args.step_num)
LL_file = out_dir + args.exp_name + '_logdensity_' + str(args.step_num)
output_file = out_dir + args.exp_name + '_outputs_' + str(args.step_num)
obs_file = out_dir + args.exp_name + '_data'
out_info = args.exp_name + '_' + str(args.step_num)

# Output files
output_file = out_dir + args.exp_name + '_outputs_' + str(args.step_num)
output_file_lf = out_dir + args.exp_name + '_outputs_lf_' + str(args.step_num)
output_file_lf_disc = out_dir + args.exp_name + '_outputs_lf+discr_' + str(args.step_num)
output_file_lf_disc_noise = out_dir + args.exp_name + '_outputs_lf+discr+noise_' + str(args.step_num)

# Observation file
obs_file = out_dir + args.exp_name + '_data'
out_info = args.exp_name + '_' + str(args.step_num)

# Check is this is a case with discrepancy or not
if(os.path.isfile(output_file_lf_disc)):
is_discrepancy = True
else:
is_discrepancy = False

# Plot loss profile
if(os.path.isfile(log_file)):
Expand All @@ -175,7 +258,7 @@ def plot_outputs(sample_file,obs_file,idx1,idx2,out_dir,out_info,fig_format='png
else:
print('Log file not found: '+log_file)

# Plot 2D slice of posterior samples
# Plot posterior samples
if(os.path.isfile(param_file) and os.path.isfile(LL_file)):
tot_params = np.loadtxt(param_file).shape[1]
print('Plotting posterior samples...')
Expand All @@ -186,13 +269,6 @@ def plot_outputs(sample_file,obs_file,idx1,idx2,out_dir,out_info,fig_format='png
print('File with posterior samples not found: '+param_file)
print('File with log-density not found: '+LL_file)

# Plot 2D slice of outputs and observations
if(os.path.isfile(output_file) and os.path.isfile(obs_file)):
tot_outputs = np.loadtxt(output_file).shape[1]
print('Plotting posterior predictive samples...')
for loopA in range(tot_outputs):
for loopB in range(loopA+1, tot_outputs):
plot_outputs(output_file,obs_file,loopA,loopB,out_dir,out_info,fig_format=args.img_format,use_dark_mode=args.use_dark_mode)
else:
print('File with posterior predictive samples not found: '+output_file)
print('File with observations not found: '+obs_file)
# Plot posterior predictive distribution with observations
plot_output(output_file,obs_file,out_dir,out_info)

3 changes: 2 additions & 1 deletion linfa/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from linfa.tests.test_linfa import linfa_test_suite
from linfa.tests.test_linfa import linfa_test_suite
from linfa.tests.test_linfa_discrepancy import linfa_test_suite_discrepancy
Loading

0 comments on commit 893e986

Please sign in to comment.