forked from google-deepmind/deepmind-research
-
Notifications
You must be signed in to change notification settings - Fork 0
/
results_summary.py
183 lines (167 loc) · 7.67 KB
/
results_summary.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
# Copyright 2019 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Plot results for different side effects penalties.
Loads csv result files generated by `run_experiment' and outputs a summary data
frame in a csv file to be used for plotting by plot_results.ipynb.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
from absl import app
from absl import flags
import pandas as pd
from side_effects_penalties.file_loading import load_files
FLAGS = flags.FLAGS
if __name__ == '__main__': # Avoid defining flags when used as a library.
flags.DEFINE_string('path', '', 'File path.')
flags.DEFINE_string('input_suffix', '',
'Filename suffix to use when loading data files.')
flags.DEFINE_string('output_suffix', '',
'Filename suffix to use when saving files.')
flags.DEFINE_bool('bar_plot', True,
'Make a data frame for a bar plot (True) ' +
'or learning curves (False)')
flags.DEFINE_string('env_name', 'box', 'Environment name.')
flags.DEFINE_bool('noops', True, 'Whether the environment includes noops.')
flags.DEFINE_list('beta_list', [0.1, 0.3, 1.0, 3.0, 10.0, 30.0, 100.0],
'List of beta values.')
flags.DEFINE_list('seed_list', [1], 'List of random seeds.')
flags.DEFINE_bool('compare_penalties', True,
'Compare different penalties using the best beta value ' +
'for each penalty (True), or compare different beta values '
+ 'for the same penalty (False).')
flags.DEFINE_enum('dev_measure', 'rel_reach',
['none', 'reach', 'rel_reach', 'att_util'],
'Deviation measure (used if compare_penalties=False).')
flags.DEFINE_enum('dev_fun', 'truncation', ['truncation', 'absolute'],
'Summary function for the deviation measure ' +
'(used if compare_penalties=False)')
flags.DEFINE_float('value_discount', 0.99,
'Discount factor for deviation measure value function ' +
'(used if compare_penalties=False)')
def beta_choice(baseline, dev_measure, dev_fun, value_discount, env_name,
beta_list, seed_list, noops=False, path='', suffix=''):
"""Choose beta value that gives the highest final performance."""
if dev_measure == 'none':
return 0.1
perf_max = float('-inf')
best_beta = 0.0
for beta in beta_list:
df = load_files(baseline=baseline, dev_measure=dev_measure,
dev_fun=dev_fun, value_discount=value_discount, beta=beta,
env_name=env_name, noops=noops, path=path, suffix=suffix,
seed_list=seed_list)
if df.empty:
perf = float('-inf')
else:
perf = df['performance_smooth'].mean()
if perf > perf_max:
perf_max = perf
best_beta = beta
return best_beta
def penalty_label(dev_measure, dev_fun, value_discount):
"""Penalty label specifying design choices."""
dev_measure_labels = {
'none': 'None', 'rel_reach': 'RR', 'att_util': 'AU', 'reach': 'UR'}
label = dev_measure_labels[dev_measure]
disc_lab = 'u' if value_discount == 1.0 else 'd'
dev_lab = ''
if dev_measure in ['rel_reach', 'att_util']:
dev_lab = 't' if dev_fun == 'truncation' else 'a'
if dev_measure != 'none':
label = label + '(' + disc_lab + dev_lab + ')'
return label
def make_summary_data_frame(
env_name, beta_list, seed_list, final=True, baseline=None, dev_measure=None,
dev_fun=None, value_discount=None, noops=False, compare_penalties=True,
path='', input_suffix='', output_suffix=''):
"""Make summary dataframe from multiple csv result files and output to csv."""
# For each of the penalty parameters (baseline, dev_measure, dev_fun, and
# value_discount), compare a list of multiple values if the parameter is None,
# or use the provided parameter value if it is not None
baseline_list = ['start', 'inaction', 'stepwise', 'step_noroll']
if dev_measure is not None:
dev_measure_list = [dev_measure]
else:
dev_measure_list = ['none', 'reach', 'rel_reach', 'att_util']
dataframes = []
for dev_measure in dev_measure_list:
# These deviation measures don't have a deviation function:
if dev_measure in ['reach', 'none']:
dev_fun_list = ['none']
elif dev_fun is not None:
dev_fun_list = [dev_fun]
else:
dev_fun_list = ['truncation', 'absolute']
# These deviation measures must be discounted:
if dev_measure in ['none', 'att_util']:
value_discount_list = [0.99]
elif value_discount is not None:
value_discount_list = [value_discount]
else:
value_discount_list = [0.99, 1.0]
for baseline in baseline_list:
for vd in value_discount_list:
for devf in dev_fun_list:
# Choose the best beta for this set of penalty parameters if
# compare_penalties=True, or compare all betas otherwise
if compare_penalties:
beta = beta_choice(
baseline=baseline, dev_measure=dev_measure, dev_fun=devf,
value_discount=vd, env_name=env_name, noops=noops,
beta_list=beta_list, seed_list=seed_list, path=path,
suffix=input_suffix)
betas = [beta]
else:
betas = beta_list
for beta in betas:
label = penalty_label(
dev_measure=dev_measure, dev_fun=devf, value_discount=vd)
df_part = load_files(
baseline=baseline, dev_measure=dev_measure, dev_fun=devf,
value_discount=vd, beta=beta, env_name=env_name,
noops=noops, path=path, suffix=input_suffix, final=final,
seed_list=seed_list)
df_part = df_part.assign(
baseline=baseline, dev_measure=dev_measure, dev_fun=devf,
value_discount=vd, beta=beta, env_name=env_name, label=label)
dataframes.append(df_part)
df = pd.concat(dataframes, sort=False)
# Output summary data frame
final_str = '_final' if final else ''
if compare_penalties:
filename = ('df_summary_penalties_' + env_name + final_str +
output_suffix + '.csv')
else:
filename = ('df_summary_betas_' + env_name + '_' + dev_measure + '_' +
dev_fun + '_' + str(value_discount) + final_str + output_suffix
+ '.csv')
f = os.path.join(path, filename)
df.to_csv(f)
return df
def main(unused_argv):
compare_penalties = FLAGS.compare_penalties
dev_measure = None if compare_penalties else FLAGS.dev_measure
dev_fun = None if compare_penalties else FLAGS.dev_fun
value_discount = None if compare_penalties else FLAGS.value_discount
make_summary_data_frame(
compare_penalties=compare_penalties, env_name=FLAGS.env_name,
noops=FLAGS.noops, final=FLAGS.bar_plot, dev_measure=dev_measure,
value_discount=value_discount, dev_fun=dev_fun, path=FLAGS.path,
input_suffix=FLAGS.input_suffix, output_suffix=FLAGS.output_suffix,
beta_list=FLAGS.beta_list, seed_list=FLAGS.seed_list)
if __name__ == '__main__':
app.run(main)