Skip to content

Commit

Permalink
Add configuration option seed to switch between HASH based and REPLIC…
Browse files Browse the repository at this point in the history
…ATE based #94
  • Loading branch information
gaow committed May 22, 2020
1 parent 8f01b9a commit 8a3768d
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 13 deletions.
17 changes: 11 additions & 6 deletions src/dsc_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,8 @@ def __init__(self,
self.exe = None
# script plugin object
self.plugin = None
# script seed option
self.seed = None
# runtime variables
self.workdir = None
self.libpath = None
Expand Down Expand Up @@ -923,15 +925,18 @@ def set_options(self, common_option, spec_option):
libpath2 = try_get_value(spec_option, 'lib_path')
path1 = try_get_value(common_option, 'exec_path')
path2 = try_get_value(spec_option, 'exec_path')
container1 = try_get_value(common_option, 'container', [None])[0]
container2 = try_get_value(spec_option, 'container', [None])[0]
container_engine1 = try_get_value(common_option, 'container_engine', [None])[0]
container_engine2 = try_get_value(spec_option, 'container_engine', [None])[0]
container1 = try_get_value(common_option, 'container')
container2 = try_get_value(spec_option, 'container')
container_engine1 = try_get_value(common_option, 'container_engine')
container_engine2 = try_get_value(spec_option, 'container_engine')
seed1 = try_get_value(common_option, 'seed', 'HASH')
seed2 = try_get_value(spec_option, 'seed')
self.workdir = workdir2 if workdir2 is not None else workdir1
self.libpath = libpath2 if libpath2 is not None else libpath1
self.path = path2 if path2 is not None else path1
self.container = container2 if container2 is not None else container1
self.container_engine = container_engine2 if container_engine2 is not None else container_engine1
self.seed = seed2[0] if seed2 is not None else seed1
self.container = container2[0] if container2 is not None else container1
self.container_engine = container_engine2[0] if container_engine2 is not None else container_engine1
self.rlib = try_get_value(spec_option, 'R_libs', [])
self.pymodule = try_get_value(spec_option, 'python_modules', [])
if not self.container is None and (self.rlib or self.pymodule):
Expand Down
3 changes: 2 additions & 1 deletion src/dsc_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,8 @@ def get_action(self):
and len(self.step.rv))
script_begin += '\n' + plugin.get_input(
self.params,
self.step.libpath if self.step.libpath else [])
self.step.libpath if self.step.libpath else [],
self.step.seed)
if len(self.step.rf):
script_begin += '\n' + plugin.get_output(
self.step.rf)
Expand Down
21 changes: 15 additions & 6 deletions src/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __init__(self, identifier=''):
super().__init__(name='bash', identifier=identifier)
self.output_ext = 'yml'

def get_input(self, params, lib):
def get_input(self, params, lib, seed_option):
res = 'rm -f $[_output]\n'
if len(lib):
res += '\n'.join([
Expand All @@ -138,7 +138,10 @@ def get_input(self, params, lib):
self.get_var(k), k)
# FIXME: may need a timer
# seed
res += '\nRANDOM=$(($DSC_REPLICATE))'
if 'seed_option' == 'REPLICATE':
res += '\nRANDOM=$(($DSC_REPLICATE))'
else:
res += '\nRANDOM=$(($DSC_REPLICATE + $[DSC_STEP_ID_]))'
return res

def get_output(self, params):
Expand Down Expand Up @@ -344,7 +347,7 @@ def load_env(self, depends_other, depends_self):
res += '\n' + '\n'.join(sorted(self.tempfile))
return res

def get_input(self, params, lib):
def get_input(self, params, lib, seed_option):
res = 'dscrutils:::source_dirs(c({}))\n'.format(','.join(
[repr(x) for x in lib])) if len(lib) else ''
# load parameters
Expand All @@ -355,7 +358,10 @@ def get_input(self, params, lib):
# timer
res += f'\nTIC_{self.identifier[4:]} <- proc.time()'
# seed
res += '\nset.seed(DSC_REPLICATE)'
if seed_option == 'REPLICATE':
res += '\nset.seed(DSC_REPLICATE)'
else:
res += '\nset.seed(DSC_REPLICATE + ${DSC_STEP_ID_})'
return res

def get_output(self, params):
Expand Down Expand Up @@ -526,7 +532,7 @@ def load_env(self, depends_other, depends_self):
res += '\n' + '\n'.join(sorted(self.tempfile))
return res

def get_input(self, params, lib):
def get_input(self, params, lib, seed_option):
res = '\n'.join(
[f'sys.path.append(os.path.expanduser("{item}"))' for item in lib])
# load parameters
Expand All @@ -535,7 +541,10 @@ def get_input(self, params, lib):
for k in keys:
res += '\n%s = ${_%s}' % (self.get_var(k), k)
res += f'\nTIC_{self.identifier[4:]} = timeit.default_timer()'
res += '\nimport random\nrandom.seed(DSC_REPLICATE)\ntry:\n\timport numpy; numpy.random.seed(DSC_REPLICATE)\nexcept Exception:\n\tpass'
if seed_option == 'REPLICATE':
res += '\nimport random\nrandom.seed(DSC_REPLICATE)\ntry:\n\timport numpy; numpy.random.seed(DSC_REPLICATE)\nexcept Exception:\n\tpass'
else:
res += '\nimport random\nrandom.seed(DSC_REPLICATE + ${DSC_STEP_ID_})\ntry:\n\timport numpy; numpy.random.seed(DSC_REPLICATE + ${DSC_STEP_ID_})\nexcept Exception:\n\tpass'
return res

def get_output(self, params):
Expand Down

0 comments on commit 8a3768d

Please sign in to comment.