Skip to content

Commit

Permalink
fix coompatible pproblem for python3.10 and gradio (#1249)
Browse files Browse the repository at this point in the history
  • Loading branch information
rainyfly committed May 4, 2023
1 parent 6ae0838 commit e420b8c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 17 deletions.
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ packaging
x2paddle >= 1.4.0
paddle2onnx >= 1.0.5
rarfile
gradio
gradio == 3.11.0
tritonclient[all]
attrdict
psutil
onnx >= 1.6.0
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,6 @@
import numpy as np
import requests
import tritonclient.http as httpclient
from attrdict import AttrDict


def convert_http_metadata_config(metadata):
metadata = AttrDict(metadata)

return metadata


def prepare_request(inputs_meta, inputs_data, outputs_meta):
Expand Down Expand Up @@ -58,7 +51,7 @@ def prepare_request(inputs_meta, inputs_data, outputs_meta):
inputs.append(infer_input)
outputs = []
for output_dict in outputs_meta:
infer_output = httpclient.InferRequestedOutput(output_dict.name)
infer_output = httpclient.InferRequestedOutput(output_dict['name'])
outputs.append(infer_output)
return inputs, outputs

Expand Down Expand Up @@ -321,8 +314,8 @@ def infer(self, server_url, model_name, model_version, inputs):

results = {}
for output in output_metadata:
result = response.as_numpy(output.name) # datatype: numpy
if output.datatype == 'BYTES': # datatype: bytes
result = response.as_numpy(output['name']) # datatype: numpy
if output['datatype'] == 'BYTES': # datatype: bytes
try:
value = result
if len(result.shape) == 1:
Expand All @@ -336,7 +329,7 @@ def infer(self, server_url, model_name, model_version, inputs):
pass
else:
result = result[0]
results[output.name] = result
results[output['name']] = result
return results

def raw_infer(self, server_url, model_name, model_version, raw_input):
Expand All @@ -353,8 +346,6 @@ def get_model_meta(self, server_url, model_name, model_version):
except Exception as e:
raise RuntimeError("Failed to retrieve the metadata: " + str(e))

model_metadata = convert_http_metadata_config(model_metadata)

input_metadata = model_metadata.inputs
output_metadata = model_metadata.outputs
input_metadata = model_metadata['inputs']
output_metadata = model_metadata['outputs']
return input_metadata, output_metadata

0 comments on commit e420b8c

Please sign in to comment.