diff --git a/mlos_bench/mlos_bench/config/schemas/cli/cli-schema.json b/mlos_bench/mlos_bench/config/schemas/cli/cli-schema.json index cd57169a35..0d11769f22 100644 --- a/mlos_bench/mlos_bench/config/schemas/cli/cli-schema.json +++ b/mlos_bench/mlos_bench/config/schemas/cli/cli-schema.json @@ -79,6 +79,13 @@ "examples": [3, 5] }, + "num_trial_runners": { + "description": "Number of trial runner instances to use to execute benchmark environments. Individual TrialRunners can be identified in configs with $trial_runner_id and optionally run in parallel.", + "type": "integer", + "minimum": 1, + "examples": [1, 3, 5, 10] + }, + "storage": { "description": "Path to the json config describing the storage backend to use.", "$ref": "#/$defs/json_config_path" diff --git a/mlos_bench/mlos_bench/environments/base_environment.py b/mlos_bench/mlos_bench/environments/base_environment.py index ba346b6765..60c267eb1b 100644 --- a/mlos_bench/mlos_bench/environments/base_environment.py +++ b/mlos_bench/mlos_bench/environments/base_environment.py @@ -43,6 +43,15 @@ class Environment(metaclass=abc.ABCMeta): # pylint: disable=too-many-instance-attributes """An abstract base of all benchmark environments.""" + # Should be provided by the runtime. + _COMMON_CONST_ARGS = { + "trial_runner_id", + } + _COMMON_REQ_ARGS = { + "experiment_id", + "trial_id", + } + @classmethod def new( # pylint: disable=too-many-arguments cls, @@ -123,6 +132,12 @@ def __init__( # pylint: disable=too-many-arguments An optional service object (e.g., providing methods to deploy or reboot a VM/Host, etc.). """ + global_config = global_config or {} + # Make some usual runtime arguments available for tests. + for arg in self._COMMON_CONST_ARGS: + global_config.setdefault(arg, None) + for arg in self._COMMON_REQ_ARGS: + global_config.setdefault(arg, None) self._validate_json_config(config, name) self.name = name self.config = config @@ -161,8 +176,9 @@ def __init__( # pylint: disable=too-many-arguments req_args = set(config.get("required_args", [])) - set( self._tunable_params.get_param_values().keys() ) + req_args.update(self._COMMON_CONST_ARGS) merge_parameters(dest=self._const_args, source=global_config, required_keys=req_args) - self._const_args = self._expand_vars(self._const_args, global_config or {}) + self._const_args = self._expand_vars(self._const_args, global_config) self._params = self._combine_tunables(self._tunable_params) _LOG.debug("Parameters for '%s' :: %s", name, self._params) @@ -332,6 +348,18 @@ def tunable_params(self) -> TunableGroups: """ return self._tunable_params + @property + def const_args(self) -> Dict[str, TunableValue]: + """ + Get the constant arguments for this Environment. + + Returns + ------- + parameters : Dict[str, TunableValue] + Key/value pairs of all environment const_args parameters. + """ + return self._const_args.copy() + @property def parameters(self) -> Dict[str, TunableValue]: """ @@ -345,7 +373,7 @@ def parameters(self) -> Dict[str, TunableValue]: Key/value pairs of all environment parameters (i.e., `const_args` and `tunable_params`). """ - return self._params + return self._params.copy() def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: """ diff --git a/mlos_bench/mlos_bench/launcher.py b/mlos_bench/mlos_bench/launcher.py index 339a11963d..e0e9275731 100644 --- a/mlos_bench/mlos_bench/launcher.py +++ b/mlos_bench/mlos_bench/launcher.py @@ -22,6 +22,7 @@ from mlos_bench.optimizers.mock_optimizer import MockOptimizer from mlos_bench.optimizers.one_shot_optimizer import OneShotOptimizer from mlos_bench.schedulers.base_scheduler import Scheduler +from mlos_bench.schedulers.trial_runner import TrialRunner from mlos_bench.services.base_service import Service from mlos_bench.services.config_persistence import ConfigPersistenceService from mlos_bench.services.local.local_exec import LocalExecService @@ -44,6 +45,7 @@ class Launcher: def __init__(self, description: str, long_text: str = "", argv: Optional[List[str]] = None): # pylint: disable=too-many-statements + # pylint: disable=too-complex # pylint: disable=too-many-locals _LOG.info("Launch: %s", description) epilog = """ @@ -108,6 +110,7 @@ def __init__(self, description: str, long_text: str = "", argv: Optional[List[st args_rest=args_rest, global_config=cli_config_args, ) + # TODO: Can we generalize these two rules using excluded_cli_args? # experiment_id is generally taken from --globals files, but we also allow # overriding it on the CLI. # It's useful to keep it there explicitly mostly for the --help output. @@ -117,6 +120,13 @@ def __init__(self, description: str, long_text: str = "", argv: Optional[List[st # set it via command line if args.trial_config_repeat_count: self.global_config["trial_config_repeat_count"] = args.trial_config_repeat_count + self.global_config.setdefault("num_trial_runners", 1) + if args.num_trial_runners: + self.global_config["num_trial_runners"] = args.num_trial_runners + if self.global_config["num_trial_runners"] <= 0: + raise ValueError( + f"Invalid num_trial_runners: {self.global_config['num_trial_runners']}" + ) # Ensure that the trial_id is present since it gets used by some other # configs but is typically controlled by the run optimize loop. self.global_config.setdefault("trial_id", 1) @@ -142,13 +152,28 @@ def __init__(self, description: str, long_text: str = "", argv: Optional[List[st ) self.root_env_config = self._config_loader.resolve_path(env_path) - self.environment: Environment = self._config_loader.load_environment( - self.root_env_config, TunableGroups(), self.global_config, service=self._parent_service + self.trial_runners: List[TrialRunner] = [] + for trial_runner_id in range(self.global_config["num_trial_runners"]): + # Create a new global config for each Environment with a unique trial_runner_id for it. + env_global_config = self.global_config.copy() + env_global_config["trial_runner_id"] = trial_runner_id + env = self._config_loader.load_environment( + self.root_env_config, + TunableGroups(), + env_global_config, + service=self._parent_service, + ) + self.trial_runners.append(TrialRunner(trial_runner_id, env)) + _LOG.info( + "Init %d trial runners for environments: %s", + len(self.trial_runners), + list(trial_runner.environment for trial_runner in self.trial_runners), ) - _LOG.info("Init environment: %s", self.environment) - # NOTE: Init tunable values *after* the Environment, but *before* the Optimizer + # NOTE: Init tunable values *after* the Environment(s), but *before* the Optimizer + # TODO: should we assign the same or different tunables for all TrialRunner Environments? self.tunables = self._init_tunable_values( + self.trial_runners[0].environment, args.random_init or config.get("random_init", False), config.get("random_seed") if args.random_seed is None else args.random_seed, config.get("tunable_values", []) + (args.tunable_values or []), @@ -278,6 +303,18 @@ def add_argument(self, *args: Any, **kwargs: Any) -> None: ), ) + parser.add_argument( + "--num_trial_runners", + "--num-trial-runners", + required=False, + type=int, + help=( + "Number of TrialRunners to use for executing benchmark Environments. " + "Individual TrialRunners can be identified in configs with $trial_runner_id " + "and optionally run in parallel." + ), + ) + path_args_tracker.add_argument( "--scheduler", required=False, @@ -428,6 +465,7 @@ def _load_config( def _init_tunable_values( self, + env: Environment, random_init: bool, seed: Optional[int], args_tunables: Optional[str], @@ -435,7 +473,7 @@ def _init_tunable_values( """Initialize the tunables and load key/value pairs of the tunable values from given JSON files, if specified. """ - tunables = self.environment.tunable_params + tunables = env.tunable_params _LOG.debug("Init tunables: default = %s", tunables) if random_init: @@ -534,7 +572,7 @@ def _load_scheduler(self, args_scheduler: Optional[str]) -> Scheduler: "teardown": self.teardown, }, global_config=self.global_config, - environment=self.environment, + trial_runners=self.trial_runners, optimizer=self.optimizer, storage=self.storage, root_env_config=self.root_env_config, @@ -544,7 +582,7 @@ def _load_scheduler(self, args_scheduler: Optional[str]) -> Scheduler: return self._config_loader.build_scheduler( config=class_config, global_config=self.global_config, - environment=self.environment, + trial_runners=self.trial_runners, optimizer=self.optimizer, storage=self.storage, root_env_config=self.root_env_config, diff --git a/mlos_bench/mlos_bench/schedulers/base_scheduler.py b/mlos_bench/mlos_bench/schedulers/base_scheduler.py index f38e51e713..b312f07480 100644 --- a/mlos_bench/mlos_bench/schedulers/base_scheduler.py +++ b/mlos_bench/mlos_bench/schedulers/base_scheduler.py @@ -9,7 +9,7 @@ from abc import ABCMeta, abstractmethod from datetime import datetime from types import TracebackType -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, Dict, Iterable, List, Optional, Tuple, Type from pytz import UTC from typing_extensions import Literal @@ -17,6 +17,7 @@ from mlos_bench.config.schemas import ConfigSchema from mlos_bench.environments.base_environment import Environment from mlos_bench.optimizers.base_optimizer import Optimizer +from mlos_bench.schedulers.trial_runner import TrialRunner from mlos_bench.storage.base_storage import Storage from mlos_bench.tunables.tunable_groups import TunableGroups from mlos_bench.util import merge_parameters @@ -33,7 +34,7 @@ def __init__( # pylint: disable=too-many-arguments *, config: Dict[str, Any], global_config: Dict[str, Any], - environment: Environment, + trial_runners: List[TrialRunner], optimizer: Optimizer, storage: Storage, root_env_config: str, @@ -41,23 +42,23 @@ def __init__( # pylint: disable=too-many-arguments """ Create a new instance of the scheduler. The constructor of this and the derived classes is called by the persistence service after reading the class JSON - configuration. Other objects like the Environment and Optimizer are provided by - the Launcher. + configuration. Other objects like the TrialRunner(s) and their Environment(s) + and Optimizer are provided by the Launcher. Parameters ---------- config : dict - The configuration for the scheduler. + The configuration for the Scheduler. global_config : dict - he global configuration for the experiment. - environment : Environment - The environment to benchmark/optimize. + The global configuration for the experiment. + trial_runner : List[TrialRunner] + The set of TrialRunner(s) (and associated Environment(s)) to benchmark/optimize. optimizer : Optimizer - The optimizer to use. + The Optimizer to use. storage : Storage The storage to use. root_env_config : str - Path to the root environment configuration. + Path to the root Environment configuration. """ self.global_config = global_config config = merge_parameters( @@ -81,11 +82,13 @@ def __init__( # pylint: disable=too-many-arguments self._do_teardown = bool(config.get("teardown", True)) - self.experiment: Optional[Storage.Experiment] = None - self.environment = environment - self.optimizer = optimizer - self.storage = storage + self._experiment: Optional[Storage.Experiment] = None + self._trial_runners = trial_runners + assert self._trial_runners, "At least one TrialRunner is required" + self._optimizer = optimizer + self._storage = storage self._root_env_config = root_env_config + self._current_trial_runner_idx = 0 self._last_trial_id = -1 self._ran_trials: List[Storage.Trial] = [] @@ -126,6 +129,122 @@ def max_trials(self) -> int: """ return self._max_trials + @property + def experiment(self) -> Optional[Storage.Experiment]: + """Gets the Experiment Storage.""" + return self._experiment + + @property + def root_environment(self) -> Environment: + """ + Gets the root (prototypical) Environment from the first TrialRunner. + + Note: All TrialRunners have the same Environment config and are made + unique by their use of the unique trial_runner_id assigned to each + TrialRunner's Environment's global_config. + """ + return self._trial_runners[0].environment + + @property + def trial_runners(self) -> List[TrialRunner]: + """Gets the list of Trial Runners.""" + return self._trial_runners + + @property + def environments(self) -> Iterable[Environment]: + """Gets the Environment from the TrialRunners.""" + return (trial_runner.environment for trial_runner in self._trial_runners) + + @property + def optimizer(self) -> Optimizer: + """Gets the Optimizer.""" + return self._optimizer + + @property + def storage(self) -> Storage: + """Gets the Storage.""" + return self._storage + + def _assign_trial_runner( + self, + trial: Storage.Trial, + trial_runner: Optional[TrialRunner] = None, + ) -> TrialRunner: + """ + Assigns a TrialRunner to the given Trial. + + The base class implements a simple round-robin scheduling algorithm. + + Subclasses can override this method to implement a more sophisticated policy. + For instance: + + ```python + def assign_trial_runner( + self, + trial: Storage.Trial, + trial_runner: Optional[TrialRunner] = None, + ) -> TrialRunner: + if trial_runner is None: + # Implement a more sophisticated policy here. + # For example, to assign the Trial to the TrialRunner with the least + # number of running Trials. + # Or assign the Trial to the TrialRunner that hasn't executed this + # TunableValues Config yet. + trial_runner = ... + # Call the base class method to assign the TrialRunner in the Trial's metadata. + return super().assign_trial_runner(trial, trial_runner) + ... + ``` + + Parameters + ---------- + trial : Storage.Trial + The trial to assign a TrialRunner to. + trial_runner : Optional[TrialRunner] + The ID of the TrialRunner to assign to the given Trial. + + Returns + ------- + TrialRunner + The assigned TrialRunner. + """ + assert ( + trial.trial_runner_id is None + ), f"Trial {trial} already has a TrialRunner assigned: {trial.trial_runner_id}" + if trial_runner is None: + # Basic round-robin trial runner assignment policy: + # fetch and increment the current TrialRunner index. + # Override in the subclass for a more sophisticated policy. + trial_runner_id = self._current_trial_runner_idx + self._current_trial_runner_idx += 1 + self._current_trial_runner_idx %= len(self._trial_runners) + trial_runner = self._trial_runners[trial_runner_id] + _LOG.info( + "Trial %s missing trial_runner_id. Assigning %s via basic round-robin policy.", + trial, + trial_runner, + ) + trial.add_new_config_data({"trial_runner_id": trial_runner.trial_runner_id}) + return trial_runner + + def get_trial_runner(self, trial: Storage.Trial) -> TrialRunner: + """ + Gets the TrialRunner associated with the given Trial. + + Parameters + ---------- + trial : Storage.Trial + The trial to get the associated TrialRunner for. + + Returns + ------- + TrialRunner + """ + if trial.trial_runner_id is None: + self._assign_trial_runner(trial, trial_runner=None) + assert trial.trial_runner_id is not None + return self._trial_runners[trial.trial_runner_id] + def __repr__(self) -> str: """ Produce a human-readable version of the Scheduler (mostly for logging). @@ -141,18 +260,17 @@ def __enter__(self) -> "Scheduler": """Enter the scheduler's context.""" _LOG.debug("Scheduler START :: %s", self) assert self.experiment is None - self.environment.__enter__() - self.optimizer.__enter__() + self._optimizer.__enter__() # Start new or resume the existing experiment. Verify that the # experiment configuration is compatible with the previous runs. # If the `merge` config parameter is present, merge in the data # from other experiments and check for compatibility. - self.experiment = self.storage.experiment( + self._experiment = self.storage.experiment( experiment_id=self._experiment_id, trial_id=self._trial_id, root_env_config=self._root_env_config, - description=self.environment.name, - tunables=self.environment.tunable_params, + description=self.root_environment.name, + tunables=self.root_environment.tunable_params, opt_targets=self.optimizer.targets, ).__enter__() return self @@ -169,55 +287,57 @@ def __exit__( else: assert ex_type and ex_val _LOG.warning("Scheduler END :: %s", self, exc_info=(ex_type, ex_val, ex_tb)) - assert self.experiment is not None - self.experiment.__exit__(ex_type, ex_val, ex_tb) - self.optimizer.__exit__(ex_type, ex_val, ex_tb) - self.environment.__exit__(ex_type, ex_val, ex_tb) - self.experiment = None + assert self._experiment is not None + self._experiment.__exit__(ex_type, ex_val, ex_tb) + self._optimizer.__exit__(ex_type, ex_val, ex_tb) + self._experiment = None return False # Do not suppress exceptions @abstractmethod def start(self) -> None: - """Start the optimization loop.""" + """Start the scheduling loop.""" assert self.experiment is not None _LOG.info( "START: Experiment: %s Env: %s Optimizer: %s", - self.experiment, - self.environment, + self._experiment, + self.root_environment, self.optimizer, ) if _LOG.isEnabledFor(logging.INFO): - _LOG.info("Root Environment:\n%s", self.environment.pprint()) + _LOG.info("Root Environment:\n%s", self.root_environment.pprint()) if self._config_id > 0: - tunables = self.load_config(self._config_id) + tunables = self.load_tunable_config(self._config_id) self.schedule_trial(tunables) def teardown(self) -> None: """ - Tear down the environment. + Tear down the TrialRunners/Environment(s). Call it after the completion of the `.start()` in the scheduler context. """ assert self.experiment is not None if self._do_teardown: - self.environment.teardown() + for trial_runner in self.trial_runners: + assert not trial_runner.is_running + trial_runner.teardown() def get_best_observation(self) -> Tuple[Optional[Dict[str, float]], Optional[TunableGroups]]: """Get the best observation from the optimizer.""" (best_score, best_config) = self.optimizer.get_best_observation() - _LOG.info("Env: %s best score: %s", self.environment, best_score) + _LOG.info("Env: %s best score: %s", self.root_environment, best_score) return (best_score, best_config) - def load_config(self, config_id: int) -> TunableGroups: + def load_tunable_config(self, config_id: int) -> TunableGroups: """Load the existing tunable configuration from the storage.""" assert self.experiment is not None tunable_values = self.experiment.load_tunable_config(config_id) - tunables = self.environment.tunable_params.assign(tunable_values) + for environment in self.environments: + tunables = environment.tunable_params.assign(tunable_values) _LOG.info("Load config from storage: %d", config_id) if _LOG.isEnabledFor(logging.DEBUG): _LOG.debug("Config %d ::\n%s", config_id, json.dumps(tunable_values, indent=2)) - return tunables + return tunables.copy() def _schedule_new_optimizer_suggestions(self) -> bool: """ @@ -242,6 +362,9 @@ def _schedule_new_optimizer_suggestions(self) -> bool: def schedule_trial(self, tunables: TunableGroups) -> None: """Add a configuration to the queue of trials.""" + # TODO: Alternative scheduling policies may prefer to expand repeats over + # time as well as space, or adjust the number of repeats (budget) of a given + # trial based on whether initial results are promising. for repeat_i in range(1, self._trial_config_repeat_count + 1): self._add_trial_to_queue( tunables, @@ -271,13 +394,17 @@ def _add_trial_to_queue( config: Optional[Dict[str, Any]] = None, ) -> None: """ - Add a configuration to the queue of trials. + Add a configuration to the queue of trials in the Storage backend. A wrapper for the `Experiment.new_trial` method. """ assert self.experiment is not None trial = self.experiment.new_trial(tunables, ts_start, config) - _LOG.info("QUEUE: Add new trial: %s", trial) + # Select a TrialRunner based on the trial's metadata. + # TODO: May want to further split this in the future to support scheduling a + # batch of new trials. + trial_runner = self._assign_trial_runner(trial, trial_runner=None) + _LOG.info("QUEUE: Added new trial: %s (assigned to %s)", trial, trial_runner) def _run_schedule(self, running: bool = False) -> None: """ diff --git a/mlos_bench/mlos_bench/schedulers/sync_scheduler.py b/mlos_bench/mlos_bench/schedulers/sync_scheduler.py index e56d15ca17..4b864942dc 100644 --- a/mlos_bench/mlos_bench/schedulers/sync_scheduler.py +++ b/mlos_bench/mlos_bench/schedulers/sync_scheduler.py @@ -5,11 +5,7 @@ """A simple single-threaded synchronous optimization loop implementation.""" import logging -from datetime import datetime -from pytz import UTC - -from mlos_bench.environments.status import Status from mlos_bench.schedulers.base_scheduler import Scheduler from mlos_bench.storage.base_storage import Storage @@ -41,25 +37,7 @@ def run_trial(self, trial: Storage.Trial) -> None: Save the results in the storage. """ super().run_trial(trial) - - if not self.environment.setup(trial.tunables, trial.config(self.global_config)): - _LOG.warning("Setup failed: %s :: %s", self.environment, trial.tunables) - # FIXME: Use the actual timestamp from the environment. - _LOG.info("QUEUE: Update trial results: %s :: %s", trial, Status.FAILED) - trial.update(Status.FAILED, datetime.now(UTC)) - return - - # Block and wait for the final result. - (status, timestamp, results) = self.environment.run() - _LOG.info("Results: %s :: %s\n%s", trial.tunables, status, results) - - # In async mode (TODO), poll the environment for status and telemetry - # and update the storage with the intermediate results. - (_status, _timestamp, telemetry) = self.environment.status() - - # Use the status and timestamp from `.run()` as it is the final status of the experiment. - # TODO: Use the `.status()` output in async mode. - trial.update_telemetry(status, timestamp, telemetry) - - trial.update(status, timestamp, results) - _LOG.info("QUEUE: Update trial results: %s :: %s %s", trial, status, results) + # In the sync scheduler we run each trial on its own TrialRunner in sequence. + trial_runner = self.get_trial_runner(trial) + trial_runner.run_trial(trial, self.global_config) + _LOG.info("QUEUE: Finished trial: %s on %s", trial, trial_runner) diff --git a/mlos_bench/mlos_bench/schedulers/trial_runner.py b/mlos_bench/mlos_bench/schedulers/trial_runner.py new file mode 100644 index 0000000000..a8da73dc26 --- /dev/null +++ b/mlos_bench/mlos_bench/schedulers/trial_runner.py @@ -0,0 +1,142 @@ +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +"""Simple class to run an individual Trial on a given Environment.""" + +import logging +from datetime import datetime +from types import TracebackType +from typing import Any, Dict, Literal, Optional, Type + +from pytz import UTC + +from mlos_bench.environments.base_environment import Environment +from mlos_bench.environments.status import Status +from mlos_bench.event_loop_context import EventLoopContext +from mlos_bench.storage.base_storage import Storage + +_LOG = logging.getLogger(__name__) + + +class TrialRunner: + """ + Simple class to help run an individual Trial on an environment. + + TrialRunner manages the lifecycle of a single trial, including setup, run, teardown, + and async status polling via EventLoopContext background threads. + + Multiple TrialRunners can be used in a multi-processing pool to run multiple trials + in parallel, for instance. + """ + + def __init__(self, trial_runner_id: int, env: Environment) -> None: + self._trial_runner_id = trial_runner_id + self._env = env + assert self._env.parameters["trial_runner_id"] == self._trial_runner_id + self._in_context = False + self._is_running = False + self._event_loop_context = EventLoopContext() + + @property + def trial_runner_id(self) -> int: + """Get the TrialRunner's id.""" + return self._trial_runner_id + + @property + def environment(self) -> Environment: + """Get the Environment.""" + return self._env + + def __enter__(self) -> "TrialRunner": + assert not self._in_context + _LOG.debug("TrialRunner START :: %s", self) + # TODO: self._event_loop_context.enter() + self._env.__enter__() + self._in_context = True + return self + + def __exit__( + self, + ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType], + ) -> Literal[False]: + assert self._in_context + _LOG.debug("TrialRunner END :: %s", self) + self._env.__exit__(ex_type, ex_val, ex_tb) + # TODO: self._event_loop_context.exit() + self._in_context = False + return False # Do not suppress exceptions + + @property + def is_running(self) -> bool: + """Get the running state of the current TrialRunner.""" + return self._is_running + + def run_trial( + self, + trial: Storage.Trial, + global_config: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Run a single trial on this TrialRunner's Environment and stores the results in + the backend Trial Storage. + + Parameters + ---------- + trial : Storage.Trial + A Storage class based Trial used to persist the experiment trial data. + global_config : dict + Global configuration parameters. + + Returns + ------- + (trial_status, trial_score) : (Status, Optional[Dict[str, float]]) + Status and results of the trial. + """ + assert self._in_context + + assert not self._is_running + self._is_running = True + + assert trial.trial_runner_id == self.trial_runner_id, ( + f"TrialRunner {self} should not run trial {trial} " + f"with different trial_runner_id {trial.trial_runner_id}." + ) + + if not self.environment.setup(trial.tunables, trial.config(global_config)): + _LOG.warning("Setup failed: %s :: %s", self.environment, trial.tunables) + # FIXME: Use the actual timestamp from the environment. + _LOG.info("TrialRunner: Update trial results: %s :: %s", trial, Status.FAILED) + trial.update(Status.FAILED, datetime.now(UTC)) + return + + # TODO: start background status polling of the environments in the event loop. + + # Block and wait for the final result. + (status, timestamp, results) = self.environment.run() + _LOG.info("TrialRunner Results: %s :: %s\n%s", trial.tunables, status, results) + + # In async mode (TODO), poll the environment for status and telemetry + # and update the storage with the intermediate results. + (_status, _timestamp, telemetry) = self.environment.status() + + # Use the status and timestamp from `.run()` as it is the final status of the experiment. + # TODO: Use the `.status()` output in async mode. + trial.update_telemetry(status, timestamp, telemetry) + + trial.update(status, timestamp, results) + _LOG.info("TrialRunner: Update trial results: %s :: %s %s", trial, status, results) + + self._is_running = False + + def teardown(self) -> None: + """ + Tear down the Environment. + + Call it after the completion of one (or more) `.run()` in the TrialRunner + context. + """ + assert self._in_context + self._env.teardown() diff --git a/mlos_bench/mlos_bench/services/config_persistence.py b/mlos_bench/mlos_bench/services/config_persistence.py index 72bfad007d..563316aefa 100644 --- a/mlos_bench/mlos_bench/services/config_persistence.py +++ b/mlos_bench/mlos_bench/services/config_persistence.py @@ -46,6 +46,7 @@ if TYPE_CHECKING: from mlos_bench.schedulers.base_scheduler import Scheduler + from mlos_bench.schedulers.trial_runner import TrialRunner from mlos_bench.storage.base_storage import Storage @@ -350,7 +351,7 @@ def build_scheduler( # pylint: disable=too-many-arguments *, config: Dict[str, Any], global_config: Dict[str, Any], - environment: Environment, + trial_runners: List["TrialRunner"], optimizer: Optimizer, storage: "Storage", root_env_config: str, @@ -364,8 +365,8 @@ def build_scheduler( # pylint: disable=too-many-arguments Configuration of the class to instantiate, as loaded from JSON. global_config : dict Global configuration parameters. - environment : Environment - The environment to benchmark/optimize. + trial_runners : List[TrialRunner] + The TrialRunners (Environments) to use. optimizer : Optimizer The optimizer to use. storage : Storage @@ -387,7 +388,7 @@ def build_scheduler( # pylint: disable=too-many-arguments class_name, config=class_config, global_config=global_config, - environment=environment, + trial_runners=trial_runners, optimizer=optimizer, storage=storage, root_env_config=root_env_config, diff --git a/mlos_bench/mlos_bench/storage/base_storage.py b/mlos_bench/mlos_bench/storage/base_storage.py index 867c4e0bc0..81a598c7e6 100644 --- a/mlos_bench/mlos_bench/storage/base_storage.py +++ b/mlos_bench/mlos_bench/storage/base_storage.py @@ -8,7 +8,7 @@ from abc import ABCMeta, abstractmethod from datetime import datetime from types import TracebackType -from typing import Any, Dict, Iterator, List, Optional, Tuple, Type +from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union from typing_extensions import Literal @@ -399,7 +399,10 @@ def __init__( # pylint: disable=too-many-arguments self._status = Status.UNKNOWN def __repr__(self) -> str: - return f"{self._experiment_id}:{self._trial_id}:{self._tunable_config_id}" + return ( + f"{self._experiment_id}:{self._trial_id}:" + f"{self._tunable_config_id}:{self.trial_runner_id}" + ) @property def trial_id(self) -> int: @@ -411,6 +414,11 @@ def tunable_config_id(self) -> int: """ID of the current trial (tunable) configuration.""" return self._tunable_config_id + @property + def trial_runner_id(self) -> Optional[int]: + """ID of the TrialRunner this trial is assigned to.""" + return self._config.get("trial_runner_id") + @property def opt_targets(self) -> Dict[str, Literal["min", "max"]]: """Get the Trial's optimization targets and directions.""" @@ -439,8 +447,52 @@ def config(self, global_config: Optional[Dict[str, Any]] = None) -> Dict[str, An config.update(global_config or {}) config["experiment_id"] = self._experiment_id config["trial_id"] = self._trial_id + trial_runner_id = self.trial_runner_id + if trial_runner_id is not None: + config.setdefault("trial_runner_id", trial_runner_id) return config + def add_new_config_data( + self, + new_config_data: Dict[str, Union[int, float, str]], + ) -> None: + """ + Add new config data to the trial. + + Parameters + ---------- + new_config_data : Dict[str, Union[int, float, str]] + New data to add (must not already exist for the trial). + + Raises + ------ + ValueError + If any of the data already exists. + """ + + for key, value in new_config_data.items(): + if key in self._config: + raise ValueError( + f"New config data {key}={value} already exists for trial {self}: " + f"{self._config[key]}" + ) + self._config[key] = value + self._save_new_config_data(new_config_data) + + @abstractmethod + def _save_new_config_data( + self, + new_config_data: Dict[str, Union[int, float, str]], + ) -> None: + """ + Save the new config data to the storage. + + Parameters + ---------- + new_config_data : Dict[str, Union[int, float, str]] + New data to add. + """ + @property def status(self) -> Status: """Get the status of the current trial.""" diff --git a/mlos_bench/mlos_bench/storage/base_trial_data.py b/mlos_bench/mlos_bench/storage/base_trial_data.py index 4782aa92b3..7162623117 100644 --- a/mlos_bench/mlos_bench/storage/base_trial_data.py +++ b/mlos_bench/mlos_bench/storage/base_trial_data.py @@ -38,6 +38,7 @@ def __init__( # pylint: disable=too-many-arguments ts_start: datetime, ts_end: Optional[datetime], status: Status, + trial_runner_id: Optional[int] = None, ): self._experiment_id = experiment_id self._trial_id = trial_id @@ -47,11 +48,12 @@ def __init__( # pylint: disable=too-many-arguments self._ts_start = ts_start self._ts_end = ts_end self._status = status + self._trial_runner_id = trial_runner_id def __repr__(self) -> str: return ( f"Trial :: {self._experiment_id}:{self._trial_id} " - f"cid:{self._tunable_config_id} {self._status.name}" + f"cid:{self._tunable_config_id} rid:{self._trial_runner_id} {self._status.name}" ) def __eq__(self, other: Any) -> bool: @@ -69,6 +71,13 @@ def trial_id(self) -> int: """ID of the trial.""" return self._trial_id + @property + def trial_runner_id(self) -> Optional[int]: + """ID of the TrialRunner.""" + if not self._trial_runner_id: + self._trial_runner_id = self.metadata_dict.get("trial_runner_id") + return self._trial_runner_id + @property def ts_start(self) -> datetime: """Start timestamp of the trial (UTC).""" diff --git a/mlos_bench/mlos_bench/storage/sql/common.py b/mlos_bench/mlos_bench/storage/sql/common.py index 3b0c6c31fb..6f3d594ac4 100644 --- a/mlos_bench/mlos_bench/storage/sql/common.py +++ b/mlos_bench/mlos_bench/storage/sql/common.py @@ -3,16 +3,46 @@ # Licensed under the MIT License. # """Common SQL methods for accessing the stored benchmark data.""" -from typing import Dict, Optional +from typing import Any, Dict, Optional import pandas -from sqlalchemy import Engine, Integer, and_, func, select +from sqlalchemy import Connection, Engine, Integer, Table, and_, func, select from mlos_bench.environments.status import Status from mlos_bench.storage.base_experiment_data import ExperimentData from mlos_bench.storage.base_trial_data import TrialData from mlos_bench.storage.sql.schema import DbSchema -from mlos_bench.util import utcify_nullable_timestamp, utcify_timestamp +from mlos_bench.util import nullable, utcify_nullable_timestamp, utcify_timestamp + + +def save_params( + conn: Connection, + table: Table, + params: Dict[str, Any], + **kwargs: Any, +) -> None: + """Updates a set of (param_id, param_value) tuples in the given Table. + + Parameters + ---------- + conn : Connection + A connection to the backend database. + table : Table + The table to update. + params : Dict[str, Any] + The new (param_id, param_value) tuples to upsert to the Table. + **kwargs : Dict[str, Any] + Primary key info for the given table. + """ + if not params: + return + conn.execute( + table.insert(), + [ + {**kwargs, "param_id": key, "param_value": nullable(str, val)} + for (key, val) in params.items() + ], + ) def get_trials( @@ -34,6 +64,13 @@ def get_trials( # Build up sql a statement for fetching trials. stmt = ( schema.trial.select() + .join( + schema.trial_param, + schema.trial.c.trial_id == schema.trial_param.c.trial_id + and schema.trial.c.exp_id == schema.trial_param.c.exp_id + and schema.trial_param.c.param_id == "trial_runner_id", + isouter=True, + ) .where( schema.trial.c.exp_id == experiment_id, ) @@ -58,6 +95,7 @@ def get_trials( ts_start=utcify_timestamp(trial.ts_start, origin="utc"), ts_end=utcify_nullable_timestamp(trial.ts_end, origin="utc"), status=Status[trial.status], + trial_runner_id=trial.param_value, ) for trial in trials.fetchall() } @@ -108,6 +146,13 @@ def get_results_df( schema.trial, tunable_config_trial_group_id_subquery, ) + .join( + schema.trial_param, + schema.trial.c.trial_id == schema.trial_param.c.trial_id + and schema.trial.c.exp_id == schema.trial_param.c.exp_id + and schema.trial_param.c.param_id == "trial_runner_id", + isouter=True, + ) .where( schema.trial.c.exp_id == experiment_id, and_( @@ -135,6 +180,7 @@ def get_results_df( row.config_id, row.tunable_config_trial_group_id, row.status, + row.param_value, ) for row in cur_trials.fetchall() ], @@ -145,6 +191,7 @@ def get_results_df( "tunable_config_id", "tunable_config_trial_group_id", "status", + "trial_runner_id", ], ) diff --git a/mlos_bench/mlos_bench/storage/sql/experiment.py b/mlos_bench/mlos_bench/storage/sql/experiment.py index 56a3f26049..abd9fe80e9 100644 --- a/mlos_bench/mlos_bench/storage/sql/experiment.py +++ b/mlos_bench/mlos_bench/storage/sql/experiment.py @@ -14,10 +14,11 @@ from mlos_bench.environments.status import Status from mlos_bench.storage.base_storage import Storage +from mlos_bench.storage.sql.common import save_params from mlos_bench.storage.sql.schema import DbSchema from mlos_bench.storage.sql.trial import Trial from mlos_bench.tunables.tunable_groups import TunableGroups -from mlos_bench.util import nullable, utcify_timestamp +from mlos_bench.util import utcify_timestamp _LOG = logging.getLogger(__name__) @@ -224,23 +225,6 @@ def _get_key_val(conn: Connection, table: Table, field: str, **kwargs: Any) -> D row._tuple() for row in cur_result.fetchall() # pylint: disable=protected-access ) - @staticmethod - def _save_params( - conn: Connection, - table: Table, - params: Dict[str, Any], - **kwargs: Any, - ) -> None: - if not params: - return - conn.execute( - table.insert(), - [ - {**kwargs, "param_id": key, "param_value": nullable(str, val)} - for (key, val) in params.items() - ], - ) - def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator[Storage.Trial]: timestamp = utcify_timestamp(timestamp, origin="local") _LOG.info("Retrieve pending trials for: %s @ %s", self._experiment_id, timestamp) @@ -302,7 +286,7 @@ def _get_config_id(self, conn: Connection, tunables: TunableGroups) -> int: config_id: int = conn.execute( self._schema.config.insert().values(config_hash=config_hash) ).inserted_primary_key[0] - self._save_params( + save_params( conn, self._schema.config_param, {tunable.name: tunable.value for (tunable, _group) in tunables}, @@ -338,7 +322,7 @@ def _new_trial( # Note: config here is the framework config, not the target # environment config (i.e., tunables). if config is not None: - self._save_params( + save_params( conn, self._schema.trial_param, config, diff --git a/mlos_bench/mlos_bench/storage/sql/schema.py b/mlos_bench/mlos_bench/storage/sql/schema.py index 3900568b75..431cfe1bb1 100644 --- a/mlos_bench/mlos_bench/storage/sql/schema.py +++ b/mlos_bench/mlos_bench/storage/sql/schema.py @@ -148,6 +148,8 @@ def __init__(self, engine: Engine): # Values of additional non-tunable parameters of the trial, # e.g., scheduled execution time, VM name / location, number of repeats, etc. + # In particular, the trial_runner_id is stored here (in part to avoid + # updating the trial table schema). self.trial_param = Table( "trial_param", self._meta, diff --git a/mlos_bench/mlos_bench/storage/sql/trial.py b/mlos_bench/mlos_bench/storage/sql/trial.py index 5942912efd..75cb65d0cc 100644 --- a/mlos_bench/mlos_bench/storage/sql/trial.py +++ b/mlos_bench/mlos_bench/storage/sql/trial.py @@ -6,13 +6,14 @@ import logging from datetime import datetime -from typing import Any, Dict, List, Literal, Optional, Tuple +from typing import Any, Dict, List, Literal, Optional, Tuple, Union from sqlalchemy import Connection, Engine from sqlalchemy.exc import IntegrityError from mlos_bench.environments.status import Status from mlos_bench.storage.base_storage import Storage +from mlos_bench.storage.sql.common import save_params from mlos_bench.storage.sql.schema import DbSchema from mlos_bench.tunables.tunable_groups import TunableGroups from mlos_bench.util import nullable, utcify_timestamp @@ -46,6 +47,16 @@ def __init__( # pylint: disable=too-many-arguments self._engine = engine self._schema = schema + def _save_new_config_data(self, new_config_data: Dict[str, Union[int, float, str]]) -> None: + with self._engine.begin() as conn: + save_params( + conn, + self._schema.trial_param, + new_config_data, + exp_id=self._experiment_id, + trial_id=self._trial_id, + ) + def update( self, status: Status, diff --git a/mlos_bench/mlos_bench/storage/sql/trial_data.py b/mlos_bench/mlos_bench/storage/sql/trial_data.py index ac57b7b5c0..66b5d06b19 100644 --- a/mlos_bench/mlos_bench/storage/sql/trial_data.py +++ b/mlos_bench/mlos_bench/storage/sql/trial_data.py @@ -36,6 +36,7 @@ def __init__( # pylint: disable=too-many-arguments ts_start: datetime, ts_end: Optional[datetime], status: Status, + trial_runner_id: Optional[int] = None, ): super().__init__( experiment_id=experiment_id, @@ -44,6 +45,7 @@ def __init__( # pylint: disable=too-many-arguments ts_start=ts_start, ts_end=ts_end, status=status, + trial_runner_id=trial_runner_id, ) self._engine = engine self._schema = schema diff --git a/mlos_bench/mlos_bench/tests/config/cli/test-cli-config.jsonc b/mlos_bench/mlos_bench/tests/config/cli/test-cli-config.jsonc index 436507ce84..4bfd00c3c3 100644 --- a/mlos_bench/mlos_bench/tests/config/cli/test-cli-config.jsonc +++ b/mlos_bench/mlos_bench/tests/config/cli/test-cli-config.jsonc @@ -18,6 +18,7 @@ ], "trial_config_repeat_count": 2, + "num_trial_runners": 3, "random_seed": 42, "random_init": true diff --git a/mlos_bench/mlos_bench/tests/config/schemas/cli/test-cases/bad/invalid/min-trial-runners-count.jsonc b/mlos_bench/mlos_bench/tests/config/schemas/cli/test-cases/bad/invalid/min-trial-runners-count.jsonc new file mode 100644 index 0000000000..251a00c89e --- /dev/null +++ b/mlos_bench/mlos_bench/tests/config/schemas/cli/test-cases/bad/invalid/min-trial-runners-count.jsonc @@ -0,0 +1,3 @@ +{ + "num_trial_runners": 0 // too small +} diff --git a/mlos_bench/mlos_bench/tests/config/schemas/cli/test-cases/good/full/full-cli.jsonc b/mlos_bench/mlos_bench/tests/config/schemas/cli/test-cases/good/full/full-cli.jsonc index 256bd1d687..0373ec3a3e 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/cli/test-cases/good/full/full-cli.jsonc +++ b/mlos_bench/mlos_bench/tests/config/schemas/cli/test-cases/good/full/full-cli.jsonc @@ -16,6 +16,7 @@ "storage": "storage/sqlite.jsonc", "trial_config_repeat_count": 3, + "num_trial_runners": 3, "random_init": true, "random_seed": 42, diff --git a/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py b/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py index 118fa13ba9..4facb39db7 100644 --- a/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py +++ b/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py @@ -70,6 +70,16 @@ def _get_launcher(desc: str, cli_args: str) -> Launcher: # Check the basic parent service assert isinstance(launcher.service, SupportsConfigLoading) # built-in assert isinstance(launcher.service, SupportsLocalExec) # built-in + # All trial runners should have the same Environment class. + assert ( + len(set(trial_runner.environment.__class__ for trial_runner in launcher.trial_runners)) + == 1 + ) + # Make sure that each trial runner has a unique ID. + assert set( + trial_runner.environment.const_args["trial_runner_id"] + for trial_runner in launcher.trial_runners + ) == set(range(0, len(launcher.trial_runners))) return launcher @@ -95,10 +105,16 @@ def test_launcher_args_parse_defaults(config_paths: List[str]) -> None: ) assert launcher.global_config["varWithEnvVarRef"] == f"user:{getuser()}" assert launcher.teardown # defaults + # Make sure we have the right number of trial runners. + assert len(launcher.trial_runners) == 1 # defaults # Check that the environment that got loaded looks to be of the right type. env_config = launcher.config_loader.load_config(ENV_CONF_PATH, ConfigSchema.ENVIRONMENT) assert env_config["class"] == "mlos_bench.environments.mock_env.MockEnv" - assert check_class_name(launcher.environment, env_config["class"]) + # All TrialRunners should get the same Environment. + assert all( + check_class_name(trial_runner.environment, env_config["class"]) + for trial_runner in launcher.trial_runners + ) # Check that the optimizer looks right. assert isinstance(launcher.optimizer, OneShotOptimizer) # Check that the optimizer got initialized with defaults. @@ -121,6 +137,7 @@ def test_launcher_args_parse_1(config_paths: List[str]) -> None: cli_args = ( "--config-paths " + " ".join(config_paths) + + " --num-trial-runners 5" + " --service services/remote/mock/mock_auth_service.jsonc" + " services/remote/mock/mock_remote_exec_service.jsonc" + " --scheduler schedulers/sync_scheduler.jsonc" @@ -148,9 +165,16 @@ def test_launcher_args_parse_1(config_paths: List[str]) -> None: ) assert launcher.global_config["varWithEnvVarRef"] == f"user:{getuser()}" assert launcher.teardown + # Make sure we have the right number of trial runners. + assert len(launcher.trial_runners) == 5 # from cli args # Check that the environment that got loaded looks to be of the right type. env_config = launcher.config_loader.load_config(ENV_CONF_PATH, ConfigSchema.ENVIRONMENT) assert env_config["class"] == "mlos_bench.environments.mock_env.MockEnv" + # All TrialRunners should get the same Environment. + assert all( + check_class_name(trial_runner.environment, env_config["class"]) + for trial_runner in launcher.trial_runners + ) # Check that the optimizer looks right. assert isinstance(launcher.optimizer, OneShotOptimizer) # Check that the optimizer got initialized with defaults. @@ -207,10 +231,16 @@ def test_launcher_args_parse_2(config_paths: List[str]) -> None: path_join(path, abs_path=True) for path in config_paths + config["config_path"] ] + # Make sure we have the right number of trial runners. + assert len(launcher.trial_runners) == 3 # from test-cli-config.jsonc # Check that the environment that got loaded looks to be of the right type. env_config_file = config["environment"] env_config = launcher.config_loader.load_config(env_config_file, ConfigSchema.ENVIRONMENT) - assert check_class_name(launcher.environment, env_config["class"]) + # All TrialRunners should get the same Environment. + assert all( + check_class_name(trial_runner.environment, env_config["class"]) + for trial_runner in launcher.trial_runners + ) # Check that the optimizer looks right. assert isinstance(launcher.optimizer, MlosCoreOptimizer) diff --git a/mlos_bench/mlos_bench/tests/storage/__init__.py b/mlos_bench/mlos_bench/tests/storage/__init__.py index c3b294cae1..9f3819c35f 100644 --- a/mlos_bench/mlos_bench/tests/storage/__init__.py +++ b/mlos_bench/mlos_bench/tests/storage/__init__.py @@ -6,3 +6,4 @@ CONFIG_COUNT = 10 CONFIG_TRIAL_REPEAT_COUNT = 3 +TRIAL_RUNNER_COUNT = 5 diff --git a/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py b/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py index 4e92d9ab9d..f2eb92e6da 100644 --- a/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py +++ b/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py @@ -5,17 +5,18 @@ """Test fixtures for mlos_bench storage.""" from random import seed as rand_seed -from typing import Generator +from typing import Generator, List import pytest from mlos_bench.environments.mock_env import MockEnv from mlos_bench.optimizers.mock_optimizer import MockOptimizer from mlos_bench.schedulers.sync_scheduler import SyncScheduler +from mlos_bench.schedulers.trial_runner import TrialRunner from mlos_bench.storage.base_experiment_data import ExperimentData from mlos_bench.storage.sql.storage import SqlStorage from mlos_bench.tests import SEED -from mlos_bench.tests.storage import CONFIG_COUNT, CONFIG_TRIAL_REPEAT_COUNT +from mlos_bench.tests.storage import CONFIG_COUNT, CONFIG_TRIAL_REPEAT_COUNT, TRIAL_RUNNER_COUNT from mlos_bench.tunables.tunable_groups import TunableGroups # pylint: disable=redefined-outer-name @@ -129,16 +130,25 @@ def _dummy_run_exp( rand_seed(SEED) - env = MockEnv( - name="Test Env", - config={ - "tunable_params": list(exp.tunables.get_covariant_group_names()), - "mock_env_seed": SEED, - "mock_env_range": [60, 120], - "mock_env_metrics": ["score"], - }, - tunables=exp.tunables, - ) + trial_runners: List[TrialRunner] = [] + global_config: dict = {} + # TODO: Make a utility function for this? + for i in range(1, TRIAL_RUNNER_COUNT): + # Create a new global config for each Environment with a unique trial_runner_id for it. + global_config_copy = global_config.copy() + global_config_copy["trial_runner_id"] = i + env = MockEnv( + name="Test Env", + config={ + "tunable_params": list(exp.tunables.get_covariant_group_names()), + "mock_env_seed": SEED, + "mock_env_range": [60, 120], + "mock_env_metrics": ["score"], + }, + global_config=global_config_copy, + tunables=exp.tunables, + ) + trial_runners.append(TrialRunner(trial_runner_id=i, env=env)) opt = MockOptimizer( tunables=exp.tunables, @@ -150,6 +160,7 @@ def _dummy_run_exp( # default values for the tunable params) # "start_with_defaults": True, }, + global_config=global_config, ) scheduler = SyncScheduler( @@ -161,8 +172,8 @@ def _dummy_run_exp( "trial_config_repeat_count": CONFIG_TRIAL_REPEAT_COUNT, "max_trials": CONFIG_COUNT * CONFIG_TRIAL_REPEAT_COUNT, }, - global_config={}, - environment=env, + global_config=global_config, + trial_runners=trial_runners, optimizer=opt, storage=storage, root_env_config=exp.root_env_config, diff --git a/mlos_bench/mlos_bench/tests/storage/trial_data_test.py b/mlos_bench/mlos_bench/tests/storage/trial_data_test.py index ea513eace2..77bc1eb243 100644 --- a/mlos_bench/mlos_bench/tests/storage/trial_data_test.py +++ b/mlos_bench/mlos_bench/tests/storage/trial_data_test.py @@ -21,6 +21,7 @@ def test_exp_trial_data(exp_data: ExperimentData) -> None: assert trial.tunable_config_id == expected_config_id assert trial.status == Status.SUCCEEDED assert trial.metadata_dict["repeat_i"] == 1 + assert trial.metadata_dict["trial_runner_id"] == "1" assert list(trial.results_dict.keys()) == ["score"] assert trial.results_dict["score"] == pytest.approx(73.27, 0.01) assert isinstance(trial.ts_start, datetime) diff --git a/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py b/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py index 8721bbe451..78d1ebfe54 100644 --- a/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py +++ b/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py @@ -10,6 +10,8 @@ from mlos_bench.tests.storage import CONFIG_TRIAL_REPEAT_COUNT from mlos_bench.tunables.tunable_groups import TunableGroups +from mlos_bench.tests.storage import CONFIG_TRIAL_REPEAT_COUNT + def test_trial_data_tunable_config_data( exp_data: ExperimentData,