Skip to content

Commit

Permalink
added marginal stats plots
Browse files Browse the repository at this point in the history
  • Loading branch information
kylajones committed Dec 20, 2023
1 parent f0ec0cf commit 0fd62d9
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 3 deletions.
52 changes: 51 additions & 1 deletion linfa/plot_disc.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,41 @@ def eval_discrepancy_custom_grid(file_path,train_grid_in,train_grid_out,test_gri
# Evaluate surrogate
return dicr.forward(test_grid)

def plot_marginal_stats(marg_stats_file, step_num, saveinterval, out_dir):
iterations = np.arange(start=saveinterval, stop=step_num + saveinterval, step=saveinterval, dtype=int)

fig, axes = plt.subplots(2, 2, figsize=(10, 8))
axes = axes.flatten()

for loopA in range(len(iterations)):
# Read file
stats = np.loadtxt(marg_stats_file + str(iterations[loopA]))
mean = stats[:, 0]
sd = stats[:, 1]

# Plot mean and sd of calibration parameters with adjusted data point label font size
axes[0].plot(iterations[loopA], mean[0], 'ko', markersize=8)
axes[1].plot(iterations[loopA], mean[1], 'kv', markersize=8)
axes[2].plot(iterations[loopA], sd[0], 'bo', markersize=8)
axes[3].plot(iterations[loopA], sd[1], 'bv', markersize=8)

# Set common labels
for ax, ylabel in zip(axes, [r'$\bar{\theta_1}$', r'$\bar{\theta_2}$', r'$se({\theta_1})$', r'$se({\theta_2})$']):
ax.set_ylabel(ylabel, fontsize=15, fontweight='bold')

# Set x-axis label only for the bottom two subplots
for ax in axes[-2:]:
ax.set_xlabel('Iterations', fontsize=15, fontweight='bold')

# Set tick label font size
for ax in axes:
ax.tick_params(axis='both', labelsize=15)

# Adjust layout and save the figure
plt.tight_layout()
plt.savefig(out_dir + 'marginal_stats', bbox_inches = 'tight', dpi = 300)
plt.show()

# =========
# MAIN CODE
# =========
Expand Down Expand Up @@ -306,7 +341,7 @@ def eval_discrepancy_custom_grid(file_path,train_grid_in,train_grid_out,test_gri
const=None,
default='histograms',
type=str,
choices=['histograms','discr_surface'],
choices=['histograms','discr_surface', 'marginal_stats'],
required=False,
help='Type of plot/result to generate',
metavar='',
Expand Down Expand Up @@ -335,6 +370,17 @@ def eval_discrepancy_custom_grid(file_path,train_grid_in,train_grid_out,test_gri
help='Factor for test grid limits from data file',
metavar='',
dest='data_limit_factor')

# save interval
parser.add_argument('-si', '--saveinterval',
action=None,
# nargs='+',
const=None,
default=1.0,
type=float,
required=False,
help='Save interval to read for each iteration',
metavar='',)

# Parse Commandline Arguments
args = parser.parse_args()
Expand All @@ -346,13 +392,17 @@ def eval_discrepancy_custom_grid(file_path,train_grid_in,train_grid_out,test_gri
lf_discr_noise_file = out_dir + args.exp_name + '_outputs_lf+discr+noise_' + str(args.step_num)
discr_sur_file = out_dir + args.exp_name
data_file = out_dir + args.exp_name + '_data'
marg_stats_file = out_dir + args.exp_name + '_marginal_stats_'

# out_info = args.exp_name + '_' + str(args.step_num)

if(args.result_mode == 'histograms'):
plot_disr_histograms(lf_file, lf_dicr_file, lf_discr_noise_file, data_file, out_dir)
elif(args.result_mode == 'discr_surface'):
plot_discr_surface_2d(discr_sur_file, lf_dicr_file, data_file, args.num_1d_grid_points, args.data_limit_factor, out_dir)
elif(args.result_mode == 'marginal_stats'):
plot_marginal_stats(marg_stats_file, args.step_num, args.saveinterval, out_dir)

else:
print('ERROR. Invalid execution mode')
exit(-1)
Expand Down
12 changes: 10 additions & 2 deletions linfa/tests/test_plot_discr.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,18 @@
# --num_points 10 \
# --limfactor 1.0 \

# python3 -m linfa.plot_disc --folder results/ \
# --name test_18_lf_w_disc_TP15_prior_rep_meas \
# --iter 25000 \
# --mode discr_surface \
# --num_points 40 \
# --limfactor 1.0 \

python3 -m linfa.plot_disc --folder results/ \
--name test_18_lf_w_disc_TP15_prior_rep_meas \
--name test_06_lf_w_disc_TP1 \
--iter 25000 \
--mode discr_surface \
--mode marginal_stats \
--saveinterval 1000 \
--num_points 40 \
--limfactor 1.0 \

Expand Down

0 comments on commit 0fd62d9

Please sign in to comment.