Skip to content

Commit

Permalink
Merge pull request #72 from zwicker-group/templates
Browse files Browse the repository at this point in the history
Improved job script tempaltes
  • Loading branch information
david-zwicker authored Jun 6, 2024
2 parents f56a312 + 45002c8 commit faf1eb8
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 5 deletions.
9 changes: 6 additions & 3 deletions modelrunner/run/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,13 @@ def submit_job(
job_args = []
if parameters is not None and len(parameters) > 0:
if isinstance(parameters, dict):
parameters = json.dumps(parameters)
elif not isinstance(parameters, str):
parameters_json = json.dumps(parameters)
elif isinstance(parameters, str):
parameters_json = parameters
else:
raise TypeError("Parameters need to be given as a string or a dict")
job_args.append(f"--json {escape_string(parameters)}")
job_args.append(f"--json {escape_string(parameters_json)}")
script_args["PARAMETERS"] = parameters # allow using parameters in job script

logger.debug("Job arguments: `%s`", job_args)

Expand Down
16 changes: 16 additions & 0 deletions scripts/run_tests.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#!/usr/bin/env python3

from __future__ import annotations

import argparse
import os
import subprocess as sp
Expand Down Expand Up @@ -66,6 +68,7 @@ def run_unit_tests(
coverage: bool = False,
no_numba: bool = False,
pattern: str = None,
pytest_args: list[str] = [],
) -> int:
"""run the unit tests
Expand All @@ -74,6 +77,9 @@ def run_unit_tests(
coverage (bool): Whether to determine the test coverage
no_numba (bool): Whether to disable numba jit compilation
pattern (str): A pattern that determines which tests are ran
pytest_args (list of str):
Additional arguments forwarded to py.test. For instance ["--maxfail=1"]
fails tests early.
Returns:
int: The return code indicating success or failure
Expand Down Expand Up @@ -117,6 +123,8 @@ def run_unit_tests(
]
)

args.extend(pytest_args)

# specify the package to run
args.append("tests")

Expand Down Expand Up @@ -191,6 +199,13 @@ def main():
help="Write a report of the results",
)

# set py.test arguments
group = parser.add_argument_group(
"py.test arguments",
description="Additional arguments separated by `--` are forward to py.test",
)
group.add_argument("pytest_args", nargs="*", help=argparse.SUPPRESS)

# parse the command line arguments
args = parser.parse_args()
run_all = not (args.style or args.types or args.unit)
Expand All @@ -206,6 +221,7 @@ def main():
parallel=args.parallel,
no_numba=args.no_numba,
pattern=args.pattern,
pytest_args=args.pytest_args,
)


Expand Down
16 changes: 16 additions & 0 deletions scripts/tests_debug.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/bin/bash

export PYTHONPATH=../py-modelrunner # likely path of package, relative to current base path

if [ ! -z $1 ]
then
# test pattern was specified
echo 'Run unittests with pattern '$1':'
./run_tests.py --unit --pattern "$1" -- \
-o log_cli=true --log-cli-level=debug -vv
else
# test pattern was not specified
echo 'Run all unittests:'
./run_tests.py --unit -- \
-o log_cli=true --log-cli-level=debug -vv
fi
19 changes: 19 additions & 0 deletions tests/run/scripts/custom.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/bin/bash -l

export PYTHONPATH={{ PACKAGE_PATH }}:$PYTHONPATH

{% if CONFIG.num_threads is number %}
# set the number of threads to use
export MKL_NUM_THREADS={{ CONFIG.num_threads }}
export NUMBA_NUM_THREADS={{ CONFIG.num_threads }}
export NUMEXPR_NUM_THREADS={{ CONFIG.num_threads }}
export OMP_NUM_THREADS={{ CONFIG.num_threads }}
export OPENBLAS_NUM_THREADS={{ CONFIG.num_threads }}
{% endif %}

{% if OUTPUT_FOLDER is defined and OUTPUT_FOLDER %}
mkdir -p {{ OUTPUT_FOLDER }}
{% endif %}

# Run the program
{{ CONFIG.python_bin }} -m modelrunner {{ MODEL_FILE }} --a {{ PARAMETERS.a }}
18 changes: 16 additions & 2 deletions tests/run/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def test_submit_job_stdout(tmp_path, method):
overwrite_strategy="silent_overwrite",
)

assert outs == "3.0\n"
assert errs == ""
assert outs == "3.0\n"
assert Result.from_file(output).result is None


Expand All @@ -69,8 +69,8 @@ def test_submit_job_no_output():
method="foreground",
overwrite_strategy="silent_overwrite",
)
assert outs == "3.0\n"
assert errs == ""
assert outs == "3.0\n"


def test_submit_jobs(tmp_path):
Expand Down Expand Up @@ -137,3 +137,17 @@ def run(**p):

assert run() == ("", "")
assert run(a=1) == ('--json{"a": 1}', "")


def test_submit_job_own_template(tmp_path):
"""test the submit_job function with a custom template"""
outs, errs = submit_job(
SCRIPT_PATH / "print.py",
method="foreground",
parameters={"a": 5, "b": 10}, # b is not used by template
template=SCRIPT_PATH / "custom.jinja",
overwrite_strategy="silent_overwrite",
)

assert errs == ""
assert outs == "7.0\n" # a + 2 (default value of b)

0 comments on commit faf1eb8

Please sign in to comment.