Skip to content

Commit

Permalink
fix: prompt for multiple options
Browse files Browse the repository at this point in the history
  • Loading branch information
bojiang committed Aug 12, 2024
1 parent f4be770 commit cc7bd0a
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 36 deletions.
6 changes: 4 additions & 2 deletions src/openllm/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
import questionary
import typer

from openllm.accelerator_spec import DeploymentTarget, can_run, get_local_machine_spec
from openllm.accelerator_spec import (DeploymentTarget, can_run,
get_local_machine_spec)
from openllm.analytic import DO_NOT_TRACK, OpenLLMTyper
from openllm.clean import app as clean_app
from openllm.cloud import deploy as cloud_deploy
from openllm.cloud import ensure_cloud_context, get_cloud_machine_spec
from openllm.common import CHECKED, INTERACTIVE, VERBOSE_LEVEL, BentoInfo, output
from openllm.common import (CHECKED, INTERACTIVE, VERBOSE_LEVEL, BentoInfo,
output)
from openllm.local import run as local_run
from openllm.local import serve as local_serve
from openllm.model import app as model_app
Expand Down
3 changes: 2 additions & 1 deletion src/openllm/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

from openllm.accelerator_spec import ACCELERATOR_SPECS
from openllm.analytic import OpenLLMTyper
from openllm.common import INTERACTIVE, BentoInfo, DeploymentTarget, output, run_command
from openllm.common import (INTERACTIVE, BentoInfo, DeploymentTarget, output,
run_command)

app = OpenLLMTyper()

Expand Down
1 change: 0 additions & 1 deletion src/openllm/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def patch(self, value: T):

VERBOSE_LEVEL = ContextVar(10)
INTERACTIVE = ContextVar(False)
FORCE = ContextVar(False)


def output(content, level=0, style=None, end=None):
Expand Down
49 changes: 17 additions & 32 deletions src/openllm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from openllm.accelerator_spec import DeploymentTarget, can_run
from openllm.analytic import OpenLLMTyper
from openllm.common import FORCE, VERBOSE_LEVEL, BentoInfo, load_config, output
from openllm.common import VERBOSE_LEVEL, BentoInfo, load_config, output
from openllm.repo import ensure_repo_updated, parse_repo_url

app = OpenLLMTyper(help='manage models')
Expand Down Expand Up @@ -54,45 +54,28 @@ def is_seen(value):
output(table)


def ensure_bento(model: str, target: Optional[DeploymentTarget] = None, repo_name: Optional[str] = None) -> BentoInfo:
def ensure_bento(
model: str,
target: Optional[DeploymentTarget] = None,
repo_name: Optional[str] = None,
) -> BentoInfo:
bentos = list_bento(model, repo_name=repo_name)
if len(bentos) == 0:
output(f'No model found for {model}', style='red')
raise typer.Exit(1)

if len(bentos) == 1:
if FORCE.get():
output(f'Found model {bentos[0]}', style='green')
return bentos[0]
if target is None:
return bentos[0]
if can_run(bentos[0], target) <= 0:
return bentos[0]
output(f'Found model {bentos[0]}', style='green')
if target is not None and can_run(bentos[0], target) <= 0:
output(f'The machine({target.name}) with {target.accelerators_repr} does not appear to have sufficient '
f'resources to run model {bentos[0]}\n',
style='yellow')
return bentos[0]

if target is None:
output(f'Multiple models match {model}, did you mean one of these?', style='red')
for bento in bentos:
output(f' {bento}')
raise typer.Exit(1)

filtered = [bento for bento in bentos if can_run(bento, target) > 0]
if len(filtered) == 0:
output(f'No deployment target found for {model}', style='red')
raise typer.Exit(1)

if len(filtered) == 0:
output(f'No deployment target found for {model}', style='red')
raise typer.Exit(1)

if len(bentos) > 1:
output(f'Multiple models match {model}, did you mean one of these?', style='red')
for bento in bentos:
output(f' {bento}')
raise typer.Exit(1)

return bentos[0]
# multiple models, pick one according to target
output(f'Multiple models match {model}, did you mean one of these?', style='red')
list_model(model, repo=repo_name)
raise typer.Exit(1)


NUMBER_RE = re.compile(r'\d+')
Expand All @@ -107,7 +90,9 @@ def _extract_first_number(s: str):


def list_bento(
tag: typing.Optional[str] = None, repo_name: typing.Optional[str] = None, include_alias: bool = False
tag: typing.Optional[str] = None,
repo_name: typing.Optional[str] = None,
include_alias: bool = False,
) -> typing.List[BentoInfo]:
ensure_repo_updated()

Expand Down

0 comments on commit cc7bd0a

Please sign in to comment.