From 07f572ea8c6f5f970f77a73e33bd76c0252c9f9b Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 29 Aug 2023 11:12:57 +0900 Subject: [PATCH 1/5] Add TritonModelInput with `optional` --- .pre-commit-config.yaml | 2 +- setup.cfg | 2 ++ tritony/helpers.py | 62 +++++++++++++++++++++++++++++++++-------- tritony/tools.py | 25 +++++++---------- 4 files changed, 64 insertions(+), 27 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 72bd477..28793e6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,7 +28,7 @@ repos: hooks: - id: flake8 types: [python] - args: ["--max-line-length", "120", "--ignore", "F811,F841,E203,E402,E712,W503"] + args: ["--max-line-length", "120", "--ignore", "F811,F841,E203,E402,E712,W503,E501"] - repo: https://github.com/shellcheck-py/shellcheck-py rev: v0.9.0.5 hooks: diff --git a/setup.cfg b/setup.cfg index 18add78..b1edf1b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,6 +43,8 @@ classifiers = zip_safe = False include_package_data = True packages = find: +package_dir = + =. install_requires = tritonclient[all]>=2.21.0 protobuf>=3.5.0 diff --git a/tritony/helpers.py b/tritony/helpers.py index 1957c93..d872e5e 100644 --- a/tritony/helpers.py +++ b/tritony/helpers.py @@ -10,6 +10,7 @@ from attrs import define from tritonclient import grpc as grpcclient from tritonclient import http as httpclient +from tritonclient.grpc import model_config_pb2 class TritonProtocol(Enum): @@ -31,13 +32,32 @@ def dict_to_attr(obj: dict[str, Any]) -> SimpleNamespace: return json.loads(json.dumps(obj), object_hook=lambda d: SimpleNamespace(**d)) +@define +class TritonModelInput: + """ + Most of the fields are mapped to model_config_pb2.ModelInput(https://github.com/triton-inference-server/common/blob/a2de06f4c80b2c7b15469fa4d36e5f6445382bad/protobuf/model_config.proto#L317) + + Commented fields are not used. + """ + + name: str + dtype: str # data_type mapping to https://github.com/triton-inference-server/client/blob/d257c0e5c3de6e15d6ef289ff2b96cecd0a69b5f/src/python/library/tritonclient/utils/__init__.py#L163-L190 + + format: int = 0 + dims: list[int] = [] # dims + + # reshape: list[int] = [] + # is_shape_tensor: bool = False + # allow_ragged_batch: bool = False + optional: bool = False + + @define class TritonModelSpec: name: str max_batch_size: int - input_name: list[str] - input_dtype: list[str] + model_input: list[TritonModelInput] output_name: list[str] @@ -91,7 +111,7 @@ def get_triton_client( model_name: str, model_version: str, protocol: TritonProtocol, -): +) -> (int, list[TritonModelInput], list[str]): """ (required in) :param triton_client: @@ -107,23 +127,43 @@ def get_triton_client( args = dict(model_name=model_name, model_version=model_version) - model_metadata = triton_client.get_model_metadata(**args) model_config = triton_client.get_model_config(**args) if protocol is TritonProtocol.http: - model_metadata = dict_to_attr(model_metadata) model_config = dict_to_attr(model_config) elif protocol is TritonProtocol.grpc: model_config = model_config.config - max_batch_size, input_name_list, output_name_list, dtype_list = parse_model(model_metadata, model_config) + max_batch_size, input_list, output_name_list = parse_model(model_config) + + return max_batch_size, input_list, output_name_list + - return max_batch_size, input_name_list, output_name_list, dtype_list +def parse_model_input( + model_input: model_config_pb2.ModelInput | SimpleNamespace, +) -> TritonModelInput: + """ + https://github.com/triton-inference-server/common/blob/r23.08/protobuf/model_config.proto#L317-L412 + """ + RAW_DTYPE = model_input.data_type + if isinstance(model_input.data_type, int): + RAW_DTYPE = model_config_pb2.DataType.Name(RAW_DTYPE) + RAW_DTYPE = RAW_DTYPE.strip("TYPE_") + + if RAW_DTYPE == "STRING": + RAW_DTYPE = "BYTES" # https://github.com/triton-inference-server/client/blob/d257c0e5c3de6e15d6ef289ff2b96cecd0a69b5f/src/python/library/tritonclient/utils/__init__.py#L188-L189 + return TritonModelInput( + name=model_input.name, + dims=model_input.dims, + dtype=RAW_DTYPE, + optional=model_input.optional, + ) -def parse_model(model_metadata, model_config): +def parse_model( + model_config: model_config_pb2.ModelConfig | SimpleNamespace, +) -> (int, list[TritonModelInput], list[str]): return ( model_config.max_batch_size, - [input_metadata.name for input_metadata in model_metadata.inputs], - [output_metadata.name for output_metadata in model_metadata.outputs], - [input_metadata.datatype for input_metadata in model_metadata.inputs], + [parse_model_input(model_config_input) for model_config_input in model_config.input], + [model_config_output.name for model_config_output in model_config.output], ) diff --git a/tritony/tools.py b/tritony/tools.py index c00acc9..135a358 100644 --- a/tritony/tools.py +++ b/tritony/tools.py @@ -6,7 +6,6 @@ import logging import os import time -import warnings from concurrent.futures import ThreadPoolExecutor from typing import Any, Dict, List, Optional, Union @@ -198,14 +197,6 @@ def triton_client(self): def default_model_spec(self): return self.model_specs[self.default_model] - @property - def input_name_list(self): - warnings.warn( - "input_name_list is deprecated, please use 'default_model_spec.input_name' instead", DeprecationWarning - ) - - return self.default_model_spec.input_name - def __del__(self): # Not supporting streaming # if self.flag.protocol is TritonProtocol.grpc and self.flag.streaming and hasattr(self, "triton_client"): @@ -223,15 +214,14 @@ def _renew_triton_client(self, triton_client, model_name: str | None = None, mod triton_client.is_server_ready() triton_client.is_model_ready(model_name, model_version) - (max_batch_size, input_name_list, output_name_list, dtype_list) = get_triton_client( + (max_batch_size, input_list, output_name_list) = get_triton_client( triton_client, model_name=model_name, model_version=model_version, protocol=self.flag.protocol ) self.model_specs[(model_name, model_version)] = TritonModelSpec( name=model_name, max_batch_size=max_batch_size, - input_name=input_name_list, - input_dtype=dtype_list, + model_input=input_list, output_name=output_name_list, ) @@ -257,7 +247,12 @@ def __call__( if type(sequences_or_dict) in [list, np.ndarray]: sequences_list = [sequences_or_dict] elif type(sequences_or_dict) is dict: - sequences_list = [sequences_or_dict[input_name] for input_name in model_spec.input_name] + sequences_list = [ + sequences_or_dict[model_input.name] + for model_input in model_spec.model_input + if model_input.optional is False # check required + or (model_input.optional is True and model_input.name in sequences_or_dict) # check optional + ] return self._call_async(sequences_list, model_spec=model_spec) @@ -267,8 +262,8 @@ def build_triton_input(self, _input_list: List[np.array], model_spec: TritonMode else: client = httpclient infer_input_list = [] - for _input, _input_name, _dtype in zip(_input_list, model_spec.input_name, model_spec.input_dtype): - infer_input = client.InferInput(_input_name, _input.shape, _dtype) + for _input, _model_input in zip(_input_list, model_spec.model_input): + infer_input = client.InferInput(_model_input.name, _input.shape, _model_input.dtype) infer_input.set_data_from_numpy(_input) infer_input_list.append(infer_input) From 0d343d7193e59cadb9bf0c32546c3499c72119a1 Mon Sep 17 00:00:00 2001 From: Dongwoo Arthur Kim Date: Tue, 29 Aug 2023 23:39:37 +0900 Subject: [PATCH 2/5] Remove deprecated tests --- tests/test_model_call.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/test_model_call.py b/tests/test_model_call.py index 7f6f450..314b9d1 100644 --- a/tests/test_model_call.py +++ b/tests/test_model_call.py @@ -37,11 +37,7 @@ def test_with_input_name(protocol_and_port): client = InferenceClient.create_with(MODEL_NAME, f"{TRITON_HOST}:{port}", protocol=protocol) - sample = np.random.rand(1, 100).astype(np.float32) - result = client({client.input_name_list[0]: sample}) - print(f"Result: {np.isclose(result, sample).all()}") - sample = np.random.rand(100, 100).astype(np.float32) - result = client({client.default_model_spec.input_name[0]: sample}) + result = client({client.default_model_spec.model_input[0].name: sample}) print(f"Result: {np.isclose(result, sample).all()}") From f071bc5b3defc8e3fe7f9b6b2f0731ea80c2fab9 Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 30 Aug 2023 19:57:25 +0900 Subject: [PATCH 3/5] Support `parameters` on config.pbtxt --- model_repository/sample/1/model.py | 5 ++- model_repository/sample/config.pbtxt | 8 +++++ .../sample_autobatching/config.pbtxt | 8 +++++ model_repository/sample_multiple/config.pbtxt | 8 +++++ tests/test_model_call.py | 30 ++++++++++------ tritony/tools.py | 34 +++++++++++++++---- tritony/version.py | 2 +- 7 files changed, 75 insertions(+), 20 deletions(-) diff --git a/model_repository/sample/1/model.py b/model_repository/sample/1/model.py index ff8803a..1618bb2 100644 --- a/model_repository/sample/1/model.py +++ b/model_repository/sample/1/model.py @@ -13,10 +13,13 @@ def initialize(self, args): pb_utils.triton_string_to_numpy(output_config["data_type"]) for output_config in output_configs ] + parameters = self.model_config["parameters"] + def execute(self, requests): responses = [None for _ in requests] for idx, request in enumerate(requests): - in_tensor = [item.as_numpy() for item in request.inputs()] + current_add_value = int(json.loads(request.parameters()).get("add", 0)) + in_tensor = [item.as_numpy() + current_add_value for item in request.inputs()] out_tensor = [ pb_utils.Tensor(output_name, x.astype(output_dtype)) for x, output_name, output_dtype in zip(in_tensor, self.output_name_list, self.output_dtype_list) diff --git a/model_repository/sample/config.pbtxt b/model_repository/sample/config.pbtxt index cb6537d..60b403d 100644 --- a/model_repository/sample/config.pbtxt +++ b/model_repository/sample/config.pbtxt @@ -1,6 +1,14 @@ name: "sample" backend: "python" max_batch_size: 0 + +parameters [ + { + key: "add", + value: { string_value: "0" } + } +] + input [ { name: "model_in" diff --git a/model_repository/sample_autobatching/config.pbtxt b/model_repository/sample_autobatching/config.pbtxt index 14393b2..0d67899 100644 --- a/model_repository/sample_autobatching/config.pbtxt +++ b/model_repository/sample_autobatching/config.pbtxt @@ -1,6 +1,14 @@ name: "sample_autobatching" backend: "python" max_batch_size: 2 + +parameters [ + { + key: "add", + value: { string_value: "0" } + } +] + input [ { name: "model_in" diff --git a/model_repository/sample_multiple/config.pbtxt b/model_repository/sample_multiple/config.pbtxt index 7a880cf..8a6b357 100644 --- a/model_repository/sample_multiple/config.pbtxt +++ b/model_repository/sample_multiple/config.pbtxt @@ -1,6 +1,14 @@ name: "sample_multiple" backend: "python" max_batch_size: 2 + +parameters [ + { + key: "add", + value: { string_value: "0" } + } +] + input [ { name: "model_in0" diff --git a/tests/test_model_call.py b/tests/test_model_call.py index 314b9d1..99b8f30 100644 --- a/tests/test_model_call.py +++ b/tests/test_model_call.py @@ -16,28 +16,36 @@ def protocol_and_port(request): return request.param -def test_swithcing(protocol_and_port): - protocol, port = protocol_and_port - print(f"Testing {protocol}") +def get_client(protocol, port): + print(f"Testing {protocol}", flush=True) + return InferenceClient.create_with(MODEL_NAME, f"{TRITON_HOST}:{port}", protocol=protocol) + - client = InferenceClient.create_with(MODEL_NAME, f"{TRITON_HOST}:{port}", protocol=protocol) +def test_swithcing(protocol_and_port): + client = get_client(*protocol_and_port) sample = np.random.rand(1, 100).astype(np.float32) result = client(sample) - print(f"Result: {np.isclose(result, sample).all()}") + assert {np.isclose(result, sample).all()} sample_batched = np.random.rand(100, 100).astype(np.float32) client(sample_batched, model_name="sample_autobatching") - print(f"Result: {np.isclose(result, sample).all()}") + assert {np.isclose(result, sample).all()} def test_with_input_name(protocol_and_port): - protocol, port = protocol_and_port - print(f"Testing {protocol}") - - client = InferenceClient.create_with(MODEL_NAME, f"{TRITON_HOST}:{port}", protocol=protocol) + client = get_client(*protocol_and_port) sample = np.random.rand(100, 100).astype(np.float32) result = client({client.default_model_spec.model_input[0].name: sample}) + assert {np.isclose(result, sample).all()} + + +def test_with_parameters(protocol_and_port): + client = get_client(*protocol_and_port) + + sample = np.random.rand(1, 100).astype(np.float32) + ADD_VALUE = 1 + result = client({client.default_model_spec.model_input[0].name: sample}, parameters={"add": f"{ADD_VALUE}"}) - print(f"Result: {np.isclose(result, sample).all()}") + assert {np.isclose(result[0], sample[0] + ADD_VALUE).all()} diff --git a/tritony/tools.py b/tritony/tools.py index 135a358..1581c58 100644 --- a/tritony/tools.py +++ b/tritony/tools.py @@ -72,6 +72,7 @@ async def send_request_async( done_event, triton_client: Union[grpcclient.InferenceServerClient, httpclient.InferenceServerClient], model_spec: TritonModelSpec, + parameters: dict | None = None, ): ret = [] while True: @@ -86,7 +87,7 @@ async def send_request_async( try: a_pred = await request_async( inference_client.flag.protocol, - inference_client.build_triton_input(batch_data, model_spec), + inference_client.build_triton_input(batch_data, model_spec, parameters=parameters), triton_client, timeout=inference_client.client_timeout, compression=inference_client.flag.compression_algorithm, @@ -232,6 +233,7 @@ def _get_request_id(self): def __call__( self, sequences_or_dict: Union[List[Any], Dict[str, List[Any]]], + parameters: dict | None = None, model_name: str | None = None, model_version: str | None = None, ): @@ -254,9 +256,14 @@ def __call__( or (model_input.optional is True and model_input.name in sequences_or_dict) # check optional ] - return self._call_async(sequences_list, model_spec=model_spec) + return self._call_async(sequences_list, model_spec=model_spec, parameters=parameters) - def build_triton_input(self, _input_list: List[np.array], model_spec: TritonModelSpec): + def build_triton_input( + self, + _input_list: List[np.array], + model_spec: TritonModelSpec, + parameters: dict | None = None, + ): if self.flag.protocol is TritonProtocol.grpc: client = grpcclient else: @@ -278,19 +285,30 @@ def build_triton_input(self, _input_list: List[np.array], model_spec: TritonMode request_id=str(request_id), model_version=model_spec.model_version, outputs=infer_requested_output, + parameters=parameters, ) return request_input - def _call_async(self, data: List[np.ndarray], model_spec: TritonModelSpec) -> Optional[np.ndarray]: - async_result = asyncio.run(self._call_async_item(data=data, model_spec=model_spec)) + def _call_async( + self, + data: List[np.ndarray], + model_spec: TritonModelSpec, + parameters: dict | None = None, + ) -> Optional[np.ndarray]: + async_result = asyncio.run(self._call_async_item(data=data, model_spec=model_spec, parameters=parameters)) if isinstance(async_result, Exception): raise async_result return async_result - async def _call_async_item(self, data: List[np.ndarray], model_spec: TritonModelSpec): + async def _call_async_item( + self, + data: List[np.ndarray], + model_spec: TritonModelSpec, + parameters: dict | None = None, + ): current_grpc_async_tasks = [] try: @@ -301,7 +319,9 @@ async def _call_async_item(self, data: List[np.ndarray], model_spec: TritonModel current_grpc_async_tasks.append(generator) predict_tasks = [ - asyncio.create_task(send_request_async(self, data_queue, done_event, self.triton_client, model_spec)) + asyncio.create_task( + send_request_async(self, data_queue, done_event, self.triton_client, model_spec, parameters) + ) for idx in range(ASYNC_TASKS) ] current_grpc_async_tasks.extend(predict_tasks) diff --git a/tritony/version.py b/tritony/version.py index b2f0155..72eb129 100644 --- a/tritony/version.py +++ b/tritony/version.py @@ -1 +1 @@ -__version__ = "0.0.11" +__version__ = "0.0.12rc0" From d82974e48c7421f902b0babc686e152b03c578b2 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 31 Aug 2023 13:25:56 +0900 Subject: [PATCH 4/5] Add sample_optional, and fix pytest for parameters, optional --- .github/workflows/pre-commit_pytest.yml | 2 +- README.md | 1 + bin/run_triton_tritony_sample.sh | 18 +++++++ model_repository/sample/1/model.py | 4 +- model_repository/sample/config.pbtxt | 7 --- model_repository/sample_optional/1/model.py | 37 +++++++++++++ model_repository/sample_optional/config.pbtxt | 54 +++++++++++++++++++ tests/test_model_call.py | 43 +++++++++++---- tritony/tools.py | 6 ++- 9 files changed, 151 insertions(+), 21 deletions(-) create mode 100755 bin/run_triton_tritony_sample.sh create mode 100644 model_repository/sample_optional/1/model.py create mode 100644 model_repository/sample_optional/config.pbtxt diff --git a/.github/workflows/pre-commit_pytest.yml b/.github/workflows/pre-commit_pytest.yml index ad5c72a..6ae36e8 100644 --- a/.github/workflows/pre-commit_pytest.yml +++ b/.github/workflows/pre-commit_pytest.yml @@ -37,7 +37,7 @@ jobs: runs-on: ubuntu-latest needs: pre-commit container: - image: nvcr.io/nvidia/tritonserver:23.03-pyt-python-py3 + image: nvcr.io/nvidia/tritonserver:23.08-pyt-python-py3 options: --shm-size=1g steps: diff --git a/README.md b/README.md index 9242aff..51a41ff 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,7 @@ if __name__ == "__main__": ## Release Notes +- 23.08.30 Support `optional` with model input, `parameters` on config.pbtxt - 23.06.16 Support tritonclient>=2.34.0 - Loosely modified the requirements related to tritonclient diff --git a/bin/run_triton_tritony_sample.sh b/bin/run_triton_tritony_sample.sh new file mode 100755 index 0000000..1903b46 --- /dev/null +++ b/bin/run_triton_tritony_sample.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +HERE=$(dirname "$(readlink -f $0)") +PARENT_DIR=$(dirname "$HERE") + +docker run -it --rm --name triton_tritony \ + -p8100:8000 \ + -p8101:8001 \ + -p8102:8002 \ + -v "${PARENT_DIR}"/model_repository:/models:ro \ + -e OMP_NUM_THREADS=2 \ + -e OPENBLAS_NUM_THREADS=2 \ + --shm-size=1g \ + nvcr.io/nvidia/tritonserver:23.08-pyt-python-py3 \ + tritonserver --model-repository=/models \ + --exit-timeout-secs 15 \ + --min-supported-compute-capability 7.0 \ + --log-verbose 0 # 0-nothing, 1-info, 2-debug, 3-trace diff --git a/model_repository/sample/1/model.py b/model_repository/sample/1/model.py index 1618bb2..60af9b5 100644 --- a/model_repository/sample/1/model.py +++ b/model_repository/sample/1/model.py @@ -13,13 +13,11 @@ def initialize(self, args): pb_utils.triton_string_to_numpy(output_config["data_type"]) for output_config in output_configs ] - parameters = self.model_config["parameters"] - def execute(self, requests): responses = [None for _ in requests] for idx, request in enumerate(requests): current_add_value = int(json.loads(request.parameters()).get("add", 0)) - in_tensor = [item.as_numpy() + current_add_value for item in request.inputs()] + in_tensor = [item.as_numpy() + current_add_value for item in request.inputs() if item.name() == "model_in"] out_tensor = [ pb_utils.Tensor(output_name, x.astype(output_dtype)) for x, output_name, output_dtype in zip(in_tensor, self.output_name_list, self.output_dtype_list) diff --git a/model_repository/sample/config.pbtxt b/model_repository/sample/config.pbtxt index 60b403d..6c6fec0 100644 --- a/model_repository/sample/config.pbtxt +++ b/model_repository/sample/config.pbtxt @@ -2,13 +2,6 @@ name: "sample" backend: "python" max_batch_size: 0 -parameters [ - { - key: "add", - value: { string_value: "0" } - } -] - input [ { name: "model_in" diff --git a/model_repository/sample_optional/1/model.py b/model_repository/sample_optional/1/model.py new file mode 100644 index 0000000..94799d1 --- /dev/null +++ b/model_repository/sample_optional/1/model.py @@ -0,0 +1,37 @@ +import json + +import triton_python_backend_utils as pb_utils + + +class TritonPythonModel: + def initialize(self, args): + self.model_config = model_config = json.loads(args["model_config"]) + output_configs = model_config["output"] + + self.output_name_list = [output_config["name"] for output_config in output_configs] + self.output_dtype_list = [ + pb_utils.triton_string_to_numpy(output_config["data_type"]) for output_config in output_configs + ] + + def execute(self, requests): + responses = [None for _ in requests] + for idx, request in enumerate(requests): + current_add_value = int(json.loads(request.parameters()).get("add", 0)) + optional_in_tensor = pb_utils.get_input_tensor_by_name(request, "optional_model_sub") + if optional_in_tensor: + optional_in_tensor = optional_in_tensor.as_numpy() + else: + optional_in_tensor = 0 + + in_tensor = [ + item.as_numpy() + current_add_value - optional_in_tensor + for item in request.inputs() + if item.name() == "model_in" + ] + out_tensor = [ + pb_utils.Tensor(output_name, x.astype(output_dtype)) + for x, output_name, output_dtype in zip(in_tensor, self.output_name_list, self.output_dtype_list) + ] + inference_response = pb_utils.InferenceResponse(output_tensors=out_tensor) + responses[idx] = inference_response + return responses diff --git a/model_repository/sample_optional/config.pbtxt b/model_repository/sample_optional/config.pbtxt new file mode 100644 index 0000000..e409997 --- /dev/null +++ b/model_repository/sample_optional/config.pbtxt @@ -0,0 +1,54 @@ +name: "sample_optional" +backend: "python" +max_batch_size: 0 + +parameters [ + { + key: "add", + value: { string_value: "0" } + } +] + +input [ +{ + name: "model_in" + data_type: TYPE_FP32 + dims: [ -1 ] +}, +{ + name: "optional_model_sub" + data_type: TYPE_FP32 + optional: true + dims: [ -1 ] +} +] + +output [ +{ + name: "model_out" + data_type: TYPE_FP32 + dims: [ -1 ] +} +] + +instance_group [{ kind: KIND_CPU, count: 1 }] + +model_warmup { + name: "RandomSampleInput" + batch_size: 1 + inputs [{ + key: "model_in" + value: { + data_type: TYPE_FP32 + dims: [ 10 ] + random_data: true + } + }, { + key: "model_in" + value: { + data_type: TYPE_FP32 + dims: [ 10 ] + zero_data: true + } + }] +} \ No newline at end of file diff --git a/tests/test_model_call.py b/tests/test_model_call.py index 99b8f30..e357a09 100644 --- a/tests/test_model_call.py +++ b/tests/test_model_call.py @@ -5,24 +5,26 @@ from tritony import InferenceClient -MODEL_NAME = os.environ.get("MODEL_NAME", "sample") TRITON_HOST = os.environ.get("TRITON_HOST", "localhost") TRITON_HTTP = os.environ.get("TRITON_HTTP", "8000") TRITON_GRPC = os.environ.get("TRITON_GRPC", "8001") +EPSILON = 1e-8 + + @pytest.fixture(params=[("http", TRITON_HTTP), ("grpc", TRITON_GRPC)]) def protocol_and_port(request): return request.param -def get_client(protocol, port): +def get_client(protocol, port, model_name): print(f"Testing {protocol}", flush=True) - return InferenceClient.create_with(MODEL_NAME, f"{TRITON_HOST}:{port}", protocol=protocol) + return InferenceClient.create_with(model_name, f"{TRITON_HOST}:{port}", protocol=protocol) def test_swithcing(protocol_and_port): - client = get_client(*protocol_and_port) + client = get_client(*protocol_and_port, model_name="sample") sample = np.random.rand(1, 100).astype(np.float32) result = client(sample) @@ -30,22 +32,45 @@ def test_swithcing(protocol_and_port): sample_batched = np.random.rand(100, 100).astype(np.float32) client(sample_batched, model_name="sample_autobatching") - assert {np.isclose(result, sample).all()} + assert np.isclose(result, sample).all() def test_with_input_name(protocol_and_port): - client = get_client(*protocol_and_port) + client = get_client(*protocol_and_port, model_name="sample") sample = np.random.rand(100, 100).astype(np.float32) result = client({client.default_model_spec.model_input[0].name: sample}) - assert {np.isclose(result, sample).all()} + assert np.isclose(result, sample).all() def test_with_parameters(protocol_and_port): - client = get_client(*protocol_and_port) + client = get_client(*protocol_and_port, model_name="sample") sample = np.random.rand(1, 100).astype(np.float32) ADD_VALUE = 1 result = client({client.default_model_spec.model_input[0].name: sample}, parameters={"add": f"{ADD_VALUE}"}) - assert {np.isclose(result[0], sample[0] + ADD_VALUE).all()} + assert np.isclose(result[0], sample[0] + ADD_VALUE).all() + + +def test_with_optional(protocol_and_port): + client = get_client(*protocol_and_port, model_name="sample_optional") + + sample = np.random.rand(1, 100).astype(np.float32) + + result = client({client.default_model_spec.model_input[0].name: sample}) + assert np.isclose(result[0], sample[0], rtol=EPSILON).all() + + OPTIONAL_SUB_VALUE = np.zeros_like(sample) + 3 + result = client( + { + client.default_model_spec.model_input[0].name: sample, + "optional_model_sub": OPTIONAL_SUB_VALUE, + } + ) + assert np.isclose(result[0], sample[0] - OPTIONAL_SUB_VALUE, rtol=EPSILON).all() + + +if __name__ == "__main__": + test_with_parameters(("grpc", "8101")) + test_with_optional(("grpc", "8101")) diff --git a/tritony/tools.py b/tritony/tools.py index 1581c58..e09b0ea 100644 --- a/tritony/tools.py +++ b/tritony/tools.py @@ -115,7 +115,11 @@ async def request_async(protocol: TritonProtocol, model_input: Dict, triton_clie loop = asyncio.get_running_loop() if "parameters" in grpc_get_inference_request.__code__.co_varnames: - model_input["parameters"] = None + # check tritonclient[all]>=2.34.0, NGC 23.04 + model_input["parameters"] = model_input.get("parameters", None) + else: + logger.warning("tritonclient[all]<2.34.0, NGC 21.04") + model_input.pop("parameters") request = grpc_get_inference_request( **model_input, priority=0, From bfea3e9c98e31c8ce6d868c9524a9f0f373dc8c0 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 31 Aug 2023 13:51:08 +0900 Subject: [PATCH 5/5] Fix README.md with 23.08 --- README.md | 7 ++----- model_repository/sample_optional/config.pbtxt | 6 ++++++ tests/test_connect.py | 6 +++--- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 51a41ff..0d81c23 100644 --- a/README.md +++ b/README.md @@ -80,14 +80,11 @@ if __name__ == "__main__": ### With Triton ```bash -docker run --rm \ - -v ${PWD}:/models \ - nvcr.io/nvidia/tritonserver:22.01-pyt-python-py3 \ - tritonserver --model-repo=/models +./bin/run_triton_tritony_sample.sh ``` ```bash -pytest -m -s tests/test_tritony.py +pytest -s --cov-report term-missing --cov=tritony tests/ ``` ### Example with image_client.py diff --git a/model_repository/sample_optional/config.pbtxt b/model_repository/sample_optional/config.pbtxt index e409997..d5a5aea 100644 --- a/model_repository/sample_optional/config.pbtxt +++ b/model_repository/sample_optional/config.pbtxt @@ -20,6 +20,12 @@ input [ data_type: TYPE_FP32 optional: true dims: [ -1 ] +}, +{ + name: "optional_model_string" + data_type: TYPE_STRING + optional: true + dims: [ -1 ] } ] diff --git a/tests/test_connect.py b/tests/test_connect.py index 886fd3f..b282b25 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -25,10 +25,10 @@ def test_basics(protocol_and_port): sample = np.random.rand(1, 100).astype(np.float32) result = client(sample) - print(f"Result: {np.isclose(result, sample).all()}") + assert np.isclose(result, sample).all() result = client({"model_in": sample}) - print(f"Dict Result: {np.isclose(result, sample).all()}") + assert np.isclose(result, sample).all() def test_batching(protocol_and_port): @@ -40,7 +40,7 @@ def test_batching(protocol_and_port): sample = np.random.rand(100, 100).astype(np.float32) # client automatically makes sub batches with (50, 2, 100) result = client(sample) - print(f"Result: {np.isclose(result, sample).all()}") + assert np.isclose(result, sample).all() def test_exception(protocol_and_port):