diff --git a/.github/workflows/manual_update_version.yml b/.github/workflows/manual_update_version.yml index 344079f..556d60a 100644 --- a/.github/workflows/manual_update_version.yml +++ b/.github/workflows/manual_update_version.yml @@ -27,6 +27,11 @@ jobs: # sed -i.bak -E "s|[^<]+|$VERSION|" wrapyfi_extensions/wrapyfi_ros2_interfaces/package.xml # sed -i.bak -E "s|[^<]+|$VERSION|" wrapyfi_extensions/wrapyfi_ros_interfaces/package.xml + - name: Refactor with Code Black + run: | + python3 -m pip install black + black . + - name: Generate Documentation run: | cd docs diff --git a/README.md b/README.md index cbbe323..95e5ab9 100755 --- a/README.md +++ b/README.md @@ -14,12 +14,17 @@ [![arXiv](https://custom-icon-badges.demolab.com/badge/arXiv:2302.09648-lightyellow.svg?logo=arxiv-logomark-small)](https://arxiv.org/abs/2302.09648 "arXiv link") [![doi](https://custom-icon-badges.demolab.com/badge/10.1145/3610977.3637471-lightyellow.svg?logo=doi_logo)](https://doi.org/10.1145/3610977.3637471 "doi link") + + [![License](https://custom-icon-badges.demolab.com/github/license/denvercoder1/custom-icon-badges?logo=law&logoColor=white)](https://github.com/fabawi/wrapyfi/blob/main/LICENSE "license MIT") [![FOSSA status](https://app.fossa.com/api/projects/git%2Bgithub.com%2Fmodular-ml%2Fwrapyfi.svg?type=shield&issueType=license)](https://app.fossa.com/projects/git%2Bgithub.com%2Fmodular-ml%2Fwrapyfi?ref=badge_shield&issueType=license) [![Documentation status](https://readthedocs.org/projects/wrapyfi/badge/?version=latest)](https://wrapyfi.readthedocs.io/en/latest/?badge=latest) -[![codecov](https://codecov.io/github/modular-ml/wrapyfi/graph/badge.svg?token=5SD1A6ENKE)](https://codecov.io/github/modular-ml/wrapyfi) +[![codecov](https://codecov.io/github/modular-ml/wrapyfi/graph/badge.svg?token=5SD1A6ENKE)](https://codecov.io/github/modular-ml/wrapyfi) [![PyPI version](https://badge.fury.io/py/wrapyfi.svg)](https://badge.fury.io/py/wrapyfi) +[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://black.readthedocs.io/ "code style link") +[![PyPI - Implementation](https://img.shields.io/pypi/implementation/wrapyfi)](https://pypi.org/project/wrapyfi/ "implementation") +[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/wrapyfi)](https://pypi.org/project/wrapyfi/ "python version") [![PyPI total downloads](https://img.shields.io/pepy/dt/wrapyfi)](https://www.pepy.tech/projects/wrapyfi) [![Docker Hub Pulls](https://img.shields.io/docker/pulls/modularml/wrapyfi.svg)](https://hub.docker.com/repository/docker/modularml/wrapyfi) diff --git a/docs/_extensions/link_modifier.py b/docs/_extensions/link_modifier.py index 122686d..7e7fc97 100644 --- a/docs/_extensions/link_modifier.py +++ b/docs/_extensions/link_modifier.py @@ -3,38 +3,36 @@ import re REPLACEMENTS = { - 'https://github.com/fabawi/wrapyfi/tree/main/wrapyfi_extensions/yarp/README.md': - 'yarp_install_lnk.html', - 'https://github.com/modular-ml/wrapyfi_ros2_interfaces/blob/master/README.md': - 'ros2_interfaces_lnk.html', - 'https://github.com/modular-ml/wrapyfi_ros_interfaces/blob/master/README.md': - 'ros_interfaces_lnk.html', - 'https://github.com/fabawi/wrapyfi/tree/main/dockerfiles/README.md': - 'wrapyfi_docker_lnk.html', + "https://github.com/fabawi/wrapyfi/tree/main/wrapyfi_extensions/yarp/README.md": "yarp_install_lnk.html", + "https://github.com/modular-ml/wrapyfi_ros2_interfaces/blob/master/README.md": "ros2_interfaces_lnk.html", + "https://github.com/modular-ml/wrapyfi_ros_interfaces/blob/master/README.md": "ros_interfaces_lnk.html", + "https://github.com/fabawi/wrapyfi/tree/main/dockerfiles/README.md": "wrapyfi_docker_lnk.html", } + class LinkModifier(SphinxTransform): default_priority = 999 def apply(self): for node in self.document.traverse(nodes.reference): - uri = node.get('refuri', '') + uri = node.get("refuri", "") for link in REPLACEMENTS.keys(): if link in uri: # Extract rank value - match = re.search(r'\?rank=(-?\d+)', uri) + match = re.search(r"\?rank=(-?\d+)", uri) if match: rank = int(match.group(1)) if rank < 0: - prefix = '../' * -rank + prefix = "../" * -rank new_uri = prefix + REPLACEMENTS[link] else: new_uri = REPLACEMENTS[link] else: # Default to rank -2 if no rank parameter - new_uri = '../../' + REPLACEMENTS[link] + new_uri = "../../" + REPLACEMENTS[link] + + node["refuri"] = uri.replace(link, new_uri) - node['refuri'] = uri.replace(link, new_uri) def setup(app): app.add_transform(LinkModifier) diff --git a/docs/_extensions/math_block_converter.py b/docs/_extensions/math_block_converter.py index dc4165a..c658afb 100644 --- a/docs/_extensions/math_block_converter.py +++ b/docs/_extensions/math_block_converter.py @@ -1,9 +1,11 @@ from sphinx.application import Sphinx import re + def convert_math_blocks(app: Sphinx, docname: str, source: list): if source: - source[0] = re.sub(r'```math', '```{math}', source[0]) + source[0] = re.sub(r"```math", "```{math}", source[0]) + def setup(app: Sphinx): app.connect("source-read", convert_math_blocks) diff --git a/docs/conf.py b/docs/conf.py index bfe9a9e..ab8653c 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -6,40 +6,41 @@ import ast import json + def get_project_info_from_setup(): curr_dir = os.path.dirname(__file__) - setup_path = os.path.join(curr_dir, '..', 'setup.py') - with open(setup_path, 'r') as f: + setup_path = os.path.join(curr_dir, "..", "setup.py") + with open(setup_path, "r") as f: content = f.read() - + name_match = re.search(r"name\s*=\s*['\"]([^'\"]*)['\"]", content) version_match = re.search(r"version\s*=\s*['\"]([^'\"]*)['\"]", content) url_match = re.search(r"url\s*=\s*['\"]([^'\"]*)['\"]", content) author_match = re.search(r"author\s*=\s*['\"]([^'\"]*)['\"]", content) - + if not name_match or not version_match or not url_match: raise RuntimeError("Unable to find name, version, url, or author string.") - + return { - 'name': name_match.group(1), - 'version': version_match.group(1), - 'url': url_match.group(1), - 'author': author_match.group(1) + "name": name_match.group(1), + "version": version_match.group(1), + "url": url_match.group(1), + "author": author_match.group(1), } def get_imported_modules(package_name): package = importlib.import_module(package_name) imported_modules = [] - for _, module_name, _ in pkgutil.walk_packages(path=package.__path__, - prefix=package.__name__ + '.', - onerror=lambda x: None): + for _, module_name, _ in pkgutil.walk_packages( + path=package.__path__, prefix=package.__name__ + ".", onerror=lambda x: None + ): imported_modules.append(module_name) return imported_modules def get_all_imports_in_file(file_path): - with open(file_path, 'r', encoding='utf-8') as file: + with open(file_path, "r", encoding="utf-8") as file: tree = ast.parse(file.read(), filename=file_path) imports = set() @@ -57,37 +58,45 @@ def get_all_imports_in_package(package_path): all_imports = set() for root, dirs, files in os.walk(package_path): for file in files: - if file.endswith('.py'): + if file.endswith(".py"): file_path = os.path.join(root, file) all_imports.update(get_all_imports_in_file(file_path)) return all_imports + def setup(app): - app.add_css_file('wide_theme.css') + app.add_css_file("wide_theme.css") autodoc_default_options = { - 'members': True, - 'member-order': 'bysource', - 'special-members': '__init__', - 'undoc-members': True, - 'exclude-members': '__weakref__' + "members": True, + "member-order": "bysource", + "special-members": "__init__", + "undoc-members": True, + "exclude-members": "__weakref__", } -main_doc = 'index' -html_theme = 'sphinx_rtd_theme' -html_static_path = ['_static'] -html_css_files = ['wide_theme.css'] - -extensions = ['sphinx.ext.todo', 'sphinx.ext.viewcode', 'sphinx.ext.autodoc', 'myst_parser', 'sphinx.ext.mathjax', - 'math_block_converter', 'link_modifier'] -source_suffix = ['.rst', '.md'] +main_doc = "index" +html_theme = "sphinx_rtd_theme" +html_static_path = ["_static"] +html_css_files = ["wide_theme.css"] + +extensions = [ + "sphinx.ext.todo", + "sphinx.ext.viewcode", + "sphinx.ext.autodoc", + "myst_parser", + "sphinx.ext.mathjax", + "math_block_converter", + "link_modifier", +] +source_suffix = [".rst", ".md"] exclude_patterns = ["_build"] myst_enable_extensions = ["dollarmath", "amsmath"] # mock all libraries except for the ones that are installed -with open('exclude_packages.json', 'r') as f: +with open("exclude_packages.json", "r") as f: all_imported_modules_pre = set(x for x in json.load(f) if x is not None) print(all_imported_modules_pre) # all_imported_modules = get_all_imports_in_package("wrapyfi") @@ -100,27 +109,31 @@ def setup(app): # extract project info project_info = get_project_info_from_setup() -project = project_info['name'] -release = project_info['version'] -version = '.'.join(release.split('.')[:2]) -url = project_info['url'] -author = project_info['author'] +project = project_info["name"] +release = project_info["version"] +version = ".".join(release.split(".")[:2]) +url = project_info["url"] +author = project_info["author"] # modify the latex cover page for pdf generation latex_elements = { - 'preamble': r''' + "preamble": r""" \usepackage{titling} \pretitle{% \begin{center} \vspace{\droptitle} \includegraphics[width=60mm]{../assets/wrapyfi.png}\\[\bigskipamount] - \Large{\textbf{''' + project + '''}}\\ - \normalsize{v''' + release + '''} + \Large{\textbf{""" + + project + + """}}\\ + \normalsize{v""" + + release + + """} } \posttitle{\end{center}} -''' +""" } -sys.path.insert(0, os.path.abspath('../')) -sys.path.append(os.path.abspath('./_extensions')) -sys.path.append(os.path.abspath('./mock_imports')) +sys.path.insert(0, os.path.abspath("../")) +sys.path.append(os.path.abspath("./_extensions")) +sys.path.append(os.path.abspath("./mock_imports")) diff --git a/docs/mock_imports/geometry_msgs/msg/__init__.py b/docs/mock_imports/geometry_msgs/msg/__init__.py index c032b16..11e57eb 100644 --- a/docs/mock_imports/geometry_msgs/msg/__init__.py +++ b/docs/mock_imports/geometry_msgs/msg/__init__.py @@ -5,4 +5,4 @@ def __init__(self, *args, **kwargs): class Quaternion(object): def __init__(self, *args, **kwargs): - pass \ No newline at end of file + pass diff --git a/docs/mock_imports/rclpy/__init__.py b/docs/mock_imports/rclpy/__init__.py index 852e4aa..cc0f15f 100644 --- a/docs/mock_imports/rclpy/__init__.py +++ b/docs/mock_imports/rclpy/__init__.py @@ -1,3 +1,3 @@ class Parameter(object): def __init__(self, *args, **kwargs): - pass \ No newline at end of file + pass diff --git a/docs/mock_imports/std_msgs/msg/__init__.py b/docs/mock_imports/std_msgs/msg/__init__.py index 2b53c11..5862ec4 100644 --- a/docs/mock_imports/std_msgs/msg/__init__.py +++ b/docs/mock_imports/std_msgs/msg/__init__.py @@ -1,3 +1,3 @@ class String(object): def __init__(self, *args, **kwargs): - pass \ No newline at end of file + pass diff --git a/examples/applications/affective_signaling_multirobot.py b/examples/applications/affective_signaling_multirobot.py index 36cbdcc..d900623 100644 --- a/examples/applications/affective_signaling_multirobot.py +++ b/examples/applications/affective_signaling_multirobot.py @@ -19,93 +19,210 @@ class ExperimentController(MiddlewareCommunicator): def __init__(self, **kwargs): super(ExperimentController, self).__init__() - @MiddlewareCommunicator.register("NativeObject", "$_mware", "ExperimentController", - "/control_interface/facial_expressions_esr9", should_wait=False) + @MiddlewareCommunicator.register( + "NativeObject", + "$_mware", + "ExperimentController", + "/control_interface/facial_expressions_esr9", + should_wait=False, + ) def listen_facial_expressions_esr9(self, _mware=DEFAULT_COMMUNICATOR): - return None, + return (None,) - @MiddlewareCommunicator.register("NativeObject", "$_mware", "ExperimentController", - "/control_interface/facial_expressions_icub", should_wait=False) + @MiddlewareCommunicator.register( + "NativeObject", + "$_mware", + "ExperimentController", + "/control_interface/facial_expressions_icub", + should_wait=False, + ) def publish_facial_expressions_icub(self, obj, _mware="yarp"): - return obj, + return (obj,) - @MiddlewareCommunicator.register("NativeObject", "$_mware", "ExperimentController", - "/control_interface/facial_expressions_pepper", - carrier="tcp", should_wait=False) + @MiddlewareCommunicator.register( + "NativeObject", + "$_mware", + "ExperimentController", + "/control_interface/facial_expressions_pepper", + carrier="tcp", + should_wait=False, + ) def publish_facial_expressions_pepper(self, obj, _mware="ros"): - return obj, - - @MiddlewareCommunicator.register("Image", "$_mware", "ExperimentController", "$_topic", - width="$_width", height="$_height", rgb=True, fp=False, should_wait=False, - jpg=True) - def listen_image_webcam(self, _mware=DEFAULT_COMMUNICATOR, _width=WEBCAM_WIDTH, _height=WEBCAM_HEIGHT, - _topic="/control_interface/image_webcam"): # dynamic topic according to corresponding API - return None, - - @MiddlewareCommunicator.register("Image", "$_mware", "ExperimentController", "$_topic", - width="$_width", height="$_height", rgb=True, fp=False, should_wait=False) - def listen_image_pepper_cam(self, _mware="ros", _width=PEPPER_CAM_WIDTH, _height=PEPPER_CAM_HEIGHT, - _topic="/pepper/camera/front/camera/image_raw"): - return None, - - @MiddlewareCommunicator.register("Image", "$_mware", "ExperimentController", "$_topic", - width="$_width", height="$_height", rgb=True, fp=False, should_wait=False) - def listen_image_icub_cam(self, _mware="yarp", _width=ICUB_CAM_WIDTH, _height=ICUB_CAM_HEIGHT, - _topic="/icub/cam/right"): - return None, - - @MiddlewareCommunicator.register("Image", "$_mware", "ExperimentController", - "/control_interface/image_webcam", - width="$_width", height="$_height", rgb=True, fp=False, should_wait=False) - def forward_image_webcam(self, img, _mware=DEFAULT_COMMUNICATOR, _width=WEBCAM_WIDTH, - _height=WEBCAM_HEIGHT): - return img, - - @MiddlewareCommunicator.register("Image", "$_mware", "ExperimentController", - "/control_interface/image_pepper_cam", - width="$_width", height="$_height", rgb=True, fp=False, jpg=True, - should_wait=False) - def forward_image_pepper_cam(self, img, _mware=DEFAULT_COMMUNICATOR, _width=PEPPER_CAM_WIDTH, - _height=PEPPER_CAM_HEIGHT): - return img, - - @MiddlewareCommunicator.register("Image", "$_mware", "ExperimentController", - "/control_interface/image_icub_cam", - width="$_width", height="$_height", rgb=True, fp=False, jpg=True, - should_wait=False) - def forward_image_icub_cam(self, img, _mware=DEFAULT_COMMUNICATOR, _width=ICUB_CAM_WIDTH, - _height=ICUB_CAM_HEIGHT): + return (obj,) + + @MiddlewareCommunicator.register( + "Image", + "$_mware", + "ExperimentController", + "$_topic", + width="$_width", + height="$_height", + rgb=True, + fp=False, + should_wait=False, + jpg=True, + ) + def listen_image_webcam( + self, + _mware=DEFAULT_COMMUNICATOR, + _width=WEBCAM_WIDTH, + _height=WEBCAM_HEIGHT, + _topic="/control_interface/image_webcam", + ): # dynamic topic according to corresponding API + return (None,) + + @MiddlewareCommunicator.register( + "Image", + "$_mware", + "ExperimentController", + "$_topic", + width="$_width", + height="$_height", + rgb=True, + fp=False, + should_wait=False, + ) + def listen_image_pepper_cam( + self, + _mware="ros", + _width=PEPPER_CAM_WIDTH, + _height=PEPPER_CAM_HEIGHT, + _topic="/pepper/camera/front/camera/image_raw", + ): + return (None,) + + @MiddlewareCommunicator.register( + "Image", + "$_mware", + "ExperimentController", + "$_topic", + width="$_width", + height="$_height", + rgb=True, + fp=False, + should_wait=False, + ) + def listen_image_icub_cam( + self, + _mware="yarp", + _width=ICUB_CAM_WIDTH, + _height=ICUB_CAM_HEIGHT, + _topic="/icub/cam/right", + ): + return (None,) + + @MiddlewareCommunicator.register( + "Image", + "$_mware", + "ExperimentController", + "/control_interface/image_webcam", + width="$_width", + height="$_height", + rgb=True, + fp=False, + should_wait=False, + ) + def forward_image_webcam( + self, + img, + _mware=DEFAULT_COMMUNICATOR, + _width=WEBCAM_WIDTH, + _height=WEBCAM_HEIGHT, + ): + return (img,) + + @MiddlewareCommunicator.register( + "Image", + "$_mware", + "ExperimentController", + "/control_interface/image_pepper_cam", + width="$_width", + height="$_height", + rgb=True, + fp=False, + jpg=True, + should_wait=False, + ) + def forward_image_pepper_cam( + self, + img, + _mware=DEFAULT_COMMUNICATOR, + _width=PEPPER_CAM_WIDTH, + _height=PEPPER_CAM_HEIGHT, + ): + return (img,) + + @MiddlewareCommunicator.register( + "Image", + "$_mware", + "ExperimentController", + "/control_interface/image_icub_cam", + width="$_width", + height="$_height", + rgb=True, + fp=False, + jpg=True, + should_wait=False, + ) + def forward_image_icub_cam( + self, + img, + _mware=DEFAULT_COMMUNICATOR, + _width=ICUB_CAM_WIDTH, + _height=ICUB_CAM_HEIGHT, + ): # convert to bgr if img is not None: img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - return img, + return (img,) - @MiddlewareCommunicator.register("Image", "$_mware", "ExperimentController", - "/control_interface/image_esr9", - width="$_width", height="$_height", rgb=True, fp=False, should_wait=False, jpg=True) - def publish_esr9_cam(self, img, _mware=DEFAULT_COMMUNICATOR, _width=ESR_CAM_WIDTH, _height=ESR_CAM_HEIGHT): + @MiddlewareCommunicator.register( + "Image", + "$_mware", + "ExperimentController", + "/control_interface/image_esr9", + width="$_width", + height="$_height", + rgb=True, + fp=False, + should_wait=False, + jpg=True, + ) + def publish_esr9_cam( + self, + img, + _mware=DEFAULT_COMMUNICATOR, + _width=ESR_CAM_WIDTH, + _height=ESR_CAM_HEIGHT, + ): # modify size of image if img is not None: img = cv2.resize(img, (self.ESR_CAM_WIDTH, self.ESR_CAM_HEIGHT)) - return img, + return (img,) def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument("--wrapyfi_cfg", - help="File to load Wrapyfi configs for running instance. " - "Choose one of the configs available in " - "./wrapyfi_configs/affective_signaling_multirobot " - "for each running instance. All configs with a prefix of COMP must " - "have corresponding instances running. OPT prefixed configs execute " - "scripts on a machine connected to either robot (Pepper/iCub) or both, " - "and at least one must run in addition to COMP.", - type=str) - parser.add_argument("--cam_source", - help="The camera input source being either from a " - "webcam, Pepper, or iCub. Note that this must be similar for all running " - "instances even when the two robots are set to display the expressions.", - type=str, default="webcam", choices=["webcam", "pepper", "icub"]) + parser.add_argument( + "--wrapyfi_cfg", + help="File to load Wrapyfi configs for running instance. " + "Choose one of the configs available in " + "./wrapyfi_configs/affective_signaling_multirobot " + "for each running instance. All configs with a prefix of COMP must " + "have corresponding instances running. OPT prefixed configs execute " + "scripts on a machine connected to either robot (Pepper/iCub) or both, " + "and at least one must run in addition to COMP.", + type=str, + ) + parser.add_argument( + "--cam_source", + help="The camera input source being either from a " + "webcam, Pepper, or iCub. Note that this must be similar for all running " + "instances even when the two robots are set to display the expressions.", + type=str, + default="webcam", + choices=["webcam", "pepper", "icub"], + ) return parser.parse_args() @@ -119,8 +236,8 @@ def parse_args(): image_cam = None while True: if args.cam_source == "webcam": - image_webcam, = ec.listen_image_webcam() - image_webcam_linked, = ec.forward_image_webcam(image_webcam) + (image_webcam,) = ec.listen_image_webcam() + (image_webcam_linked,) = ec.forward_image_webcam(image_webcam) if image_webcam is None: image_webcam = image_webcam_linked image_cam = image_webcam @@ -128,8 +245,8 @@ def parse_args(): cv2.imshow("Webcam image", image_webcam) cv2.waitKey(1) if args.cam_source == "pepper": - image_pepper_cam, = ec.listen_image_pepper_cam() - image_pepper_cam_linked, = ec.forward_image_pepper_cam(image_pepper_cam) + (image_pepper_cam,) = ec.listen_image_pepper_cam() + (image_pepper_cam_linked,) = ec.forward_image_pepper_cam(image_pepper_cam) if image_pepper_cam is None: image_pepper_cam = image_pepper_cam_linked image_cam = image_pepper_cam @@ -137,8 +254,8 @@ def parse_args(): cv2.imshow("Pepper image", image_pepper_cam) cv2.waitKey(1) if args.cam_source == "icub": - image_icub_cam, = ec.listen_image_icub_cam() - image_icub_cam_linked, = ec.forward_image_icub_cam(image_icub_cam) + (image_icub_cam,) = ec.listen_image_icub_cam() + (image_icub_cam_linked,) = ec.forward_image_icub_cam(image_icub_cam) if image_icub_cam is None: image_icub_cam = image_icub_cam_linked image_cam = image_icub_cam @@ -146,9 +263,8 @@ def parse_args(): cv2.imshow("iCub image", image_icub_cam_linked) cv2.waitKey(1) - image_esr, = ec.publish_esr9_cam(image_cam) - facial_expression, = ec.listen_facial_expressions_esr9() + (image_esr,) = ec.publish_esr9_cam(image_cam) + (facial_expression,) = ec.listen_facial_expressions_esr9() if facial_expression is not None: ec.publish_facial_expressions_icub(facial_expression) ec.publish_facial_expressions_pepper(facial_expression) - diff --git a/examples/applications/gaze_mirroring_multisensor.py b/examples/applications/gaze_mirroring_multisensor.py index f5181b5..30eca32 100644 --- a/examples/applications/gaze_mirroring_multisensor.py +++ b/examples/applications/gaze_mirroring_multisensor.py @@ -17,46 +17,103 @@ class ExperimentController(MiddlewareCommunicator): def __init__(self, **kwargs): super(ExperimentController, self).__init__() - @MiddlewareCommunicator.register("NativeObject", "$_sixdrepnet_mware", "ExperimentController", - "/control_interface/orientation_sixdrepnet", should_wait=False) - @MiddlewareCommunicator.register("NativeObject", "$_waveshareimu_mware", "ExperimentController", - "/control_interface/orientation_waveshareimu", should_wait=False) - def listen_orientation_sixdrepnet_waveshareimu(self, _sixdrepnet_mware="yarp", _waveshareimu_mware="yarp"): - return None, None, - - @MiddlewareCommunicator.register("NativeObject", "$_mware", "ExperimentController", - "/control_interface/orientation_icub", should_wait=False) + @MiddlewareCommunicator.register( + "NativeObject", + "$_sixdrepnet_mware", + "ExperimentController", + "/control_interface/orientation_sixdrepnet", + should_wait=False, + ) + @MiddlewareCommunicator.register( + "NativeObject", + "$_waveshareimu_mware", + "ExperimentController", + "/control_interface/orientation_waveshareimu", + should_wait=False, + ) + def listen_orientation_sixdrepnet_waveshareimu( + self, _sixdrepnet_mware="yarp", _waveshareimu_mware="yarp" + ): + return ( + None, + None, + ) + + @MiddlewareCommunicator.register( + "NativeObject", + "$_mware", + "ExperimentController", + "/control_interface/orientation_icub", + should_wait=False, + ) def publish_orientation_icub(self, obj, _mware="yarp"): - return obj, - - @MiddlewareCommunicator.register("NativeObject", "$_mware", "ExperimentController", - "/control_interface/gaze_pupil", should_wait=False) + return (obj,) + + @MiddlewareCommunicator.register( + "NativeObject", + "$_mware", + "ExperimentController", + "/control_interface/gaze_pupil", + should_wait=False, + ) def listen_gaze_pupil(self, _mware="zeromq"): - return None, - - @MiddlewareCommunicator.register("NativeObject", "$_mware", "ExperimentController", - "/control_interface/gaze_icub", should_wait=False) + return (None,) + + @MiddlewareCommunicator.register( + "NativeObject", + "$_mware", + "ExperimentController", + "/control_interface/gaze_icub", + should_wait=False, + ) def publish_gaze_icub(self, obj, _mware="yarp"): - return obj, - - - @MiddlewareCommunicator.register("Image", "$_mware", "ExperimentController", - "/control_interface/image_webcam", - width="$_width", height="$_height", rgb=True, fp=False, should_wait=False, - jpg=True) - def listen_image_webcam(self, _mware="ros2", _width=WEBCAM_WIDTH, _height=WEBCAM_HEIGHT): - return None, - - @MiddlewareCommunicator.register("Image", "$_mware", "ExperimentController", - "/control_interface/image_sixdrepnet", - width="$_width", height="$_height", rgb=True, fp=False, should_wait=False) - def publish_sixdrepnet_cam(self, img, _mware="ros2", _width=SIXDREPNET_CAM_WIDTH, _height=SIXDREPNET_CAM_HEIGHT): + return (obj,) + + @MiddlewareCommunicator.register( + "Image", + "$_mware", + "ExperimentController", + "/control_interface/image_webcam", + width="$_width", + height="$_height", + rgb=True, + fp=False, + should_wait=False, + jpg=True, + ) + def listen_image_webcam( + self, _mware="ros2", _width=WEBCAM_WIDTH, _height=WEBCAM_HEIGHT + ): + return (None,) + + @MiddlewareCommunicator.register( + "Image", + "$_mware", + "ExperimentController", + "/control_interface/image_sixdrepnet", + width="$_width", + height="$_height", + rgb=True, + fp=False, + should_wait=False, + ) + def publish_sixdrepnet_cam( + self, + img, + _mware="ros2", + _width=SIXDREPNET_CAM_WIDTH, + _height=SIXDREPNET_CAM_HEIGHT, + ): # modify size of image if img is not None: - img = cv2.resize(img, (self.SIXDREPNET_CAM_WIDTH, self.SIXDREPNET_CAM_HEIGHT)) - return img, - - def priority_control_sources(self, orientation_sixdrepnet, orientation_imu, control_sources): + img = cv2.resize( + img, (self.SIXDREPNET_CAM_WIDTH, self.SIXDREPNET_CAM_HEIGHT) + ) + return (img,) + + def priority_control_sources( + self, orientation_sixdrepnet, orientation_imu, control_sources + ): if control_sources[0] == "vision": if orientation_sixdrepnet is not None: # print("Orientation 6DRepNet: ", orientation_sixdrepnet) @@ -72,27 +129,38 @@ def priority_control_sources(self, orientation_sixdrepnet, orientation_imu, cont # print("Orientation 6DRepNet: ", orientation_sixdrepnet) self.publish_orientation_icub(orientation_sixdrepnet) + def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument("--wrapyfi_cfg", - help="File to load Wrapyfi configs for running instance. " - "Choose one of the configs available in " - "./wrapyfi_configs/gaze_mirroring_multisensor. " - "for each running instance. All configs with a prefix of COMP must " - "have corresponding instances running. OPT prefixed config is optional " - "(only when using a vision model) executes the script on a machine connected " - "to the camera (or has access to the camera topic).", - type=str) - parser.add_argument("--control_sources", - help="Control sources to use for the experiment. The order of sources indicates the priority. " - "For example, if vision is the first source, then the vision source will " - "be used for control. If vision is not available, then the IMU source will be used. " - "If one source is provided, then the experiment will run with that source. ", - type=str, default=["vision", "imu"], nargs="+", choices=["vision", "imu"]) - parser.add_argument("--enable_gaze", - help="Enable the gaze (eye movement) control. If not enabled, " - "then the gaze control will not be used. ", - action="store_true", default=False) + parser.add_argument( + "--wrapyfi_cfg", + help="File to load Wrapyfi configs for running instance. " + "Choose one of the configs available in " + "./wrapyfi_configs/gaze_mirroring_multisensor. " + "for each running instance. All configs with a prefix of COMP must " + "have corresponding instances running. OPT prefixed config is optional " + "(only when using a vision model) executes the script on a machine connected " + "to the camera (or has access to the camera topic).", + type=str, + ) + parser.add_argument( + "--control_sources", + help="Control sources to use for the experiment. The order of sources indicates the priority. " + "For example, if vision is the first source, then the vision source will " + "be used for control. If vision is not available, then the IMU source will be used. " + "If one source is provided, then the experiment will run with that source. ", + type=str, + default=["vision", "imu"], + nargs="+", + choices=["vision", "imu"], + ) + parser.add_argument( + "--enable_gaze", + help="Enable the gaze (eye movement) control. If not enabled, " + "then the gaze control will not be used. ", + action="store_true", + default=False, + ) return parser.parse_args() @@ -107,8 +175,8 @@ def parse_args(): image_cam = None while True: if "vision" in args.control_sources: - image_webcam, = ec.listen_image_webcam() - image_webcam_linked, = ec.publish_sixdrepnet_cam(image_cam) + (image_webcam,) = ec.listen_image_webcam() + (image_webcam_linked,) = ec.publish_sixdrepnet_cam(image_cam) if image_webcam is None: image_webcam = image_webcam_linked image_cam = image_webcam @@ -116,15 +184,15 @@ def parse_args(): cv2.imshow("Webcam image", image_webcam) cv2.waitKey(1) - orientation_sixdrepnet, orientation_imu = ec.listen_orientation_sixdrepnet_waveshareimu() - ec.priority_control_sources(orientation_sixdrepnet, orientation_imu, args.control_sources) + orientation_sixdrepnet, orientation_imu = ( + ec.listen_orientation_sixdrepnet_waveshareimu() + ) + ec.priority_control_sources( + orientation_sixdrepnet, orientation_imu, args.control_sources + ) if args.enable_gaze: - gaze_pupil, = ec.listen_gaze_pupil() + (gaze_pupil,) = ec.listen_gaze_pupil() if gaze_pupil is not None: # print("Gaze Pupil: ", gaze_pupil) ec.publish_gaze_icub(gaze_pupil) - - - - diff --git a/examples/communication_patterns/request_reply_example.py b/examples/communication_patterns/request_reply_example.py index 1984b27..be69234 100755 --- a/examples/communication_patterns/request_reply_example.py +++ b/examples/communication_patterns/request_reply_example.py @@ -56,16 +56,28 @@ class ReqRep(MiddlewareCommunicator): @MiddlewareCommunicator.register( - "NativeObject", "$mware", "ReqRep", "/req_rep/my_message", - carrier="tcp", persistent=True + "NativeObject", + "$mware", + "ReqRep", + "/req_rep/my_message", + carrier="tcp", + persistent=True, ) @MiddlewareCommunicator.register( - "Image", "$mware", "ReqRep", "/req_rep/my_image_message", - carrier="", width="$img_width", height="$img_height", rgb=True, jpg=True, - persistent=True + "Image", + "$mware", + "ReqRep", + "/req_rep/my_image_message", + carrier="", + width="$img_width", + height="$img_height", + rgb=True, + jpg=True, + persistent=True, ) - def send_img_message(self, msg=None, img_width=320, img_height=240, - mware=None, *args, **kwargs): + def send_img_message( + self, msg=None, img_width=320, img_height=240, mware=None, *args, **kwargs + ): """ Exchange messages with OpenCV images and other native Python objects. """ @@ -74,26 +86,62 @@ def send_img_message(self, msg=None, img_width=320, img_height=240, # read image from file img = cv2.imread("../../assets/wrapyfi.png") img = cv2.resize(img, (img_width, img_height), interpolation=cv2.INTER_AREA) - cv2.putText(img, msg, - ((img.shape[1] - cv2.getTextSize(msg, cv2.FONT_HERSHEY_SIMPLEX, 1, 2)[0][0]) // 2, - (img.shape[0] + cv2.getTextSize(msg, cv2.FONT_HERSHEY_SIMPLEX, 1, 2)[0][1]) // 2), - cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2, cv2.LINE_AA) + cv2.putText( + img, + msg, + ( + ( + img.shape[1] + - cv2.getTextSize(msg, cv2.FONT_HERSHEY_SIMPLEX, 1, 2)[0][0] + ) + // 2, + ( + img.shape[0] + + cv2.getTextSize(msg, cv2.FONT_HERSHEY_SIMPLEX, 1, 2)[0][1] + ) + // 2, + ), + cv2.FONT_HERSHEY_SIMPLEX, + 1, + (255, 0, 0), + 2, + cv2.LINE_AA, + ) return obj, img @MiddlewareCommunicator.register( - "NativeObject", "$mware", "ReqRep", "/req_rep/my_message", - carrier="tcp", persistent=True + "NativeObject", + "$mware", + "ReqRep", + "/req_rep/my_message", + carrier="tcp", + persistent=True, ) @MiddlewareCommunicator.register( - "AudioChunk", "$mware", "ReqRep", "/req_rep/my_audio_message", - carrier="", rate="$aud_rate", chunk="$aud_chunk", channels="$aud_channels", - persistent=True + "AudioChunk", + "$mware", + "ReqRep", + "/req_rep/my_audio_message", + carrier="", + rate="$aud_rate", + chunk="$aud_chunk", + channels="$aud_channels", + persistent=True, ) - def send_aud_message(self, msg=None, - aud_rate=-1, aud_chunk=-1, aud_channels=2, - mware=None, *args, **kwargs): - """Exchange messages with sounddevice audio chunks and other native Python objects.""" + def send_aud_message( + self, + msg=None, + aud_rate=-1, + aud_chunk=-1, + aud_channels=2, + mware=None, + *args, + **kwargs, + ): + """ + Exchange messages with sounddevice audio chunks and other native Python objects. + """ obj = {"message": msg, "args": args, "kwargs": kwargs} # read audio from file aud = sf.read("../../assets/sound_test.wav", dtype="float32") @@ -102,32 +150,46 @@ def send_aud_message(self, msg=None, def parse_args(): - """Parse command line arguments.""" - parser = argparse.ArgumentParser(description="A message requester and replier for native Python objects, images using OpenCV, and sound using PortAudio.") + """ + Parse command line arguments. + """ + parser = argparse.ArgumentParser( + description="A message requester and replier for native Python objects, images using OpenCV, and sound using PortAudio." + ) parser.add_argument( - "--mode", type=str, default="request", + "--mode", + type=str, + default="request", choices={"request", "reply"}, - help="The mode of communication, either 'request' or 'reply'" + help="The mode of communication, either 'request' or 'reply'", ) parser.add_argument( - "--mware", type=str, default=DEFAULT_COMMUNICATOR, + "--mware", + type=str, + default=DEFAULT_COMMUNICATOR, choices=MiddlewareCommunicator.get_communicators(), - help="The middleware to use for transmission" + help="The middleware to use for transmission", ) parser.add_argument( - "--sound_device", type=int, default=0, - help="The sound device to use for audio playback" + "--sound_device", + type=int, + default=0, + help="The sound device to use for audio playback", ) parser.add_argument( - "--list_sound_devices", action="store_true", - help="List all available sound devices and exit" + "--list_sound_devices", + action="store_true", + help="List all available sound devices and exit", ) parser.add_argument( - "--stream", type=str, default="image", choices={"image", "audio"}, - help="The streamed data as either 'image' or 'audio'" + "--stream", + type=str, + default="image", + choices={"image", "audio"}, + help="The streamed data as either 'image' or 'audio'", ) return parser.parse_args() @@ -145,7 +207,9 @@ def sound_play(my_aud, blocking=True, device=0): sd.play(*my_aud, blocking=blocking, device=device) return True except sd.PortAudioError: - logging.warning("PortAudioError: No device is found or the device is already in use. Will try again in 3 seconds.") + logging.warning( + "PortAudioError: No device is found or the device is already in use. Will try again in 3 seconds." + ) return False @@ -154,10 +218,17 @@ def main(args): print(sd.query_devices()) return - """Main function to initiate ReqRep class and communication.""" + """ + + Main function to initiate ReqRep class and communication. + """ if args.mode == "request" and args.mware == "zeromq": - print("WE INTENTIONALLY WAIT 5 SECONDS TO ALLOW THE REPLIER ENOUGH TIME TO START UP. ") - print("THIS IS NEEDED WHEN USING ZEROMQ AS THE COMMUNICATION MIDDLEWARE IF THE SERVER ") + print( + "WE INTENTIONALLY WAIT 5 SECONDS TO ALLOW THE REPLIER ENOUGH TIME TO START UP. " + ) + print( + "THIS IS NEEDED WHEN USING ZEROMQ AS THE COMMUNICATION MIDDLEWARE IF THE SERVER " + ) print("IS SET TO SPAWN A PROXY BROKER (DEFAULT).") time.sleep(5) @@ -175,8 +246,13 @@ def main(args): # but this separation is NOT necessary for the method to work if args.mode == "request": msg = input("Type your message: ") - my_message, my_image = req_rep.send_img_message(msg, counter=counter, mware=args.mware) - my_message2, my_aud, = req_rep.send_aud_message(msg, counter=counter, mware=args.mware) + my_message, my_image = req_rep.send_img_message( + msg, counter=counter, mware=args.mware + ) + ( + my_message2, + my_aud, + ) = req_rep.send_aud_message(msg, counter=counter, mware=args.mware) my_message = my_message2 if my_message2 is not None else my_message counter += 1 if my_message is not None: @@ -186,7 +262,9 @@ def main(args): cv2.imshow("Received image", my_image) while True: k = cv2.waitKey(1) & 0xFF - if not (cv2.getWindowProperty("Received image", cv2.WND_PROP_VISIBLE)): + if not ( + cv2.getWindowProperty("Received image", cv2.WND_PROP_VISIBLE) + ): break if cv2.waitKey(1) == 27: @@ -204,7 +282,10 @@ def main(args): # The send_message() only executes in "reply" mode, # meaning, the method is only accessible from this code block my_message, my_image = req_rep.send_img_message(mware=args.mware) - my_message2, my_aud, = req_rep.send_aud_message(mware=args.mware) + ( + my_message2, + my_aud, + ) = req_rep.send_aud_message(mware=args.mware) my_message = my_message2 if my_message2 is not None else my_message if my_message is not None: print("Reply: received reply:", my_message) @@ -229,4 +310,4 @@ def main(args): if __name__ == "__main__": args = parse_args() - main(args) \ No newline at end of file + main(args) diff --git a/examples/communication_schemes/channeling_example.py b/examples/communication_schemes/channeling_example.py index 8295161..ed8bf9e 100644 --- a/examples/communication_schemes/channeling_example.py +++ b/examples/communication_schemes/channeling_example.py @@ -42,55 +42,100 @@ class ChannelingCls(MiddlewareCommunicator): @MiddlewareCommunicator.register( - "NativeObject", "$mware_A", "ChannelingCls", "/example/native_A_msg", - carrier="mcast", should_wait=True + "NativeObject", + "$mware_A", + "ChannelingCls", + "/example/native_A_msg", + carrier="mcast", + should_wait=True, ) @MiddlewareCommunicator.register( - "Image", "$mware_B", "ChannelingCls", "/example/image_B_msg", - carrier="tcp", width="$img_width", height="$img_height", rgb=True, should_wait=False + "Image", + "$mware_B", + "ChannelingCls", + "/example/image_B_msg", + carrier="tcp", + width="$img_width", + height="$img_height", + rgb=True, + should_wait=False, ) @MiddlewareCommunicator.register( - "AudioChunk", "$mware_C", "ChannelingCls", "/example/audio_C_msg", - carrier="tcp", rate="$aud_rate", chunk="$aud_chunk", channels="$aud_channels", should_wait=False + "AudioChunk", + "$mware_C", + "ChannelingCls", + "/example/audio_C_msg", + carrier="tcp", + rate="$aud_rate", + chunk="$aud_chunk", + channels="$aud_channels", + should_wait=False, ) - def read_mulret_mulmware(self, img_width=200, img_height=200, - aud_rate=44100, aud_chunk=8820, aud_channels=1, - mware_A=None, mware_B=None, mware_C=None): - """Read and forward messages through channels A, B, and C.""" - ros_img = np.random.randint(256, size=(img_height, img_width, 3), dtype=np.uint8) - zeromq_aud = (np.random.uniform(-1, 1, aud_chunk), aud_rate,) + def read_mulret_mulmware( + self, + img_width=200, + img_height=200, + aud_rate=44100, + aud_chunk=8820, + aud_channels=1, + mware_A=None, + mware_B=None, + mware_C=None, + ): + """ + Read and forward messages through channels A, B, and C. + """ + ros_img = np.random.randint( + 256, size=(img_height, img_width, 3), dtype=np.uint8 + ) + zeromq_aud = ( + np.random.uniform(-1, 1, aud_chunk), + aud_rate, + ) yarp_native = [ros_img, zeromq_aud] return yarp_native, ros_img, zeromq_aud def parse_args(): - """Parse command line arguments.""" + """ + Parse command line arguments. + """ parser = argparse.ArgumentParser(description="Channeling Example using Wrapyfi.") parser.add_argument( - "--mode", type=str, default="publish", + "--mode", + type=str, + default="publish", choices={"publish", "listen"}, - help="The transmission mode" + help="The transmission mode", ) parser.add_argument( - "--mware_A", type=str, default="none", + "--mware_A", + type=str, + default="none", choices=MiddlewareCommunicator.get_communicators().update({"none"}), - help="The middleware to use for transmission of channel A" + help="The middleware to use for transmission of channel A", ) parser.add_argument( - "--mware_B", type=str, default="none", + "--mware_B", + type=str, + default="none", choices=MiddlewareCommunicator.get_communicators().update({"none"}), - help="The middleware to use for transmission of channel B" + help="The middleware to use for transmission of channel B", ) parser.add_argument( - "--mware_C", type=str, default="none", + "--mware_C", + type=str, + default="none", choices=MiddlewareCommunicator.get_communicators().update({"none"}), - help="The middleware to use for transmission of channel C" + help="The middleware to use for transmission of channel C", ) return parser.parse_args() def main(args): - """Main function to initiate ChannelingCls class and communication.""" + """ + Main function to initiate ChannelingCls class and communication. + """ channeling = ChannelingCls() channeling.activate_communication(channeling.read_mulret_mulmware, mode=args.mode) diff --git a/examples/communication_schemes/forwarding_example.py b/examples/communication_schemes/forwarding_example.py index 0d8cebe..f49dd1e 100644 --- a/examples/communication_schemes/forwarding_example.py +++ b/examples/communication_schemes/forwarding_example.py @@ -36,62 +36,95 @@ class ForwardCls(MiddlewareCommunicator): @MiddlewareCommunicator.register( - 'NativeObject', '$mware_chain_A', 'ForwardCls', '/example/native_chain_A_msg', - carrier='mcast', should_wait=True) - def read_chain_A(self, mware_chain_A=None, msg=''): - """Read and forward message from chain A.""" - return msg, + "NativeObject", + "$mware_chain_A", + "ForwardCls", + "/example/native_chain_A_msg", + carrier="mcast", + should_wait=True, + ) + def read_chain_A(self, mware_chain_A=None, msg=""): + """ + Read and forward message from chain A. + """ + return (msg,) @MiddlewareCommunicator.register( - 'NativeObject', '$mware_chain_B', 'ForwardCls', '/example/native_chain_B_msg', - carrier='tcp', should_wait=False) - def read_chain_B(self, mware_chain_B=None, msg=''): - """Read and forward message from chain B.""" - return msg, + "NativeObject", + "$mware_chain_B", + "ForwardCls", + "/example/native_chain_B_msg", + carrier="tcp", + should_wait=False, + ) + def read_chain_B(self, mware_chain_B=None, msg=""): + """ + Read and forward message from chain B. + """ + return (msg,) + def parse_args(): - """Parse command line arguments.""" + """ + Parse command line arguments. + """ parser = argparse.ArgumentParser(description="Forwarding Example using Wrapyfi.") parser.add_argument( - "--mode_chain_A", type=str, default="publish", + "--mode_chain_A", + type=str, + default="publish", choices=["listen", "publish", "disable", "none", None], - help="The mode of transmission for the first method in the chain" + help="The mode of transmission for the first method in the chain", ) parser.add_argument( - "--mode_chain_B", type=str, default="listen", + "--mode_chain_B", + type=str, + default="listen", choices=["listen", "publish", "disable", "none", None], - help="The mode of transmission for the second method in the chain" + help="The mode of transmission for the second method in the chain", ) parser.add_argument( - "--mware_chain_A", type=str, default=DEFAULT_COMMUNICATOR, + "--mware_chain_A", + type=str, + default=DEFAULT_COMMUNICATOR, choices=MiddlewareCommunicator.get_communicators(), - help="The middleware to use for transmission of the first method in the chain" + help="The middleware to use for transmission of the first method in the chain", ) parser.add_argument( - "--mware_chain_B", type=str, default=DEFAULT_COMMUNICATOR, + "--mware_chain_B", + type=str, + default=DEFAULT_COMMUNICATOR, choices=MiddlewareCommunicator.get_communicators(), - help="The middleware to use for transmission of the second method in the chain" + help="The middleware to use for transmission of the second method in the chain", ) return parser.parse_args() def main(args): - """Main function to initiate ForwardCls class and communication.""" + """ + Main function to initiate ForwardCls class and communication. + """ forward = ForwardCls() forward.activate_communication(forward.read_chain_A, mode=args.mode_chain_A) forward.activate_communication(forward.read_chain_B, mode=args.mode_chain_B) while True: - msg, = forward.read_chain_A(mware_chain_A=args.mware_chain_A, - msg=f"This argument message was sent from read_chain_A transmitted over " - f"{args.mware_chain_A}") + (msg,) = forward.read_chain_A( + mware_chain_A=args.mware_chain_A, + msg=f"This argument message was sent from read_chain_A transmitted over " + f"{args.mware_chain_A}", + ) if msg is not None: print(msg) - msg, = forward.read_chain_B(mware_chain_B=args.mware_chain_B, - msg=f"{msg}. It was then forwarded to read_chain_B over {args.mware_chain_B}") + (msg,) = forward.read_chain_B( + mware_chain_B=args.mware_chain_B, + msg=f"{msg}. It was then forwarded to read_chain_B over {args.mware_chain_B}", + ) if msg is not None: if args.mode_chain_B == "listen": - print(f"{msg}. This message is the last in the chain received over {args.mware_chain_B}") + print( + f"{msg}. This message is the last in the chain received over {args.mware_chain_B}" + ) else: print(msg) time.sleep(0.1) diff --git a/examples/communication_schemes/mirroring_example.py b/examples/communication_schemes/mirroring_example.py index 8ba09c0..06bd089 100644 --- a/examples/communication_schemes/mirroring_example.py +++ b/examples/communication_schemes/mirroring_example.py @@ -41,37 +41,56 @@ class MirrorCls(MiddlewareCommunicator): @MiddlewareCommunicator.register( - 'NativeObject', '$mware', 'MirrorCls', - '/example/read_msg', - carrier='tcp', should_wait='$blocking') - def read_msg(self, mware=None, msg='', blocking=True): - """Exchange messages and mirror user input.""" - msg_ip = input('Type message: ') - obj = {'msg': msg, 'msg_ip': msg_ip} - return obj, + "NativeObject", + "$mware", + "MirrorCls", + "/example/read_msg", + carrier="tcp", + should_wait="$blocking", + ) + def read_msg(self, mware=None, msg="", blocking=True): + """ + Exchange messages and mirror user input. + """ + msg_ip = input("Type message: ") + obj = {"msg": msg, "msg_ip": msg_ip} + return (obj,) + def parse_args(): - """Parse command line arguments.""" + """ + Parse command line arguments. + """ parser = argparse.ArgumentParser(description="Mirroring Example using Wrapyfi.") parser.add_argument( - "--mode", type=str, default="listen", + "--mode", + type=str, + default="listen", choices={"publish", "listen", "request", "reply"}, - help="The communication mode (publish, listen, request, reply)" + help="The communication mode (publish, listen, request, reply)", ) parser.add_argument( - "--mware", type=str, default=DEFAULT_COMMUNICATOR, + "--mware", + type=str, + default=DEFAULT_COMMUNICATOR, choices=MiddlewareCommunicator.get_communicators(), - help="The middleware to use for transmission" + help="The middleware to use for transmission", ) return parser.parse_args() + def main(args): - """Main function to initiate MirrorCls class and communication.""" + """ + Main function to initiate MirrorCls class and communication. + """ mirror = MirrorCls() mirror.activate_communication(MirrorCls.read_msg, mode=args.mode) while True: - msg_object, = mirror.read_msg(mware=args.mware, msg=f"This argument message was sent by the {args.mode} script") + (msg_object,) = mirror.read_msg( + mware=args.mware, + msg=f"This argument message was sent by the {args.mode} script", + ) if msg_object is not None: print(msg_object) diff --git a/examples/custom_msgs/ros2_message_example.py b/examples/custom_msgs/ros2_message_example.py index 0201720..49a3239 100644 --- a/examples/custom_msgs/ros2_message_example.py +++ b/examples/custom_msgs/ros2_message_example.py @@ -35,15 +35,23 @@ class Notifier(MiddlewareCommunicator): @MiddlewareCommunicator.register( - "ROS2Message", "ros2", "Notifier", "/notify/test_ros2_msg_str_exchange", - should_wait=True + "ROS2Message", + "ros2", + "Notifier", + "/notify/test_ros2_msg_str_exchange", + should_wait=True, ) @MiddlewareCommunicator.register( - "ROS2Message", "ros2", "Notifier", "/notify/test_ros2_msg_pose_exchange", - should_wait=True + "ROS2Message", + "ros2", + "Notifier", + "/notify/test_ros2_msg_pose_exchange", + should_wait=True, ) def send_message(self): - """Exchange ROS 2 messages over ROS 2.""" + """ + Exchange ROS 2 messages over ROS 2. + """ msg = input("Type your message: ") quat = Quaternion() quat.x = 0.1 @@ -57,18 +65,26 @@ def send_message(self): def parse_args(): - """Parse command line arguments.""" - parser = argparse.ArgumentParser(description="A message publisher and listener for ROS 2 messages using Wrapyfi.") + """ + Parse command line arguments. + """ + parser = argparse.ArgumentParser( + description="A message publisher and listener for ROS 2 messages using Wrapyfi." + ) parser.add_argument( - "--mode", type=str, default="publish", + "--mode", + type=str, + default="publish", choices={"publish", "listen"}, - help="The transmission mode" + help="The transmission mode", ) return parser.parse_args() def main(args): - """Main function to initiate Notify class and communication.""" + """ + Main function to initiate Notify class and communication. + """ ros2_message = Notifier() ros2_message.activate_communication(Notifier.send_message, mode=args.mode) diff --git a/examples/custom_msgs/ros_message_example.py b/examples/custom_msgs/ros_message_example.py index 8c2b255..1cf1694 100755 --- a/examples/custom_msgs/ros_message_example.py +++ b/examples/custom_msgs/ros_message_example.py @@ -35,15 +35,23 @@ class Notifier(MiddlewareCommunicator): @MiddlewareCommunicator.register( - "ROSMessage", "ros", "Notifier", "/notify/test_ros_msg_str_exchange", - should_wait=True + "ROSMessage", + "ros", + "Notifier", + "/notify/test_ros_msg_str_exchange", + should_wait=True, ) @MiddlewareCommunicator.register( - "ROSMessage", "ros", "Notifier", "/notify/test_ros_msg_pose_exchange", - should_wait=True + "ROSMessage", + "ros", + "Notifier", + "/notify/test_ros_msg_pose_exchange", + should_wait=True, ) def send_message(self): - """Exchange ROS messages over ROS.""" + """ + Exchange ROS messages over ROS. + """ msg = input("Type your message: ") quat = Quaternion() quat.x = 0.1 @@ -57,18 +65,26 @@ def send_message(self): def parse_args(): - """Parse command line arguments.""" - parser = argparse.ArgumentParser(description="A message publisher and listener for ROS messages using Wrapyfi.") + """ + Parse command line arguments. + """ + parser = argparse.ArgumentParser( + description="A message publisher and listener for ROS messages using Wrapyfi." + ) parser.add_argument( - "--mode", type=str, default="publish", + "--mode", + type=str, + default="publish", choices={"publish", "listen"}, - help="The transmission mode" + help="The transmission mode", ) return parser.parse_args() def main(args): - """Main function to initiate Notify class and communication.""" + """ + Main function to initiate Notify class and communication. + """ ros_message = Notifier() ros_message.activate_communication(Notifier.send_message, mode=args.mode) diff --git a/examples/custom_msgs/ros_parameter_example.py b/examples/custom_msgs/ros_parameter_example.py index 28c45c2..983e28d 100755 --- a/examples/custom_msgs/ros_parameter_example.py +++ b/examples/custom_msgs/ros_parameter_example.py @@ -32,44 +32,73 @@ class Notifier(MiddlewareCommunicator): @MiddlewareCommunicator.register( - "Properties", "ros", "Notifier", "/notify/test_property_exchange", - should_wait=True + "Properties", + "ros", + "Notifier", + "/notify/test_property_exchange", + should_wait=True, ) @MiddlewareCommunicator.register( - "Properties", "ros", "Notifier", "/notify/test_property_exchange/a", - should_wait=True + "Properties", + "ros", + "Notifier", + "/notify/test_property_exchange/a", + should_wait=True, ) @MiddlewareCommunicator.register( - "Properties", "ros", "Notifier", "/notify/test_property_exchange/e", - persistent=False, should_wait=True + "Properties", + "ros", + "Notifier", + "/notify/test_property_exchange/e", + persistent=False, + should_wait=True, ) def set_property(self): - """Exchange ROS properties over ROS.""" + """ + Exchange ROS properties over ROS. + """ ret_str = input("Type your message: ") - ret_multiprops = {"b": [1,2,3,4], "c": False, "d": 12.3} - ret_non_persistent = "Non-persistent property which should be deleted on closure" + ret_multiprops = {"b": [1, 2, 3, 4], "c": False, "d": 12.3} + ret_non_persistent = ( + "Non-persistent property which should be deleted on closure" + ) return ret_multiprops, ret_str, ret_non_persistent def parse_args(): - """Parse command line arguments.""" - parser = argparse.ArgumentParser(description="A message publisher and listener for ROS properties using Wrapyfi.") + """ + Parse command line arguments. + """ + parser = argparse.ArgumentParser( + description="A message publisher and listener for ROS properties using Wrapyfi." + ) parser.add_argument( - "--mode", type=str, default="publish", + "--mode", + type=str, + default="publish", choices={"publish", "listen"}, - help="The transmission mode" + help="The transmission mode", ) return parser.parse_args() def main(args): - """Main function to initiate Notify class and communication.""" + """ + Main function to initiate Notify class and communication. + """ ros_message = Notifier() ros_message.activate_communication(Notifier.set_property, mode=args.mode) while True: - my_dict_message, my_string_message, my_nonpersistent_message = ros_message.set_property() - print("Method result:", my_string_message, my_dict_message, my_nonpersistent_message) + my_dict_message, my_string_message, my_nonpersistent_message = ( + ros_message.set_property() + ) + print( + "Method result:", + my_string_message, + my_dict_message, + my_nonpersistent_message, + ) if __name__ == "__main__": diff --git a/examples/encoders/astropy_example.py b/examples/encoders/astropy_example.py index a010149..3dd0aff 100644 --- a/examples/encoders/astropy_example.py +++ b/examples/encoders/astropy_example.py @@ -41,57 +41,72 @@ # Modifying the WRAPYFI_PLUGINS_PATH environment variable to include the plugins directory script_dir = os.path.dirname(os.path.realpath(__file__)) -if 'WRAPYFI_PLUGINS_PATH' in os.environ: - os.environ['WRAPYFI_PLUGINS_PATH'] += os.pathsep + script_dir +if "WRAPYFI_PLUGINS_PATH" in os.environ: + os.environ["WRAPYFI_PLUGINS_PATH"] += os.pathsep + script_dir else: - os.environ['WRAPYFI_PLUGINS_PATH'] = script_dir + os.environ["WRAPYFI_PLUGINS_PATH"] = script_dir class Notifier(MiddlewareCommunicator): @MiddlewareCommunicator.register( - "NativeObject", "$mware", "Notifier", "/notify/test_astropy_exchange", - carrier="tcp", should_wait=True + "NativeObject", + "$mware", + "Notifier", + "/notify/test_astropy_exchange", + carrier="tcp", + should_wait=True, ) def exchange_object(self, mware=None): - """Exchange messages with Astropy Tables and other native Python objects.""" + """ + Exchange messages with Astropy Tables and other native Python objects. + """ msg = input("Type your message: ") # Creating an example Astropy Table t = Table() - t['name'] = ['source 1', 'source 2', 'source 3'] - t['flux'] = [1.2, 2.2, 3.1] + t["name"] = ["source 1", "source 2", "source 3"] + t["flux"] = [1.2, 2.2, 3.1] ret = { "message": msg, "astropy_table": t, } - return ret, + return (ret,) def parse_args(): - """Parse command line arguments.""" + """ + Parse command line arguments. + """ parser = argparse.ArgumentParser( - description="A message publisher and listener for native Python objects and Astropy Tables.") + description="A message publisher and listener for native Python objects and Astropy Tables." + ) parser.add_argument( - "--mode", type=str, default="publish", + "--mode", + type=str, + default="publish", choices={"publish", "listen"}, - help="The transmission mode" + help="The transmission mode", ) parser.add_argument( - "--mware", type=str, default=DEFAULT_COMMUNICATOR, + "--mware", + type=str, + default=DEFAULT_COMMUNICATOR, choices=MiddlewareCommunicator.get_communicators(), - help="The middleware to use for transmission" + help="The middleware to use for transmission", ) return parser.parse_args() def main(args): - """Main function to initiate Notifier class and communication.""" + """ + Main function to initiate Notifier class and communication. + """ notifier = Notifier() notifier.activate_communication(Notifier.exchange_object, mode=args.mode) while True: - msg_object, = notifier.exchange_object(mware=args.mware) + (msg_object,) = notifier.exchange_object(mware=args.mware) print("Method result:", msg_object) diff --git a/examples/encoders/cupy_example.py b/examples/encoders/cupy_example.py index 28af7fb..854b30c 100644 --- a/examples/encoders/cupy_example.py +++ b/examples/encoders/cupy_example.py @@ -39,44 +39,60 @@ class CuPyNotifier(MiddlewareCommunicator): @MiddlewareCommunicator.register( - "NativeObject", "$mware", "CuPyNotifier", "/notify/test_cupy_exchange", - carrier="", should_wait=True, - listener_kwargs=dict(load_cupy_device=cp.cuda.Device(0)) + "NativeObject", + "$mware", + "CuPyNotifier", + "/notify/test_cupy_exchange", + carrier="", + should_wait=True, + listener_kwargs=dict(load_cupy_device=cp.cuda.Device(0)), ) def exchange_object(self, mware=None): - """Exchange messages with CuPy tensors and other native Python objects.""" + """ + Exchange messages with CuPy tensors and other native Python objects. + """ msg = input("Type your message: ") ret = { "message": msg, "cupy_ones_cuda": cp.ones((2, 4), dtype=cp.float32), - "cupy_zeros_cuda": cp.zeros((2, 3), dtype=cp.float32) + "cupy_zeros_cuda": cp.zeros((2, 3), dtype=cp.float32), } - return ret, + return (ret,) def parse_args(): - """Parse command line arguments.""" - parser = argparse.ArgumentParser(description="A message publisher and listener for native Python objects and CuPy tensors.") + """ + Parse command line arguments. + """ + parser = argparse.ArgumentParser( + description="A message publisher and listener for native Python objects and CuPy tensors." + ) parser.add_argument( - "--mode", type=str, default="publish", + "--mode", + type=str, + default="publish", choices={"publish", "listen"}, - help="The transmission mode" + help="The transmission mode", ) parser.add_argument( - "--mware", type=str, default=DEFAULT_COMMUNICATOR, + "--mware", + type=str, + default=DEFAULT_COMMUNICATOR, choices=MiddlewareCommunicator.get_communicators(), - help="The middleware to use for transmission" + help="The middleware to use for transmission", ) return parser.parse_args() def main(args): - """Main function to initiate CuPyNotifier class and communication.""" + """ + Main function to initiate CuPyNotifier class and communication. + """ notifier = CuPyNotifier() notifier.activate_communication(CuPyNotifier.exchange_object, mode=args.mode) while True: - msg_object, = notifier.exchange_object(mware=args.mware) + (msg_object,) = notifier.exchange_object(mware=args.mware) print("Method result:", msg_object) diff --git a/examples/encoders/dask_example.py b/examples/encoders/dask_example.py index f914156..dbcc706 100644 --- a/examples/encoders/dask_example.py +++ b/examples/encoders/dask_example.py @@ -40,19 +40,28 @@ class Notifier(MiddlewareCommunicator): @MiddlewareCommunicator.register( - "NativeObject", "$mware", "Notifier", "/notify/test_dask_exchange", - carrier="tcp", should_wait=True + "NativeObject", + "$mware", + "Notifier", + "/notify/test_dask_exchange", + carrier="tcp", + should_wait=True, ) def exchange_object(self, mware=None): - """Exchange messages with Dask arrays/dataframes and other native Python objects.""" + """ + Exchange messages with Dask arrays/dataframes and other native Python objects. + """ msg = input("Type your message: ") # Creating an example Dask DataFrame - df = pd.DataFrame({ - 'num_legs': [4, 2, 0, 4], - 'num_wings': [0, 2, 0, 0], - 'num_specimen_seen': [10, 2, 1, 8] - }, index=['falcon', 'parrot', 'fish', 'dog']) + df = pd.DataFrame( + { + "num_legs": [4, 2, 0, 4], + "num_wings": [0, 2, 0, 0], + "num_specimen_seen": [10, 2, 1, 8], + }, + index=["falcon", "parrot", "fish", "dog"], + ) ddf = dd.from_pandas(df, npartitions=2) @@ -67,32 +76,42 @@ def exchange_object(self, mware=None): "dask_array": darray, "dask_series": dds, } - return ret, + return (ret,) def parse_args(): - """Parse command line arguments.""" - parser = argparse.ArgumentParser(description="A message publisher and listener for native Python objects and Dask arrays/dataframes.") + """ + Parse command line arguments. + """ + parser = argparse.ArgumentParser( + description="A message publisher and listener for native Python objects and Dask arrays/dataframes." + ) parser.add_argument( - "--mode", type=str, default="publish", + "--mode", + type=str, + default="publish", choices={"publish", "listen"}, - help="The transmission mode" + help="The transmission mode", ) parser.add_argument( - "--mware", type=str, default=DEFAULT_COMMUNICATOR, + "--mware", + type=str, + default=DEFAULT_COMMUNICATOR, choices=MiddlewareCommunicator.get_communicators(), - help="The middleware to use for transmission" + help="The middleware to use for transmission", ) return parser.parse_args() def main(args): - """Main function to initiate Notifier class and communication.""" + """ + Main function to initiate Notifier class and communication. + """ notifier = Notifier() notifier.activate_communication(Notifier.exchange_object, mode=args.mode) while True: - msg_object, = notifier.exchange_object(mware=args.mware) + (msg_object,) = notifier.exchange_object(mware=args.mware) # Compute and print the actual values of the Dask objects for key, value in msg_object.items(): diff --git a/examples/encoders/jax_example.py b/examples/encoders/jax_example.py index 8778e56..b5022f8 100755 --- a/examples/encoders/jax_example.py +++ b/examples/encoders/jax_example.py @@ -38,42 +38,58 @@ class Notifier(MiddlewareCommunicator): @MiddlewareCommunicator.register( - "NativeObject", "$mware", "Notifier", "/notify/test_jax_exchange", - carrier="tcp", should_wait=True + "NativeObject", + "$mware", + "Notifier", + "/notify/test_jax_exchange", + carrier="tcp", + should_wait=True, ) def exchange_object(self, mware=None): - """Exchange messages with JAX arrays and other native Python objects.""" + """ + Exchange messages with JAX arrays and other native Python objects. + """ msg = input("Type your message: ") ret = { "message": msg, "jax_ones": jnp.ones((2, 4)), } - return ret, + return (ret,) def parse_args(): - """Parse command line arguments.""" - parser = argparse.ArgumentParser(description="A message publisher and listener for native Python objects and JAX arrays.") + """ + Parse command line arguments. + """ + parser = argparse.ArgumentParser( + description="A message publisher and listener for native Python objects and JAX arrays." + ) parser.add_argument( - "--mode", type=str, default="publish", + "--mode", + type=str, + default="publish", choices={"publish", "listen"}, - help="The transmission mode" + help="The transmission mode", ) parser.add_argument( - "--mware", type=str, default=DEFAULT_COMMUNICATOR, + "--mware", + type=str, + default=DEFAULT_COMMUNICATOR, choices=MiddlewareCommunicator.get_communicators(), - help="The middleware to use for transmission" + help="The middleware to use for transmission", ) return parser.parse_args() def main(args): - """Main function to initiate Notifier class and communication.""" + """ + Main function to initiate Notifier class and communication. + """ notifier = Notifier() notifier.activate_communication(Notifier.exchange_object, mode=args.mode) while True: - msg_object, = notifier.exchange_object(mware=args.mware) + (msg_object,) = notifier.exchange_object(mware=args.mware) print("Method result:", msg_object) diff --git a/examples/encoders/mxnet_example.py b/examples/encoders/mxnet_example.py index 1ff24a2..d34c9dc 100755 --- a/examples/encoders/mxnet_example.py +++ b/examples/encoders/mxnet_example.py @@ -40,44 +40,63 @@ class Notifier(MiddlewareCommunicator): @MiddlewareCommunicator.register( - "NativeObject", "$mware", "Notifier", "/notify/test_mxnet_exchange", - carrier="", should_wait=True, - listener_kwargs=dict(load_mxnet_device=mxnet.gpu(0), map_mxnet_devices={'cpu': 'cuda:0', 'gpu:0': 'cpu'}) + "NativeObject", + "$mware", + "Notifier", + "/notify/test_mxnet_exchange", + carrier="", + should_wait=True, + listener_kwargs=dict( + load_mxnet_device=mxnet.gpu(0), + map_mxnet_devices={"cpu": "cuda:0", "gpu:0": "cpu"}, + ), ) def exchange_object(self, mware=None): - """Exchange messages with MXNet tensors and other native Python objects.""" + """ + Exchange messages with MXNet tensors and other native Python objects. + """ msg = input("Type your message: ") ret = { "message": msg, "mx_ones": mxnet.nd.ones((2, 4), ctx=mxnet.cpu()), - "mxnet_zeros_cuda": mxnet.nd.zeros((2, 3), ctx=mxnet.gpu(0)) + "mxnet_zeros_cuda": mxnet.nd.zeros((2, 3), ctx=mxnet.gpu(0)), } - return ret, + return (ret,) def parse_args(): - """Parse command line arguments.""" - parser = argparse.ArgumentParser(description="A message publisher and listener for native Python objects and MXNet tensors.") + """ + Parse command line arguments. + """ + parser = argparse.ArgumentParser( + description="A message publisher and listener for native Python objects and MXNet tensors." + ) parser.add_argument( - "--mode", type=str, default="publish", + "--mode", + type=str, + default="publish", choices={"publish", "listen"}, - help="The transmission mode" + help="The transmission mode", ) parser.add_argument( - "--mware", type=str, default=DEFAULT_COMMUNICATOR, + "--mware", + type=str, + default=DEFAULT_COMMUNICATOR, choices=MiddlewareCommunicator.get_communicators(), - help="The middleware to use for transmission" + help="The middleware to use for transmission", ) return parser.parse_args() def main(args): - """Main function to initiate Notifier class and communication.""" + """ + Main function to initiate Notifier class and communication. + """ notifier = Notifier() notifier.activate_communication(Notifier.exchange_object, mode=args.mode) while True: - msg_object, = notifier.exchange_object(mware=args.mware) + (msg_object,) = notifier.exchange_object(mware=args.mware) print("Method result:", msg_object) diff --git a/examples/encoders/numpy_pandas_example.py b/examples/encoders/numpy_pandas_example.py index f069280..737fa20 100755 --- a/examples/encoders/numpy_pandas_example.py +++ b/examples/encoders/numpy_pandas_example.py @@ -40,46 +40,65 @@ class Notifier(MiddlewareCommunicator): @MiddlewareCommunicator.register( - "NativeObject", "$mware", "Notifier", "/notify/test_native_exchange", - carrier="tcp", should_wait=True + "NativeObject", + "$mware", + "Notifier", + "/notify/test_native_exchange", + carrier="tcp", + should_wait=True, ) def exchange_object(self, mware=None): - """Exchange messages with NumPy arrays, pandas series/dataframes, and other native Python objects.""" + """ + Exchange messages with NumPy arrays, pandas series/dataframes, and other native Python objects. + """ msg = input("Type your message: ") ret = { "message": msg, "numpy_array": np.ones((2, 4)), "pandas_series": pd.Series([1, 3, 5, np.nan, 6, 8]), - "pandas_dataframe": pd.DataFrame(np.random.randn(6, 4), index=pd.date_range("20130101", periods=6), - columns=list("ABCD")), + "pandas_dataframe": pd.DataFrame( + np.random.randn(6, 4), + index=pd.date_range("20130101", periods=6), + columns=list("ABCD"), + ), } - return ret, + return (ret,) def parse_args(): - """Parse command line arguments.""" - parser = argparse.ArgumentParser(description="A message publisher and listener for native Python objects, NumPy arrays, and pandas series/dataframes.") + """ + Parse command line arguments. + """ + parser = argparse.ArgumentParser( + description="A message publisher and listener for native Python objects, NumPy arrays, and pandas series/dataframes." + ) parser.add_argument( - "--mode", type=str, default="publish", + "--mode", + type=str, + default="publish", choices={"publish", "listen"}, - help="The transmission mode" + help="The transmission mode", ) parser.add_argument( - "--mware", type=str, default=DEFAULT_COMMUNICATOR, + "--mware", + type=str, + default=DEFAULT_COMMUNICATOR, choices=MiddlewareCommunicator.get_communicators(), - help="The middleware to use for transmission" + help="The middleware to use for transmission", ) return parser.parse_args() def main(args): - """Main function to initiate Notifier class and communication.""" + """ + Main function to initiate Notifier class and communication. + """ notifier = Notifier() notifier.activate_communication(Notifier.exchange_object, mode=args.mode) while True: - msg_object, = notifier.exchange_object(mware=args.mware) + (msg_object,) = notifier.exchange_object(mware=args.mware) print("Method result:", msg_object) diff --git a/examples/encoders/paddlepaddle_example.py b/examples/encoders/paddlepaddle_example.py index bce99ef..b5e71e3 100755 --- a/examples/encoders/paddlepaddle_example.py +++ b/examples/encoders/paddlepaddle_example.py @@ -29,6 +29,7 @@ """ import argparse + try: import paddle except ImportError: @@ -39,44 +40,67 @@ class Notifier(MiddlewareCommunicator): @MiddlewareCommunicator.register( - "NativeObject", "$mware", "Notifier", "/notify/test_paddle_exchange", - carrier="", should_wait=True, - listener_kwargs=dict(load_paddle_device='gpu:0', map_paddle_devices={'cpu': 'cuda:0', 'gpu:0': 'cpu'}) + "NativeObject", + "$mware", + "Notifier", + "/notify/test_paddle_exchange", + carrier="", + should_wait=True, + listener_kwargs=dict( + load_paddle_device="gpu:0", + map_paddle_devices={"cpu": "cuda:0", "gpu:0": "cpu"}, + ), ) def exchange_object(self, mware=None): - """Exchange messages with PaddlePaddle tensors and other native Python objects.""" + """ + Exchange messages with PaddlePaddle tensors and other native Python objects. + """ msg = input("Type your message: ") ret = { "message": msg, - "paddle_ones": paddle.ones([2, 4], dtype='float32', place=paddle.CPUPlace()), - "paddle_zeros_cuda": paddle.zeros([2, 3], dtype='float32', place=paddle.CUDAPlace(0)) + "paddle_ones": paddle.ones( + [2, 4], dtype="float32", place=paddle.CPUPlace() + ), + "paddle_zeros_cuda": paddle.zeros( + [2, 3], dtype="float32", place=paddle.CUDAPlace(0) + ), } - return ret, + return (ret,) def parse_args(): - """Parse command line arguments.""" - parser = argparse.ArgumentParser(description="A message publisher and listener for native Python objects and PaddlePaddle tensors.") + """ + Parse command line arguments. + """ + parser = argparse.ArgumentParser( + description="A message publisher and listener for native Python objects and PaddlePaddle tensors." + ) parser.add_argument( - "--mode", type=str, default="publish", + "--mode", + type=str, + default="publish", choices={"publish", "listen"}, - help="The transmission mode" + help="The transmission mode", ) parser.add_argument( - "--mware", type=str, default=DEFAULT_COMMUNICATOR, + "--mware", + type=str, + default=DEFAULT_COMMUNICATOR, choices=MiddlewareCommunicator.get_communicators(), - help="The middleware to use for transmission" + help="The middleware to use for transmission", ) return parser.parse_args() def main(args): - """Main function to initiate Notifier class and communication.""" + """ + Main function to initiate Notifier class and communication. + """ notifier = Notifier() notifier.activate_communication(Notifier.exchange_object, mode=args.mode) while True: - msg_object, = notifier.exchange_object(mware=args.mware) + (msg_object,) = notifier.exchange_object(mware=args.mware) print("Method result:", msg_object) diff --git a/examples/encoders/pillow_example.py b/examples/encoders/pillow_example.py index 7fa29f5..74edd05 100755 --- a/examples/encoders/pillow_example.py +++ b/examples/encoders/pillow_example.py @@ -40,44 +40,56 @@ class Notifier(MiddlewareCommunicator): @MiddlewareCommunicator.register( - "NativeObject", "$mware", "Notify", "/notify/test_native_exchange", - carrier="", should_wait=True + "NativeObject", + "$mware", + "Notify", + "/notify/test_native_exchange", + carrier="", + should_wait=True, ) def exchange_object(self, mware=None): msg = input("Type your message: ") imarray = np.random.rand(100, 100, 3) * 255 ret = { "message": msg, - "pillow_random": Image.fromarray(imarray.astype('uint8')).convert('RGBA'), + "pillow_random": Image.fromarray(imarray.astype("uint8")).convert("RGBA"), "pillow_png": Image.open("../../assets/wrapyfi.png"), - "pillow_jpg": Image.open("../../assets/wrapyfi.jpg") + "pillow_jpg": Image.open("../../assets/wrapyfi.jpg"), } - return ret, + return (ret,) def parse_args(): - """Parse command line arguments.""" + """ + Parse command line arguments. + """ parser = argparse.ArgumentParser() parser.add_argument( - "--mode", type=str, default="publish", + "--mode", + type=str, + default="publish", choices={"publish", "listen"}, - help="The transmission mode" + help="The transmission mode", ) parser.add_argument( - "--mware", type=str, default=DEFAULT_COMMUNICATOR, + "--mware", + type=str, + default=DEFAULT_COMMUNICATOR, choices=MiddlewareCommunicator.get_communicators(), - help="The middleware to use for transmission" + help="The middleware to use for transmission", ) return parser.parse_args() def main(args): - """Main function to initiate Notify class and communication.""" + """ + Main function to initiate Notify class and communication. + """ notifier = Notifier() notifier.activate_communication(Notifier.exchange_object, mode=args.mode) while True: - msg_object, = notifier.exchange_object(mware=args.mware) + (msg_object,) = notifier.exchange_object(mware=args.mware) print("Method result:", msg_object) diff --git a/examples/encoders/pint_example.py b/examples/encoders/pint_example.py index 040b068..81a2d45 100644 --- a/examples/encoders/pint_example.py +++ b/examples/encoders/pint_example.py @@ -38,53 +38,76 @@ class Notifier(MiddlewareCommunicator): @MiddlewareCommunicator.register( - "NativeObject", "$mware", "Notifier", "/notify/test_pint_exchange", - carrier="tcp", should_wait=True + "NativeObject", + "$mware", + "Notifier", + "/notify/test_pint_exchange", + carrier="tcp", + should_wait=True, ) def exchange_object(self, mware=None): - """Exchange messages with Pint Quantities and other native Python objects.""" + """ + Exchange messages with Pint Quantities and other native Python objects. + """ msg = input("Type your message: ") # Creating a Pint Quantity ureg = pint.UnitRegistry() - quantity = 42 * ureg.parse_expression('meter') + quantity = 42 * ureg.parse_expression("meter") # Constructing the message object to be transmitted - ret = [{"message": msg, - "pint_quantity": quantity, - "list": [1, 2, 3]}, - "string", - 0.4344, - {"other": (1, 2, 3, 4.32,)}] - return ret, + ret = [ + {"message": msg, "pint_quantity": quantity, "list": [1, 2, 3]}, + "string", + 0.4344, + { + "other": ( + 1, + 2, + 3, + 4.32, + ) + }, + ] + return (ret,) def parse_args(): - """Parse command line arguments.""" - parser = argparse.ArgumentParser(description="A message publisher and listener for native Python objects and Pint Quantities.") + """ + Parse command line arguments. + """ + parser = argparse.ArgumentParser( + description="A message publisher and listener for native Python objects and Pint Quantities." + ) parser.add_argument( - "--mode", type=str, default="publish", + "--mode", + type=str, + default="publish", choices={"publish", "listen"}, - help="The transmission mode" + help="The transmission mode", ) parser.add_argument( - "--mware", type=str, default=DEFAULT_COMMUNICATOR, + "--mware", + type=str, + default=DEFAULT_COMMUNICATOR, choices=MiddlewareCommunicator.get_communicators(), - help="The middleware to use for transmission" + help="The middleware to use for transmission", ) return parser.parse_args() def main(args): - """Main function to initiate Notifier class and communication.""" + """ + Main function to initiate Notifier class and communication. + """ notifier = Notifier() notifier.activate_communication(Notifier.exchange_object, mode=args.mode) while True: - msg_object, = notifier.exchange_object(mware=args.mware) + (msg_object,) = notifier.exchange_object(mware=args.mware) print("Method result:", msg_object) if __name__ == "__main__": args = parse_args() - main(args) \ No newline at end of file + main(args) diff --git a/examples/encoders/plugins/astropy_tables.py b/examples/encoders/plugins/astropy_tables.py index f3fe48f..375c1b8 100644 --- a/examples/encoders/plugins/astropy_tables.py +++ b/examples/encoders/plugins/astropy_tables.py @@ -24,6 +24,7 @@ try: from astropy.table import Table + HAVE_ASTROPY = True except ImportError: HAVE_ASTROPY = False @@ -54,9 +55,9 @@ def encode(self, obj, *args, **kwargs): - '__wrapyfi__': A tuple containing the class name and the encoded data string """ memfile = io.BytesIO() - obj.write(memfile, format='fits') + obj.write(memfile, format="fits") memfile.seek(0) - obj_data = base64.b64encode(memfile.getvalue()).decode('ascii') + obj_data = base64.b64encode(memfile.getvalue()).decode("ascii") memfile.close() return True, dict(__wrapyfi__=(str(self.__class__.__name__), obj_data)) @@ -74,8 +75,8 @@ def decode(self, obj_type, obj_full, *args, **kwargs): """ encoded_str = obj_full[1] if isinstance(encoded_str, str): - encoded_str = encoded_str.encode('ascii') + encoded_str = encoded_str.encode("ascii") with io.BytesIO(base64.b64decode(encoded_str)) as memfile: memfile.seek(0) - obj = Table.read(memfile, format='fits') + obj = Table.read(memfile, format="fits") return True, obj diff --git a/examples/encoders/pyarrow_example.py b/examples/encoders/pyarrow_example.py index ec87244..9027575 100755 --- a/examples/encoders/pyarrow_example.py +++ b/examples/encoders/pyarrow_example.py @@ -27,6 +27,7 @@ """ import argparse + try: import pyarrow as pa except ImportError: @@ -37,46 +38,61 @@ class Notifier(MiddlewareCommunicator): @MiddlewareCommunicator.register( - "NativeObject", "$mware", "Notifier", "/notify/test_arrow_exchange", - carrier="tcp", should_wait=True + "NativeObject", + "$mware", + "Notifier", + "/notify/test_arrow_exchange", + carrier="tcp", + should_wait=True, ) def exchange_object(self, mware=None): - """Exchange messages with PyArrow arrays and other native Python objects.""" + """ + Exchange messages with PyArrow arrays and other native Python objects. + """ msg = input("Type your message: ") ret = { "message": msg, "pyarrow_array": pa.array(range(100)), } - return ret, + return (ret,) def parse_args(): - """Parse command line arguments.""" + """ + Parse command line arguments. + """ parser = argparse.ArgumentParser( - description="A message publisher and listener for native Python objects and PyArrow arrays.") + description="A message publisher and listener for native Python objects and PyArrow arrays." + ) parser.add_argument( - "--mode", type=str, default="publish", + "--mode", + type=str, + default="publish", choices={"publish", "listen"}, - help="The transmission mode" + help="The transmission mode", ) parser.add_argument( - "--mware", type=str, default=DEFAULT_COMMUNICATOR, + "--mware", + type=str, + default=DEFAULT_COMMUNICATOR, choices=MiddlewareCommunicator.get_communicators(), - help="The middleware to use for transmission" + help="The middleware to use for transmission", ) return parser.parse_args() def main(args): - """Main function to initiate Notifier class and communication.""" + """ + Main function to initiate Notifier class and communication. + """ notifier = Notifier() notifier.activate_communication(Notifier.exchange_object, mode=args.mode) while True: - msg_object, = notifier.exchange_object(mware=args.mware) + (msg_object,) = notifier.exchange_object(mware=args.mware) print("Method result:", msg_object) if __name__ == "__main__": args = parse_args() - main(args) \ No newline at end of file + main(args) diff --git a/examples/encoders/pytorch_example.py b/examples/encoders/pytorch_example.py index 37b4351..1da2e91 100755 --- a/examples/encoders/pytorch_example.py +++ b/examples/encoders/pytorch_example.py @@ -40,45 +40,63 @@ class Notifier(MiddlewareCommunicator): @MiddlewareCommunicator.register( - "NativeObject", "$mware", "Notifier", "/notify/test_torch_exchange", - carrier="", should_wait=True, - listener_kwargs=dict(load_torch_device='cuda:0', map_torch_devices={'cpu': 'cuda:0', 'cuda:0': 'cpu'}) + "NativeObject", + "$mware", + "Notifier", + "/notify/test_torch_exchange", + carrier="", + should_wait=True, + listener_kwargs=dict( + load_torch_device="cuda:0", + map_torch_devices={"cpu": "cuda:0", "cuda:0": "cpu"}, + ), ) def exchange_object(self, mware=None): - """Exchange messages with PyTorch tensors and other native Python objects.""" + """ + Exchange messages with PyTorch tensors and other native Python objects. + """ msg = input("Type your message: ") ret = { "message": msg, - "torch_ones": torch.ones((2, 4), device='cpu'), - "torch_zeros_cuda": torch.zeros((2, 3), device='cuda:0') + "torch_ones": torch.ones((2, 4), device="cpu"), + "torch_zeros_cuda": torch.zeros((2, 3), device="cuda:0"), } - return ret, + return (ret,) def parse_args(): - """Parse command line arguments.""" + """ + Parse command line arguments. + """ parser = argparse.ArgumentParser( - description="A message publisher and listener for native Python objects and PyTorch tensors.") + description="A message publisher and listener for native Python objects and PyTorch tensors." + ) parser.add_argument( - "--mode", type=str, default="publish", + "--mode", + type=str, + default="publish", choices={"publish", "listen"}, - help="The transmission mode" + help="The transmission mode", ) parser.add_argument( - "--mware", type=str, default=DEFAULT_COMMUNICATOR, + "--mware", + type=str, + default=DEFAULT_COMMUNICATOR, choices=MiddlewareCommunicator.get_communicators(), - help="The middleware to use for transmission" + help="The middleware to use for transmission", ) return parser.parse_args() def main(args): - """Main function to initiate Notifier class and communication.""" + """ + Main function to initiate Notifier class and communication. + """ notifier = Notifier() notifier.activate_communication(Notifier.exchange_object, mode=args.mode) while True: - msg_object, = notifier.exchange_object(mware=args.mware) + (msg_object,) = notifier.exchange_object(mware=args.mware) print("Method result:", msg_object) diff --git a/examples/encoders/tensorflow_example.py b/examples/encoders/tensorflow_example.py index fe3290f..c17bcd8 100755 --- a/examples/encoders/tensorflow_example.py +++ b/examples/encoders/tensorflow_example.py @@ -27,6 +27,7 @@ """ import argparse + try: import tensorflow as tf except ImportError: @@ -37,45 +38,60 @@ class Notifier(MiddlewareCommunicator): @MiddlewareCommunicator.register( - "NativeObject", "$mware", "Notifier", "/notify/test_native_exchange", - carrier="", should_wait=True + "NativeObject", + "$mware", + "Notifier", + "/notify/test_native_exchange", + carrier="", + should_wait=True, ) def exchange_object(self, mware=None): - """Exchange messages with TensorFlow tensors and other native Python objects.""" + """ + Exchange messages with TensorFlow tensors and other native Python objects. + """ msg = input("Type your message: ") ret = { "message": msg, "tf_ones": tf.ones((2, 4)), - "tf_string": tf.constant("This is string") + "tf_string": tf.constant("This is string"), } - return ret, + return (ret,) def parse_args(): - """Parse command line arguments.""" + """ + Parse command line arguments. + """ parser = argparse.ArgumentParser( - description="A message publisher and listener for native Python objects and TensorFlow tensors.") + description="A message publisher and listener for native Python objects and TensorFlow tensors." + ) parser.add_argument( - "--mode", type=str, default="publish", + "--mode", + type=str, + default="publish", choices={"publish", "listen"}, - help="The transmission mode" + help="The transmission mode", ) parser.add_argument( - "--mware", type=str, default=DEFAULT_COMMUNICATOR, + "--mware", + type=str, + default=DEFAULT_COMMUNICATOR, choices=MiddlewareCommunicator.get_communicators(), - help="The middleware to use for transmission" + help="The middleware to use for transmission", ) return parser.parse_args() def main(args): - """Main function to initiate Notifier class and communication.""" + """ + Main function to initiate Notifier class and communication. + """ notifier = Notifier() notifier.activate_communication(Notifier.exchange_object, mode=args.mode) while True: - msg_object, = notifier.exchange_object(mware=args.mware) + (msg_object,) = notifier.exchange_object(mware=args.mware) print("Method result:", msg_object) diff --git a/examples/encoders/xarray_example.py b/examples/encoders/xarray_example.py index 51f9b89..eeb3365 100644 --- a/examples/encoders/xarray_example.py +++ b/examples/encoders/xarray_example.py @@ -41,57 +41,74 @@ class Notifier(MiddlewareCommunicator): @MiddlewareCommunicator.register( - "NativeObject", "$mware", "Notifier", "/notify/test_xarray_exchange", - carrier="tcp", should_wait=True + "NativeObject", + "$mware", + "Notifier", + "/notify/test_xarray_exchange", + carrier="tcp", + should_wait=True, ) def exchange_object(self, mware=None): - """Exchange messages with xarray DataArrays and other native Python objects.""" + """ + Exchange messages with xarray DataArrays and other native Python objects. + """ msg = input("Type your message: ") # Creating an example xarray DataArray data = np.random.rand(4, 3) - locs = ['IA', 'IL', 'IN'] - times = pd.date_range('2000-01-01', periods=4) - da = xr.DataArray(data, coords=[times, locs], dims=['time', 'space'], name='example') + locs = ["IA", "IL", "IN"] + times = pd.date_range("2000-01-01", periods=4) + da = xr.DataArray( + data, coords=[times, locs], dims=["time", "space"], name="example" + ) ret = { "message": msg, "xarray_dataarray": da, "additional_info": { - "set": {'a', 1, None}, - "list": [[[3, [4], 5.677890, 1.2]]] - } + "set": {"a", 1, None}, + "list": [[[3, [4], 5.677890, 1.2]]], + }, } - return ret, + return (ret,) def parse_args(): - """Parse command line arguments.""" + """ + Parse command line arguments. + """ parser = argparse.ArgumentParser( - description="A message publisher and listener for native Python objects and xarray DataArrays.") + description="A message publisher and listener for native Python objects and xarray DataArrays." + ) parser.add_argument( - "--mode", type=str, default="publish", + "--mode", + type=str, + default="publish", choices={"publish", "listen"}, - help="The transmission mode" + help="The transmission mode", ) parser.add_argument( - "--mware", type=str, default=DEFAULT_COMMUNICATOR, + "--mware", + type=str, + default=DEFAULT_COMMUNICATOR, choices=MiddlewareCommunicator.get_communicators(), - help="The middleware to use for transmission" + help="The middleware to use for transmission", ) return parser.parse_args() def main(args): - """Main function to initiate Notifier class and communication.""" + """ + Main function to initiate Notifier class and communication. + """ notifier = Notifier() notifier.activate_communication(Notifier.exchange_object, mode=args.mode) while True: - msg_object, = notifier.exchange_object(mware=args.mware) + (msg_object,) = notifier.exchange_object(mware=args.mware) print("Method result:", msg_object) if __name__ == "__main__": args = parse_args() - main(args) \ No newline at end of file + main(args) diff --git a/examples/encoders/zarr_example.py b/examples/encoders/zarr_example.py index 983cc66..4a60d2d 100644 --- a/examples/encoders/zarr_example.py +++ b/examples/encoders/zarr_example.py @@ -40,11 +40,17 @@ class Notifier(MiddlewareCommunicator): @MiddlewareCommunicator.register( - "NativeObject", "$mware", "Notifier", "/notify/test_zarr_exchange", - carrier="tcp", should_wait=True + "NativeObject", + "$mware", + "Notifier", + "/notify/test_zarr_exchange", + carrier="tcp", + should_wait=True, ) def exchange_object(self, mware=None): - """Exchange messages with Zarr arrays/groups and other native Python objects.""" + """ + Exchange messages with Zarr arrays/groups and other native Python objects. + """ msg = input("Type your message: ") # Creating an example Zarr Array @@ -52,41 +58,50 @@ def exchange_object(self, mware=None): # Creating an example Zarr Group zgroup = zarr.group() - zgroup.create_dataset('dataset1', data=np.random.randint(0, 100, 50), chunks=10) - zgroup.create_dataset('dataset2', data=np.random.random(100), chunks=10) + zgroup.create_dataset("dataset1", data=np.random.randint(0, 100, 50), chunks=10) + zgroup.create_dataset("dataset2", data=np.random.random(100), chunks=10) ret = { "message": msg, "zarr_array": zarray, "zarr_group": zgroup, } - return ret, + return (ret,) def parse_args(): - """Parse command line arguments.""" + """ + Parse command line arguments. + """ parser = argparse.ArgumentParser( - description="A message publisher and listener for native Python objects and Zarr arrays/groups.") + description="A message publisher and listener for native Python objects and Zarr arrays/groups." + ) parser.add_argument( - "--mode", type=str, default="publish", + "--mode", + type=str, + default="publish", choices={"publish", "listen"}, - help="The transmission mode" + help="The transmission mode", ) parser.add_argument( - "--mware", type=str, default=DEFAULT_COMMUNICATOR, + "--mware", + type=str, + default=DEFAULT_COMMUNICATOR, choices=MiddlewareCommunicator.get_communicators(), - help="The middleware to use for transmission" + help="The middleware to use for transmission", ) return parser.parse_args() def main(args): - """Main function to initiate Notifier class and communication.""" + """ + Main function to initiate Notifier class and communication. + """ notifier = Notifier() notifier.activate_communication(Notifier.exchange_object, mode=args.mode) while True: - msg_object, = notifier.exchange_object(mware=args.mware) + (msg_object,) = notifier.exchange_object(mware=args.mware) print("Method result:", msg_object) diff --git a/examples/hello_world.py b/examples/hello_world.py index 5496e3d..6edcc7b 100755 --- a/examples/hello_world.py +++ b/examples/hello_world.py @@ -29,6 +29,7 @@ """ + import argparse from wrapyfi.connect.wrapper import MiddlewareCommunicator, DEFAULT_COMMUNICATOR @@ -36,22 +37,64 @@ class HelloWorld(MiddlewareCommunicator): - @MiddlewareCommunicator.register("NativeObject", "$mware", "HelloWorld", "/hello/my_message", - carrier="tcp", should_wait=True) + @MiddlewareCommunicator.register( + "NativeObject", + "$mware", + "HelloWorld", + "/hello/my_message", + carrier="tcp", + should_wait=True, + ) def send_message(self, arg_from_requester="", mware=None): - """Exchange messages and mirror user input.""" + """ + Exchange messages and mirror user input. + """ msg = input("Type your message: ") obj = {"message": msg, "message_from_requester": arg_from_requester} - return obj, + return (obj,) + def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument("--publish", dest="mode", action="store_const", const="publish", default="listen", help="Publish mode") - parser.add_argument("--listen", dest="mode", action="store_const", const="listen", default="listen", help="Listen mode (default)") - parser.add_argument("--request", dest="mode", action="store_const", const="request", default="listen", help="Request mode") - parser.add_argument("--reply", dest="mode", action="store_const", const="reply", default="listen", help="Reply mode") - parser.add_argument("--mware", type=str, default=DEFAULT_COMMUNICATOR, choices=MiddlewareCommunicator.get_communicators(), - help="The middleware to use for transmission") + parser.add_argument( + "--publish", + dest="mode", + action="store_const", + const="publish", + default="listen", + help="Publish mode", + ) + parser.add_argument( + "--listen", + dest="mode", + action="store_const", + const="listen", + default="listen", + help="Listen mode (default)", + ) + parser.add_argument( + "--request", + dest="mode", + action="store_const", + const="request", + default="listen", + help="Request mode", + ) + parser.add_argument( + "--reply", + dest="mode", + action="store_const", + const="reply", + default="listen", + help="Reply mode", + ) + parser.add_argument( + "--mware", + type=str, + default=DEFAULT_COMMUNICATOR, + choices=MiddlewareCommunicator.get_communicators(), + help="The middleware to use for transmission", + ) return parser.parse_args() @@ -61,6 +104,9 @@ def parse_args(): hello_world.activate_communication(HelloWorld.send_message, mode=args.mode) while True: - my_message, = hello_world.send_message(arg_from_requester=f"I got this message from the script running in {args.mode} mode", mware=args.mware) + (my_message,) = hello_world.send_message( + arg_from_requester=f"I got this message from the script running in {args.mode} mode", + mware=args.mware, + ) if my_message is not None: print("Method result:", my_message) diff --git a/examples/robots/icub_head.py b/examples/robots/icub_head.py index c330bd9..d61cbb2 100755 --- a/examples/robots/icub_head.py +++ b/examples/robots/icub_head.py @@ -86,9 +86,15 @@ from wrapyfi.connect.wrapper import MiddlewareCommunicator ICUB_DEFAULT_COMMUNICATOR = os.environ.get("WRAPYFI_DEFAULT_COMMUNICATOR", "yarp") -ICUB_DEFAULT_COMMUNICATOR = os.environ.get("WRAPYFI_DEFAULT_MWARE", ICUB_DEFAULT_COMMUNICATOR) -ICUB_DEFAULT_COMMUNICATOR = os.environ.get("ICUB_DEFAULT_COMMUNICATOR", ICUB_DEFAULT_COMMUNICATOR) -ICUB_DEFAULT_COMMUNICATOR = os.environ.get("ICUB_DEFAULT_MWARE", ICUB_DEFAULT_COMMUNICATOR) +ICUB_DEFAULT_COMMUNICATOR = os.environ.get( + "WRAPYFI_DEFAULT_MWARE", ICUB_DEFAULT_COMMUNICATOR +) +ICUB_DEFAULT_COMMUNICATOR = os.environ.get( + "ICUB_DEFAULT_COMMUNICATOR", ICUB_DEFAULT_COMMUNICATOR +) +ICUB_DEFAULT_COMMUNICATOR = os.environ.get( + "ICUB_DEFAULT_MWARE", ICUB_DEFAULT_COMMUNICATOR +) EMOTION_LOOKUP = { "Neutral": [("LIGHTS", "neu")], @@ -101,7 +107,7 @@ "Contempt": [("raw", "L01"), ("raw", "R09"), ("raw", "ME9")], # change to array "Cunning": [("LIGHTS", "cun")], "Shy": [("LIGHTS", "shy")], - "Evil": [("LIGHTS", "evi")] + "Evil": [("LIGHTS", "evi")], } @@ -129,7 +135,9 @@ def cartesian_to_spherical(xyz=None, x=None, y=None, z=None, expand_return=None) ptr = np.zeros((3,)) xy = xyz[0] ** 2 + xyz[1] ** 2 ptr[0] = np.arctan2(xyz[1], xyz[0]) - ptr[1] = np.arctan2(xyz[2], np.sqrt(xy)) # for elevation angle defined from XY-plane up + ptr[1] = np.arctan2( + xyz[2], np.sqrt(xy) + ) # for elevation angle defined from XY-plane up # ptr[1] = np.arctan2(np.sqrt(xy), xyz[2]) # for elevation angle defined from Z-axis down ptr[2] = np.sqrt(xy + xyz[2] ** 2) return ptr if not expand_return else {"p": ptr[0], "t": ptr[1], "r": ptr[2]} @@ -145,6 +153,7 @@ def mode_smoothing_filter(time_series, default, window_length=6, min_count=None) :param min_count: int: Minimum number of values in the window to apply the smoothing filter """ import scipy.stats + if min_count is None: min_count = window_length // 2 mode = scipy.stats.mode(time_series[-window_length:]) @@ -171,15 +180,26 @@ class ICub(MiddlewareCommunicator, yarp.RFModule): FACIAL_EXPRESSIONS_QUEUE_SIZE = 7 FACIAL_EXPRESSION_SMOOTHING_WINDOW = 6 - def __init__(self, simulation=False, headless=False, get_cam_feed=True, - img_width=CAP_PROP_FRAME_WIDTH, img_height=CAP_PROP_FRAME_HEIGHT, - control_head=True, set_head_coordinates=True, head_coordinates_port=HEAD_COORDINATES_PORT, - control_eyes=True, set_eye_coordinates=True, eye_coordinates_port=EYE_COORDINATES_PORT, - ikingaze=False, - gaze_plane_coordinates_port=GAZE_PLANE_COORDINATES_PORT, - control_expressions=False, - set_facial_expressions=True, facial_expressions_port=FACIAL_EXPRESSIONS_PORT, - mware=MWARE): + def __init__( + self, + simulation=False, + headless=False, + get_cam_feed=True, + img_width=CAP_PROP_FRAME_WIDTH, + img_height=CAP_PROP_FRAME_HEIGHT, + control_head=True, + set_head_coordinates=True, + head_coordinates_port=HEAD_COORDINATES_PORT, + control_eyes=True, + set_eye_coordinates=True, + eye_coordinates_port=EYE_COORDINATES_PORT, + ikingaze=False, + gaze_plane_coordinates_port=GAZE_PLANE_COORDINATES_PORT, + control_expressions=False, + set_facial_expressions=True, + facial_expressions_port=FACIAL_EXPRESSIONS_PORT, + mware=MWARE, + ): """ Initialize the ICub head controller, facial expression transmitter and camera viewer. @@ -222,15 +242,19 @@ def __init__(self, simulation=False, headless=False, get_cam_feed=True, if simulation: props.put("remote", "/icubSim/head") - self.cam_props = {"cam_world_port": "/icubSim/cam", - "cam_left_port": "/icubSim/cam/left", - "cam_right_port": "/icubSim/cam/right"} + self.cam_props = { + "cam_world_port": "/icubSim/cam", + "cam_left_port": "/icubSim/cam/left", + "cam_right_port": "/icubSim/cam/right", + } emotion_cmd = f"yarp rpc /icubSim/face/emotions/in" else: props.put("remote", "/icub/head") - self.cam_props = {"cam_world_port": "/icub/cam/left", - "cam_left_port": "/icub/cam/left", - "cam_right_port": "/icub/cam/right"} + self.cam_props = { + "cam_world_port": "/icub/cam/left", + "cam_left_port": "/icub/cam/left", + "cam_right_port": "/icub/cam/right", + } emotion_cmd = f"yarp rpc /icub/face/emotions/in" if img_width is not None: @@ -248,10 +272,15 @@ def __init__(self, simulation=False, headless=False, get_cam_feed=True, # control emotional expressions using RPC self.client = pexpect.spawn(emotion_cmd) else: - logging.error("pexpect must be installed to control the emotion interface") + logging.error( + "pexpect must be installed to control the emotion interface" + ) self.activate_communication(self.update_facial_expressions, "disable") - self.last_expression = ["", ""] # (emotion part on the robot's face , emotional expression category) + self.last_expression = [ + "", + "", + ] # (emotion part on the robot's face , emotional expression category) self.expressions_queue = deque(maxlen=self.FACIAL_EXPRESSIONS_QUEUE_SIZE) else: self.activate_communication(self.update_facial_expressions, "disable") @@ -352,9 +381,20 @@ def build(self): Updates the default method arguments according to constructor arguments. This method is called by the module constructor. It is not necessary to call it manually. """ - ICub.acquire_head_coordinates.__defaults__ = (self.HEAD_COORDINATES_PORT, None, self.MWARE) - ICub.acquire_eye_coordinates.__defaults__ = (self.EYE_COORDINATES_PORT, None, self.MWARE) - ICub.receive_gaze_plane_coordinates.__defaults__ = (self.GAZE_PLANE_COORDINATES_PORT, self.MWARE) + ICub.acquire_head_coordinates.__defaults__ = ( + self.HEAD_COORDINATES_PORT, + None, + self.MWARE, + ) + ICub.acquire_eye_coordinates.__defaults__ = ( + self.EYE_COORDINATES_PORT, + None, + self.MWARE, + ) + ICub.receive_gaze_plane_coordinates.__defaults__ = ( + self.GAZE_PLANE_COORDINATES_PORT, + self.MWARE, + ) ICub.wait_for_gaze.__defaults__ = (True, self.MWARE) ICub.reset_gaze.__defaults__ = (self.MWARE,) ICub.update_head_gaze_speed.__defaults__ = (10.0, 10.0, 20.0, 0.8, self.MWARE) @@ -362,16 +402,37 @@ def build(self): ICub.update_eye_gaze_speed.__defaults__ = (10.0, 10.0, 20.0, 0.5, self.MWARE) ICub.control_eye_gaze.__defaults__ = (0.0, 0.0, 0.0, self.MWARE) ICub._control_head_eye_gaze.__defaults__ = (self.MWARE,) - ICub.control_gaze_at_plane.__defaults__ = (0, 0, 0.3, 0.3, True, True, self.MWARE) - ICub.acquire_facial_expressions.__defaults__ = (self.FACIAL_EXPRESSIONS_PORT, None, self.MWARE) + ICub.control_gaze_at_plane.__defaults__ = ( + 0, + 0, + 0.3, + 0.3, + True, + True, + self.MWARE, + ) + ICub.acquire_facial_expressions.__defaults__ = ( + self.FACIAL_EXPRESSIONS_PORT, + None, + self.MWARE, + ) ICub.update_facial_expressions.__defaults__ = (None, False, "mode", self.MWARE) - ICub.receive_images.__defaults__ = (self.CAP_PROP_FRAME_WIDTH, self.CAP_PROP_FRAME_HEIGHT, True) - - @MiddlewareCommunicator.register("NativeObject", "$_mware", - "ICub", "$head_coordinates_port", - should_wait=False) - def acquire_head_coordinates(self, head_coordinates_port=HEAD_COORDINATES_PORT, cv2_key=None, - _mware=MWARE, **kwargs): + ICub.receive_images.__defaults__ = ( + self.CAP_PROP_FRAME_WIDTH, + self.CAP_PROP_FRAME_HEIGHT, + True, + ) + + @MiddlewareCommunicator.register( + "NativeObject", "$_mware", "ICub", "$head_coordinates_port", should_wait=False + ) + def acquire_head_coordinates( + self, + head_coordinates_port=HEAD_COORDINATES_PORT, + cv2_key=None, + _mware=MWARE, + **kwargs, + ): """ Acquire head coordinates for controlling the iCub. @@ -382,7 +443,7 @@ def acquire_head_coordinates(self, head_coordinates_port=HEAD_COORDINATES_PORT, if cv2_key is None: logging.error("controlling orientation in headless mode not yet supported") - return None, + return (None,) else: if cv2_key == 27: # Esc key to exit exit(0) @@ -413,20 +474,29 @@ def acquire_head_coordinates(self, head_coordinates_port=HEAD_COORDINATES_PORT, logging.info("resetting the orientation") else: logging.info(cv2_key) # else print its value - return None, + return (None,) - return {"topic": head_coordinates_port.split("/")[-1], + return ( + { + "topic": head_coordinates_port.split("/")[-1], "timestamp": time.time(), "pitch": self._curr_head[0], "roll": self._curr_head[1], "yaw": self._curr_head[2], - "order": "zyx"}, - - @MiddlewareCommunicator.register("NativeObject", "$_mware", - "ICub", "$eye_coordinates_port", - should_wait=False) - def acquire_eye_coordinates(self, eye_coordinates_port=EYE_COORDINATES_PORT, cv2_key=None, - _mware=MWARE, **kwargs): + "order": "zyx", + }, + ) + + @MiddlewareCommunicator.register( + "NativeObject", "$_mware", "ICub", "$eye_coordinates_port", should_wait=False + ) + def acquire_eye_coordinates( + self, + eye_coordinates_port=EYE_COORDINATES_PORT, + cv2_key=None, + _mware=MWARE, + **kwargs, + ): """ Acquire eye coordinates for controlling the iCub. @@ -437,7 +507,7 @@ def acquire_eye_coordinates(self, eye_coordinates_port=EYE_COORDINATES_PORT, cv2 if cv2_key is None: logging.error("controlling orientation in headless mode not yet supported") - return None, + return (None,) else: if cv2_key == 27: # Esc key to exit exit(0) @@ -462,30 +532,46 @@ def acquire_eye_coordinates(self, eye_coordinates_port=EYE_COORDINATES_PORT, cv2 logging.info("resetting the orientation") else: logging.info(cv2_key) # else print its value - return None, + return (None,) - return {"topic": eye_coordinates_port.split("/")[-1], + return ( + { + "topic": eye_coordinates_port.split("/")[-1], "timestamp": time.time(), "pitch": self._curr_eyes[0], "yaw": self._curr_eyes[1], - "vergence": self._curr_eyes[2]}, - - @MiddlewareCommunicator.register("NativeObject", "$_mware", - "ICub", "$gaze_plane_coordinates_port", - should_wait=False) - def receive_gaze_plane_coordinates(self, gaze_plane_coordinates_port=GAZE_PLANE_COORDINATES_PORT, - _mware=MWARE, **kwargs): + "vergence": self._curr_eyes[2], + }, + ) + + @MiddlewareCommunicator.register( + "NativeObject", + "$_mware", + "ICub", + "$gaze_plane_coordinates_port", + should_wait=False, + ) + def receive_gaze_plane_coordinates( + self, + gaze_plane_coordinates_port=GAZE_PLANE_COORDINATES_PORT, + _mware=MWARE, + **kwargs, + ): """ Receive gaze plane (normalized x,y) coordinates for controlling the iCub. :param gaze_plane_coordinates_port: str: Port to receive gaze plane coordinates :return: dict: Gaze plane coordinates """ - return None, - - @MiddlewareCommunicator.register("NativeObject", "$_mware", - "ICub", "/icub_controller/logs/wait_for_gaze", - should_wait=False) + return (None,) + + @MiddlewareCommunicator.register( + "NativeObject", + "$_mware", + "ICub", + "/icub_controller/logs/wait_for_gaze", + should_wait=False, + ) def wait_for_gaze(self, reset=True, _mware=MWARE): """ Wait for the gaze actuation to complete. @@ -506,13 +592,21 @@ def wait_for_gaze(self, reset=True, _mware=MWARE): self._ipos.positionMove(self._encs.data()) while not self._ipos.checkMotionDone(): pass - return {"topic": "logging_wait_for_gaze", + return ( + { + "topic": "logging_wait_for_gaze", "timestamp": time.time(), - "command": f"waiting for gaze completed with reset={reset}"}, - - @MiddlewareCommunicator.register("NativeObject", "$_mware", - "ICub", "/icub_controller/logs/reset_gaze", - should_wait=False) + "command": f"waiting for gaze completed with reset={reset}", + }, + ) + + @MiddlewareCommunicator.register( + "NativeObject", + "$_mware", + "ICub", + "/icub_controller/logs/reset_gaze", + should_wait=False, + ) def reset_gaze(self, _mware=MWARE): """ Reset the eyes and head to their original position. @@ -521,14 +615,24 @@ def reset_gaze(self, _mware=MWARE): :return: dict: Gaze reset log for a given time step """ self.wait_for_gaze(reset=True) - return {"topic": "logging_reset_gaze", + return ( + { + "topic": "logging_reset_gaze", "timestamp": time.time(), - "command": f"reset gaze"}, - - @MiddlewareCommunicator.register("NativeObject", "$mware", - "ICub", "/icub_controller/logs/head_speed", - should_wait=False) - def update_head_gaze_speed(self, pitch=10.0, roll=10.0, yaw=20.0, head=0.8, _mware=MWARE, **kwargs): + "command": f"reset gaze", + }, + ) + + @MiddlewareCommunicator.register( + "NativeObject", + "$mware", + "ICub", + "/icub_controller/logs/head_speed", + should_wait=False, + ) + def update_head_gaze_speed( + self, pitch=10.0, roll=10.0, yaw=20.0, head=0.8, _mware=MWARE, **kwargs + ): """ Control the iCub head speed. @@ -542,22 +646,36 @@ def update_head_gaze_speed(self, pitch=10.0, roll=10.0, yaw=20.0, head=0.8, _mwa if self.ikingaze: self._igaze.setNeckTrajTime(head) - return {"topic": "logging_head_speed", + return ( + { + "topic": "logging_head_speed", "timestamp": time.time(), - "command": f"head speed set to {head}"}, + "command": f"head speed set to {head}", + }, + ) else: self._ipos.setRefSpeed(0, pitch) self._ipos.setRefSpeed(1, roll) self._ipos.setRefSpeed(2, yaw) - return {"topic": "logging_head_speed", + return ( + { + "topic": "logging_head_speed", "timestamp": time.time(), - "command": f"head speed set to {pitch, roll, yaw} (pitch, roll, yaw)"}, - - @MiddlewareCommunicator.register("NativeObject", "$mware", - "ICub", "/icub_controller/logs/eye_speed", - should_wait=False) - def update_eye_gaze_speed(self, pitch=10.0, yaw=10.0, vergence=20.0, eye=0.5, _mware=MWARE, **kwargs): + "command": f"head speed set to {pitch, roll, yaw} (pitch, roll, yaw)", + }, + ) + + @MiddlewareCommunicator.register( + "NativeObject", + "$mware", + "ICub", + "/icub_controller/logs/eye_speed", + should_wait=False, + ) + def update_eye_gaze_speed( + self, pitch=10.0, yaw=10.0, vergence=20.0, eye=0.5, _mware=MWARE, **kwargs + ): """ Control the iCub eye speed. @@ -571,22 +689,36 @@ def update_eye_gaze_speed(self, pitch=10.0, yaw=10.0, vergence=20.0, eye=0.5, _m if self.ikingaze: self._igaze.setEyesTrajTime(eye) - return {"topic": "logging_eye_speed", + return ( + { + "topic": "logging_eye_speed", "timestamp": time.time(), - "command": f"eye speed set to {eye}"}, + "command": f"eye speed set to {eye}", + }, + ) else: self._ipos.setRefSpeed(3, pitch) self._ipos.setRefSpeed(4, yaw) self._ipos.setRefSpeed(5, vergence) - return {"topic": "logging_eye_speed", + return ( + { + "topic": "logging_eye_speed", "timestamp": time.time(), - "command": f"eye speed set to {pitch, yaw, vergence} (pitch, yaw, vergence)"}, - - @MiddlewareCommunicator.register("NativeObject", "$_mware", - "ICub", "/icub_controller/logs/head_orientation_coordinates", - should_wait=False) - def control_head_gaze(self, pitch=0.0, roll=0.0, yaw=0.0, order="xyz", _mware=MWARE, **kwargs): + "command": f"eye speed set to {pitch, yaw, vergence} (pitch, yaw, vergence)", + }, + ) + + @MiddlewareCommunicator.register( + "NativeObject", + "$_mware", + "ICub", + "/icub_controller/logs/head_orientation_coordinates", + should_wait=False, + ) + def control_head_gaze( + self, pitch=0.0, roll=0.0, yaw=0.0, order="xyz", _mware=MWARE, **kwargs + ): """ Control the iCub head relative to previous coordinates following the roll,pitch,yaw convention (order=xyz) (initialized at 0 looking straight ahead). @@ -599,8 +731,10 @@ def control_head_gaze(self, pitch=0.0, roll=0.0, yaw=0.0, order="xyz", _mware=MW :return: dict: Head orientation coordinates log for a given time step """ if order != "xyz": - logging.error("only accepts ratation angles following the order='xyz' convention") - return None, + logging.error( + "only accepts ratation angles following the order='xyz' convention" + ) + return (None,) # wait for the action to complete # self.wait_for_gaze(reset=False) @@ -615,14 +749,24 @@ def control_head_gaze(self, pitch=0.0, roll=0.0, yaw=0.0, order="xyz", _mware=MW # self._ipos.positionMove(self.init_pos_head.data()) self._curr_head = list((pitch, roll, yaw)) - return {"topic": "logging_head_coordinates", + return ( + { + "topic": "logging_head_coordinates", "timestamp": time.time(), - "command": f"head orientation set to {self._curr_head} (pitch, roll, yaw)"}, - - @MiddlewareCommunicator.register("NativeObject", "$_mware", - "ICub", "/icub_controller/logs/eye_orientation_coordinates", - should_wait=False) - def control_eye_gaze(self, pitch=0.0, yaw=0.0, vergence=0.0, _mware=MWARE, **kwargs): + "command": f"head orientation set to {self._curr_head} (pitch, roll, yaw)", + }, + ) + + @MiddlewareCommunicator.register( + "NativeObject", + "$_mware", + "ICub", + "/icub_controller/logs/eye_orientation_coordinates", + should_wait=False, + ) + def control_eye_gaze( + self, pitch=0.0, yaw=0.0, vergence=0.0, _mware=MWARE, **kwargs + ): """ Control the iCub eyes relative to previous coordinates (initialized at 0 looking straight ahead). @@ -641,19 +785,28 @@ def control_eye_gaze(self, pitch=0.0, yaw=0.0, vergence=0.0, _mware=MWARE, **kwa # eye control self.init_pos_eyes.set(3, self.init_pos_eyes.get(3) + pitch) # eye tilt self.init_pos_eyes.set(4, self.init_pos_eyes.get(4) + yaw) # eye pan/version - self.init_pos_eyes.set(5, self.init_pos_eyes.get( - 5) + vergence) # the vergence between the eyes (to align, set to 0) + self.init_pos_eyes.set( + 5, self.init_pos_eyes.get(5) + vergence + ) # the vergence between the eyes (to align, set to 0) # self._ipos.positionMove(self.init_pos_eyes.data()) self._curr_eyes = list((pitch, yaw, vergence)) - return {"topic": "logging_eye_coordinates", + return ( + { + "topic": "logging_eye_coordinates", "timestamp": time.time(), - "command": f"eye orientation set to {self._curr_eyes} (pitch, yaw, vergence)"}, - - @MiddlewareCommunicator.register("NativeObject", "$_mware", - "ICub", "/icub_controller/logs/head_eye_orientation_coordinates", - should_wait=False) + "command": f"eye orientation set to {self._curr_eyes} (pitch, yaw, vergence)", + }, + ) + + @MiddlewareCommunicator.register( + "NativeObject", + "$_mware", + "ICub", + "/icub_controller/logs/head_eye_orientation_coordinates", + should_wait=False, + ) def _control_head_eye_gaze(self, _mware=MWARE, **kwargs): """ Issue the movement command @@ -671,18 +824,37 @@ def _control_head_eye_gaze(self, _mware=MWARE, **kwargs): self.init_pos.set(2, self.init_pos_head.get(2)) # pan/yaw self.init_pos.set(3, self.init_pos_eyes.get(3)) # eye tilt self.init_pos.set(4, self.init_pos_eyes.get(4)) # eye pan/version - self.init_pos.set(5, self.init_pos_eyes.get(5)) # the vergence between the eyes (to align, set to 0) + self.init_pos.set( + 5, self.init_pos_eyes.get(5) + ) # the vergence between the eyes (to align, set to 0) self._ipos.positionMove(self.init_pos.data()) - return {"topic": "logging_head_eye_coordinates", + return ( + { + "topic": "logging_head_eye_coordinates", "timestamp": time.time(), - "command": f"head orientation set to {self._curr_head} (pitch, roll, yaw) and eye orientation to {self._curr_eyes} (pitch, yaw, vergence)"}, - - @MiddlewareCommunicator.register("NativeObject", "$_mware", - "ICub", "/icub_controller/logs/gaze_plane_coordinates", - should_wait=False) - def control_gaze_at_plane(self, x=0.0, y=0.0, limit_x=0.3, limit_y=0.3, control_eyes=True, control_head=True, - _mware=MWARE, **kwargs): + "command": f"head orientation set to {self._curr_head} (pitch, roll, yaw) and eye orientation to {self._curr_eyes} (pitch, yaw, vergence)", + }, + ) + + @MiddlewareCommunicator.register( + "NativeObject", + "$_mware", + "ICub", + "/icub_controller/logs/gaze_plane_coordinates", + should_wait=False, + ) + def control_gaze_at_plane( + self, + x=0.0, + y=0.0, + limit_x=0.3, + limit_y=0.3, + control_eyes=True, + control_head=True, + _mware=MWARE, + **kwargs, + ): """ Gaze at specific point in a normalized plane in front of the iCub. @@ -704,8 +876,10 @@ def control_gaze_at_plane(self, x=0.0, y=0.0, limit_x=0.3, limit_y=0.3, control_ if control_eyes and control_head: if not self.ikingaze: - logging.error("set ikingaze=True in order to move eyes and head simultaneously") - return None, + logging.error( + "set ikingaze=True in order to move eyes and head simultaneously" + ) + return (None,) self.init_pos_ikin = yarp.Vector(3, self._gaze_encs.data()) self.init_pos_ikin.set(0, ptr_degrees[0]) self.init_pos_ikin.set(1, ptr_degrees[1]) @@ -715,23 +889,32 @@ def control_gaze_at_plane(self, x=0.0, y=0.0, limit_x=0.3, limit_y=0.3, control_ elif control_head: if self.ikingaze: logging.error("set ikingaze=False in order to move head only") - return None, + return (None,) self.control_head_gaze(pitch=ptr_degrees[1], roll=0, yaw=ptr_degrees[0]) elif control_eyes: if self.ikingaze: logging.error("set ikingaze=False in order to move eyes only") - return None, + return (None,) self.control_eye_gaze(pitch=ptr_degrees[1], yaw=ptr_degrees[0], vergence=0) - return {"topic": "logging_gaze_plane_coordinates", + return ( + { + "topic": "logging_gaze_plane_coordinates", "timestamp": time.time(), - "command": f"moving gaze toward {ptr_degrees} with head={control_head} and eyes={control_eyes}"}, - - @MiddlewareCommunicator.register("NativeObject", "$_mware", - "ICub", "$facial_expressions_port", - should_wait=False) - def acquire_facial_expressions(self, facial_expressions_port=FACIAL_EXPRESSIONS_PORT, cv2_key=None, - _mware=MWARE, **kwargs): + "command": f"moving gaze toward {ptr_degrees} with head={control_head} and eyes={control_eyes}", + }, + ) + + @MiddlewareCommunicator.register( + "NativeObject", "$_mware", "ICub", "$facial_expressions_port", should_wait=False + ) + def acquire_facial_expressions( + self, + facial_expressions_port=FACIAL_EXPRESSIONS_PORT, + cv2_key=None, + _mware=MWARE, + **kwargs, + ): """ Acquire facial expressions from the iCub. @@ -742,7 +925,7 @@ def acquire_facial_expressions(self, facial_expressions_port=FACIAL_EXPRESSIONS_ emotion = None if cv2_key is None: logging.error("controlling expressions in headless mode not yet supported") - return None, + return (None,) else: if cv2_key == 27: # Esc key to exit exit(0) @@ -780,15 +963,25 @@ def acquire_facial_expressions(self, facial_expressions_port=FACIAL_EXPRESSIONS_ logging.info("expressing shyness") else: logging.info(cv2_key) # else print its value - return None, - return {"topic": facial_expressions_port.split("/")[-1], + return (None,) + return ( + { + "topic": facial_expressions_port.split("/")[-1], "timestamp": time.time(), - "emotion_category": emotion}, - - @MiddlewareCommunicator.register("NativeObject", "$_mware", - "ICub", "/icub_controller/logs/facial_expressions", - should_wait=False) - def update_facial_expressions(self, expression, part=False, smoothing="mode", _mware=MWARE, **kwargs): + "emotion_category": emotion, + }, + ) + + @MiddlewareCommunicator.register( + "NativeObject", + "$_mware", + "ICub", + "/icub_controller/logs/facial_expressions", + should_wait=False, + ) + def update_facial_expressions( + self, expression, part=False, smoothing="mode", _mware=MWARE, **kwargs + ): """ Control facial expressions of the iCub. @@ -801,27 +994,34 @@ def update_facial_expressions(self, expression, part=False, smoothing="mode", _m :return: Emotion log for a given time step """ if expression is None: - return None, + return (None,) if isinstance(expression, (list, tuple)): expression = expression[-1] if smoothing == "mode": self.expressions_queue.append(expression) - transmitted_expression = mode_smoothing_filter(list(self.expressions_queue), default="neu", - window_length=self.FACIAL_EXPRESSION_SMOOTHING_WINDOW) + transmitted_expression = mode_smoothing_filter( + list(self.expressions_queue), + default="neu", + window_length=self.FACIAL_EXPRESSION_SMOOTHING_WINDOW, + ) else: transmitted_expression = expression - expressions_lookup = EMOTION_LOOKUP.get(transmitted_expression, transmitted_expression) + expressions_lookup = EMOTION_LOOKUP.get( + transmitted_expression, transmitted_expression + ) if isinstance(expressions_lookup, str): expressions_lookup = [(part if part else "LIGHTS", expressions_lookup)] - if self.last_expression[0] == (part if part else "LIGHTS") and self.last_expression[ - 1] == transmitted_expression: + if ( + self.last_expression[0] == (part if part else "LIGHTS") + and self.last_expression[1] == transmitted_expression + ): expressions_lookup = [] - for (part_lookup, expression_lookup) in expressions_lookup: + for part_lookup, expression_lookup in expressions_lookup: if part_lookup == "LIGHTS": self.client.sendline(f"set leb {expression_lookup}") self.client.expect(">>") @@ -838,18 +1038,50 @@ def update_facial_expressions(self, expression, part=False, smoothing="mode", _m self.last_expression[0] = part self.last_expression[1] = transmitted_expression - return {"topic": "logging_facial_expressions", + return ( + { + "topic": "logging_facial_expressions", "timestamp": time.time(), - "command": f"emotion set to {part} {expression} with smoothing={smoothing}"}, - - @MiddlewareCommunicator.register("Image", "yarp", "ICub", "$cam_world_port", - width="$img_width", height="$img_height", rgb="$_rgb") - @MiddlewareCommunicator.register("Image", "yarp", "ICub", "$cam_left_port", - width="$img_width", height="$img_height", rgb="$_rgb") - @MiddlewareCommunicator.register("Image", "yarp", "ICub", "$cam_right_port", - width="$img_width", height="$img_height", rgb="$_rgb") - def receive_images(self, cam_world_port, cam_left_port, cam_right_port, - img_width=CAP_PROP_FRAME_WIDTH, img_height=CAP_PROP_FRAME_HEIGHT, _rgb=True): + "command": f"emotion set to {part} {expression} with smoothing={smoothing}", + }, + ) + + @MiddlewareCommunicator.register( + "Image", + "yarp", + "ICub", + "$cam_world_port", + width="$img_width", + height="$img_height", + rgb="$_rgb", + ) + @MiddlewareCommunicator.register( + "Image", + "yarp", + "ICub", + "$cam_left_port", + width="$img_width", + height="$img_height", + rgb="$_rgb", + ) + @MiddlewareCommunicator.register( + "Image", + "yarp", + "ICub", + "$cam_right_port", + width="$img_width", + height="$img_height", + rgb="$_rgb", + ) + def receive_images( + self, + cam_world_port, + cam_left_port, + cam_right_port, + img_width=CAP_PROP_FRAME_WIDTH, + img_height=CAP_PROP_FRAME_HEIGHT, + _rgb=True, + ): """ Receive images from the iCub. @@ -883,75 +1115,128 @@ def updateModule(self): left_cam = cv2.cvtColor(left_cam, cv2.COLOR_BGR2RGB) right_cam = cv2.cvtColor(right_cam, cv2.COLOR_BGR2RGB) if not self.headless: - cv2.imshow("ICubCam", np.concatenate((left_cam, external_cam, right_cam), axis=1)) + cv2.imshow( + "ICubCam", np.concatenate((left_cam, external_cam, right_cam), axis=1) + ) k = cv2.waitKey(30) else: k = None - switch_emotion, = self.acquire_facial_expressions(facial_expressions_port=self.FACIAL_EXPRESSIONS_PORT, - cv2_key=k, _mware=self.MWARE) + (switch_emotion,) = self.acquire_facial_expressions( + facial_expressions_port=self.FACIAL_EXPRESSIONS_PORT, + cv2_key=k, + _mware=self.MWARE, + ) if switch_emotion is not None and isinstance(switch_emotion, dict): - self.update_facial_expressions(switch_emotion.get("emotion_category", None), - part=switch_emotion.get("part", False), _mware=self.MWARE) + self.update_facial_expressions( + switch_emotion.get("emotion_category", None), + part=switch_emotion.get("part", False), + _mware=self.MWARE, + ) # move robot head - move_robot_head, = self.acquire_head_coordinates(head_coordinates_port=self.HEAD_COORDINATES_PORT, - cv2_key=k, _mware=self.MWARE) + (move_robot_head,) = self.acquire_head_coordinates( + head_coordinates_port=self.HEAD_COORDINATES_PORT, + cv2_key=k, + _mware=self.MWARE, + ) if move_robot_head is not None and isinstance(move_robot_head, dict): robot_head_speed = move_robot_head.get("speed", False) if robot_head_speed and isinstance(robot_head_speed, dict): - self.update_head_gaze_speed(pitch=robot_head_speed.get("pitch", 10.0), - roll=robot_head_speed.get("roll", 10.0), - yaw=robot_head_speed.get("yaw", 20.0), _mware=self.MWARE) + self.update_head_gaze_speed( + pitch=robot_head_speed.get("pitch", 10.0), + roll=robot_head_speed.get("roll", 10.0), + yaw=robot_head_speed.get("yaw", 20.0), + _mware=self.MWARE, + ) if move_robot_head.get("reset_gaze", False): self.reset_gaze() - self.control_head_gaze(pitch=move_robot_head.get("pitch", 0.0), - roll=move_robot_head.get("roll", 0.0), - yaw=move_robot_head.get("yaw", 0.0), _mware=self.MWARE) + self.control_head_gaze( + pitch=move_robot_head.get("pitch", 0.0), + roll=move_robot_head.get("roll", 0.0), + yaw=move_robot_head.get("yaw", 0.0), + _mware=self.MWARE, + ) # move robot eyes - move_robot_eyes, = self.acquire_eye_coordinates(eye_coordinates_port=self.EYE_COORDINATES_PORT, - cv2_key=k, _mware=self.MWARE) + (move_robot_eyes,) = self.acquire_eye_coordinates( + eye_coordinates_port=self.EYE_COORDINATES_PORT, cv2_key=k, _mware=self.MWARE + ) if move_robot_eyes is not None and isinstance(move_robot_eyes, dict): robot_eye_speed = move_robot_eyes.get("speed", False) if robot_eye_speed and isinstance(robot_eye_speed, dict): - self.update_eye_gaze_speed(pitch=robot_eye_speed.get("pitch", 10.0), - yaw=robot_eye_speed.get("yaw", 10.0), - vergence=robot_eye_speed.get("vergence", 20.0), _mware=self.MWARE) + self.update_eye_gaze_speed( + pitch=robot_eye_speed.get("pitch", 10.0), + yaw=robot_eye_speed.get("yaw", 10.0), + vergence=robot_eye_speed.get("vergence", 20.0), + _mware=self.MWARE, + ) if move_robot_eyes.get("reset_gaze", False): self.reset_gaze() - self.control_eye_gaze(pitch=move_robot_eyes.get("pitch", 0.0), - yaw=move_robot_eyes.get("yaw", 0.0), - vergence=move_robot_eyes.get("vergence", 0.0), _mware=self.MWARE) + self.control_eye_gaze( + pitch=move_robot_eyes.get("pitch", 0.0), + yaw=move_robot_eyes.get("yaw", 0.0), + vergence=move_robot_eyes.get("vergence", 0.0), + _mware=self.MWARE, + ) if move_robot_head is not None or move_robot_eyes is not None: self._control_head_eye_gaze() - move_robot, = self.receive_gaze_plane_coordinates(gaze_plane_coordinates_port=self.GAZE_PLANE_COORDINATES_PORT, - _mware=self.MWARE) + (move_robot,) = self.receive_gaze_plane_coordinates( + gaze_plane_coordinates_port=self.GAZE_PLANE_COORDINATES_PORT, + _mware=self.MWARE, + ) if move_robot is not None and isinstance(move_robot, dict): robot_eye_speed = move_robot.get("eye_speed", False) if robot_eye_speed and isinstance(robot_eye_speed, dict): - self.update_eye_gaze_speed(**{"pitch": robot_eye_speed.get("pitch", 10.0), - "yaw": robot_eye_speed.get("yaw", 10.0), - "vergence": robot_eye_speed.get("vergence", 20.0), "_mware": self.MWARE} - if not self.ikingaze else {"eye": robot_eye_speed.get("eye", 0.5), "_mware": self.MWARE}) + self.update_eye_gaze_speed( + **( + { + "pitch": robot_eye_speed.get("pitch", 10.0), + "yaw": robot_eye_speed.get("yaw", 10.0), + "vergence": robot_eye_speed.get("vergence", 20.0), + "_mware": self.MWARE, + } + if not self.ikingaze + else { + "eye": robot_eye_speed.get("eye", 0.5), + "_mware": self.MWARE, + } + ) + ) robot_head_speed = move_robot.get("head_speed", False) if robot_head_speed and isinstance(robot_head_speed, dict): - self.update_head_gaze_speed(**{"pitch": robot_head_speed.get("pitch", 10.0), - "roll": robot_head_speed.get("roll", 10.0), - "yaw": robot_head_speed.get("yaw", 20.0), "_mware": self.MWARE} - if not self.ikingaze else {"head": robot_head_speed.get("head", 0.8), "_mware": self.MWARE}) + self.update_head_gaze_speed( + **( + { + "pitch": robot_head_speed.get("pitch", 10.0), + "roll": robot_head_speed.get("roll", 10.0), + "yaw": robot_head_speed.get("yaw", 20.0), + "_mware": self.MWARE, + } + if not self.ikingaze + else { + "head": robot_head_speed.get("head", 0.8), + "_mware": self.MWARE, + } + ) + ) if move_robot.get("reset_gaze", False): self.reset_gaze() - self.control_gaze_at_plane(x=move_robot.get("x", 0.0), y=move_robot.get("y", 0.0), - limit_x=move_robot.get("limit_x", 0.3), - limit_y=move_robot.get("limit_y", 0.3), - control_head=move_robot.get("control_head", - False if not self.ikingaze else True), - control_eyes=move_robot.get("control_eyes", True), _mware=self.MWARE), + self.control_gaze_at_plane( + x=move_robot.get("x", 0.0), + y=move_robot.get("y", 0.0), + limit_x=move_robot.get("limit_x", 0.3), + limit_y=move_robot.get("limit_y", 0.3), + control_head=move_robot.get( + "control_head", False if not self.ikingaze else True + ), + control_eyes=move_robot.get("control_eyes", True), + _mware=self.MWARE, + ), return True @@ -961,39 +1246,77 @@ def parse_args(): parser.add_argument("--simulation", action="store_true", help="Run in simulation") parser.add_argument("--headless", action="store_true", help="Disable CV2 GUI") parser.add_argument("--ikingaze", action="store_true", help="Enable iKinGazeCtrl") - parser.add_argument("--get_cam_feed", action="store_true", help="Get the camera feeds from the robot") + parser.add_argument( + "--get_cam_feed", + action="store_true", + help="Get the camera feeds from the robot", + ) parser.add_argument("--control_head", action="store_true", help="Control the head") - parser.add_argument("--set_head_coordinates", action="store_true", - help="Publish head coordinates set using keyboard commands") - parser.add_argument("--head_coordinates_port", type=str, default="", - help="The port (topic) name used for receiving and transmitting head orientation " - "Setting the port name without --set_head_coordinates will only receive the coordinates") + parser.add_argument( + "--set_head_coordinates", + action="store_true", + help="Publish head coordinates set using keyboard commands", + ) + parser.add_argument( + "--head_coordinates_port", + type=str, + default="", + help="The port (topic) name used for receiving and transmitting head orientation " + "Setting the port name without --set_head_coordinates will only receive the coordinates", + ) parser.add_argument("--control_eyes", action="store_true", help="Control the eyes") - parser.add_argument("--set_eye_coordinates", action="store_true", - help="Publish eye coordinates set using keyboard commands") - parser.add_argument("--eye_coordinates_port", type=str, default="", - help="The port (topic) name used for receiving and transmitting eye orientation " - "Setting the port name without --set_eye_coordinates will only receive the coordinates") - parser.add_argument("--gaze_plane_coordinates_port", type=str, default="", - help="The port (topic) name used for receiving plane coordinates in 2D for robot to look at") - parser.add_argument("--control_expressions", action="store_true", help="Control the facial expressions") - parser.add_argument("--set_facial_expressions", action="store_true", - help="Publish facial expressions set using keyboard commands") - parser.add_argument("--facial_expressions_port", type=str, default="", - help="The port (topic) name used for receiving and transmitting facial expressions. " - "Setting the port name without --set_facial_expressions will only receive the facial expressions") - parser.add_argument("--mware", type=str, default=ICUB_DEFAULT_COMMUNICATOR, - help="The middleware used for communication. " - "This can be overriden by providing either of the following environment variables " - "{WRAPYFI_DEFAULT_COMMUNICATOR, WRAPYFI_DEFAULT_MWARE, " - "ICUB_DEFAULT_COMMUNICATOR, ICUB_DEFAULT_MWARE}. Defaults to 'yarp'", - choices=MiddlewareCommunicator.get_communicators()) + parser.add_argument( + "--set_eye_coordinates", + action="store_true", + help="Publish eye coordinates set using keyboard commands", + ) + parser.add_argument( + "--eye_coordinates_port", + type=str, + default="", + help="The port (topic) name used for receiving and transmitting eye orientation " + "Setting the port name without --set_eye_coordinates will only receive the coordinates", + ) + parser.add_argument( + "--gaze_plane_coordinates_port", + type=str, + default="", + help="The port (topic) name used for receiving plane coordinates in 2D for robot to look at", + ) + parser.add_argument( + "--control_expressions", + action="store_true", + help="Control the facial expressions", + ) + parser.add_argument( + "--set_facial_expressions", + action="store_true", + help="Publish facial expressions set using keyboard commands", + ) + parser.add_argument( + "--facial_expressions_port", + type=str, + default="", + help="The port (topic) name used for receiving and transmitting facial expressions. " + "Setting the port name without --set_facial_expressions will only receive the facial expressions", + ) + parser.add_argument( + "--mware", + type=str, + default=ICUB_DEFAULT_COMMUNICATOR, + help="The middleware used for communication. " + "This can be overriden by providing either of the following environment variables " + "{WRAPYFI_DEFAULT_COMMUNICATOR, WRAPYFI_DEFAULT_MWARE, " + "ICUB_DEFAULT_COMMUNICATOR, ICUB_DEFAULT_MWARE}. Defaults to 'yarp'", + choices=MiddlewareCommunicator.get_communicators(), + ) return parser.parse_args() if __name__ == "__main__": args = parse_args() - assert not (args.headless and (args.set_facial_expressions or args.set_head_eye_coordinates)), \ - "setters require a CV2 window for capturing keystrokes. Disable --set-... for running in headless mode" + assert not ( + args.headless and (args.set_facial_expressions or args.set_head_eye_coordinates) + ), "setters require a CV2 window for capturing keystrokes. Disable --set-... for running in headless mode" controller = ICub(**vars(args)) - controller.runModule() \ No newline at end of file + controller.runModule() diff --git a/examples/sensors/cam_mic.py b/examples/sensors/cam_mic.py index a56687e..34ad2ff 100755 --- a/examples/sensors/cam_mic.py +++ b/examples/sensors/cam_mic.py @@ -65,9 +65,20 @@ class CamMic(MiddlewareCommunicator): __registry__ = {} - def __init__(self, *args, stream=("audio", "video"), mic_source=0, - mic_rate=44100, mic_chunk=10000, mic_channels=1, img_source=0, - img_width=320, img_height=240, mware=None, **kwargs): + def __init__( + self, + *args, + stream=("audio", "video"), + mic_source=0, + mic_rate=44100, + mic_chunk=10000, + mic_channels=1, + img_source=0, + img_width=320, + img_height=240, + mware=None, + **kwargs, + ): super(MiddlewareCommunicator, self).__init__() self.mic_source = mic_source self.mic_rate = mic_rate @@ -83,11 +94,22 @@ def __init__(self, *args, stream=("audio", "video"), mic_source=0, self.mware = mware - @MiddlewareCommunicator.register("Image", "$mware", "CamMic", "/cam_mic/cam_feed", - carrier="", width="$img_width", height="$img_height", rgb=True, jpg=True, - queue_size=10) + @MiddlewareCommunicator.register( + "Image", + "$mware", + "CamMic", + "/cam_mic/cam_feed", + carrier="", + width="$img_width", + height="$img_height", + rgb=True, + jpg=True, + queue_size=10, + ) def collect_cam(self, img_width=320, img_height=240, mware=None): - """Collect images from the camera.""" + """ + Collect images from the camera. + """ if self.vid_cap is True: self.vid_cap = cv2.VideoCapture(self.img_source) if img_width > 0 and img_height > 0: @@ -102,25 +124,53 @@ def collect_cam(self, img_width=320, img_height=240, mware=None): print("Video frame grabbed") else: print("Video frame not grabbed") - img = np.random.randint(256, size=(img_height, img_width, 3), dtype=np.uint8) + img = np.random.randint( + 256, size=(img_height, img_width, 3), dtype=np.uint8 + ) else: print("Video capturer not opened") - img = np.random.randint(256, size=(img_height, img_width, 3), dtype=np.uint8) - return img, + img = np.random.randint( + 256, size=(img_height, img_width, 3), dtype=np.uint8 + ) + return (img,) - @MiddlewareCommunicator.register("AudioChunk", "$mware", "CamMic", "/cam_mic/audio_feed", - carrier="", rate="$mic_rate", chunk="$mic_chunk", channels="$mic_channels") - def collect_mic(self, aud=None, mic_rate=44100, mic_chunk=int(44100 / 5), mic_channels=1, mware=None): - """Collect audio from the microphone.""" + @MiddlewareCommunicator.register( + "AudioChunk", + "$mware", + "CamMic", + "/cam_mic/audio_feed", + carrier="", + rate="$mic_rate", + chunk="$mic_chunk", + channels="$mic_channels", + ) + def collect_mic( + self, + aud=None, + mic_rate=44100, + mic_chunk=int(44100 / 5), + mic_channels=1, + mware=None, + ): + """ + Collect audio from the microphone. + """ aud = aud, mic_rate - return aud, + return (aud,) def capture_cam_mic(self): - """Capture audio and video from the camera and microphone.""" + """ + Capture audio and video from the camera and microphone. + """ if self.enable_audio: # capture the audio stream from the microphone - with sd.InputStream(device=self.mic_source, channels=self.mic_channels, callback=self._mic_callback, - blocksize=self.mic_chunk, samplerate=self.mic_rate): + with sd.InputStream( + device=self.mic_source, + channels=self.mic_channels, + callback=self._mic_callback, + blocksize=self.mic_chunk, + samplerate=self.mic_rate, + ): while True: pass elif self.enable_video: @@ -128,36 +178,92 @@ def capture_cam_mic(self): self.collect_cam(mware=self.mware) def _mic_callback(self, audio, frames, time, status): - """Callback for the microphone audio stream.""" + """ + Callback for the microphone audio stream. + """ if self.enable_video: - self.collect_cam(img_width=self.img_width, img_height=self.img_height, mware=self.mware) - self.collect_mic(audio, mic_rate=self.mic_rate, mic_chunk=self.mic_chunk, mic_channels=self.mic_channels, mware=self.mware) + self.collect_cam( + img_width=self.img_width, img_height=self.img_height, mware=self.mware + ) + self.collect_mic( + audio, + mic_rate=self.mic_rate, + mic_chunk=self.mic_chunk, + mic_channels=self.mic_channels, + mware=self.mware, + ) print(audio.flatten(), audio.min(), audio.mean(), audio.max()) def __del__(self): - """Release the video capture device.""" + """ + Release the video capture device. + """ if not isinstance(self.vid_cap, bool): self.vid_cap.release() def parse_args(): - """Parse command line arguments.""" - parser = argparse.ArgumentParser(description="A streamer and listener for audio and video streams.") - parser.add_argument("--mode", type=str, default="publish", choices={"publish", "listen"}, help="The transmission mode") - parser.add_argument("--mware", type=str, default=DEFAULT_COMMUNICATOR, choices=MiddlewareCommunicator.get_communicators(), - help="The middleware to use for transmission") - parser.add_argument("--stream", nargs="+", default=["video", "audio"], - choices={"video", "audio"}, - help="The streamed sensor data") - parser.add_argument("--img_source", type=int, default=0, help="The video capture device id (int camera id)") + """ + Parse command line arguments. + """ + parser = argparse.ArgumentParser( + description="A streamer and listener for audio and video streams." + ) + parser.add_argument( + "--mode", + type=str, + default="publish", + choices={"publish", "listen"}, + help="The transmission mode", + ) + parser.add_argument( + "--mware", + type=str, + default=DEFAULT_COMMUNICATOR, + choices=MiddlewareCommunicator.get_communicators(), + help="The middleware to use for transmission", + ) + parser.add_argument( + "--stream", + nargs="+", + default=["video", "audio"], + choices={"video", "audio"}, + help="The streamed sensor data", + ) + parser.add_argument( + "--img_source", + type=int, + default=0, + help="The video capture device id (int camera id)", + ) parser.add_argument("--img_width", type=int, default=320, help="The image width") parser.add_argument("--img_height", type=int, default=240, help="The image height") - parser.add_argument("--mic_source", type=int, default=None, help="The audio capture device id (int microphone id from python3 -m sounddevice)") - parser.add_argument("--mic_rate", type=int, default=44100, help="The audio sampling rate") - parser.add_argument("--mic_channels", type=int, default=1, help="The audio channels") - parser.add_argument("--mic_chunk", type=int, default=10000, help="The transmitted audio chunk size") - parser.add_argument("--sound_device", type=int, default=0, help="The sound device to use for audio playback") - parser.add_argument("--list_sound_devices", action="store_true", help="List all available sound devices and exit") + parser.add_argument( + "--mic_source", + type=int, + default=None, + help="The audio capture device id (int microphone id from python3 -m sounddevice)", + ) + parser.add_argument( + "--mic_rate", type=int, default=44100, help="The audio sampling rate" + ) + parser.add_argument( + "--mic_channels", type=int, default=1, help="The audio channels" + ) + parser.add_argument( + "--mic_chunk", type=int, default=10000, help="The transmitted audio chunk size" + ) + parser.add_argument( + "--sound_device", + type=int, + default=0, + help="The sound device to use for audio playback", + ) + parser.add_argument( + "--list_sound_devices", + action="store_true", + help="List all available sound devices and exit", + ) return parser.parse_args() @@ -174,7 +280,9 @@ def sound_play(my_aud, blocking=True, device=0): sd.play(*my_aud, blocking=blocking, device=device) return True except sd.PortAudioError: - logging.warning("PortAudioError: No device is found or the device is already in use. Will try again in 3 seconds.") + logging.warning( + "PortAudioError: No device is found or the device is already in use. Will try again in 3 seconds." + ) return False @@ -194,11 +302,21 @@ def main(args): cam_mic.activate_communication(CamMic.collect_mic, mode="listen") while True: if "audio" in args.stream: - (aud, mic_rate), = cam_mic.collect_mic(mic_rate=args.mic_rate, mic_chunk=args.mic_chunk, mic_channels=args.mic_channels, mware=args.mware) + ((aud, mic_rate),) = cam_mic.collect_mic( + mic_rate=args.mic_rate, + mic_chunk=args.mic_chunk, + mic_channels=args.mic_channels, + mware=args.mware, + ) else: aud = mic_rate = None if "video" in args.stream: - img, = cam_mic.collect_cam(img_source=args.img_source, img_width=args.img_width, img_height=args.img_height, mware=args.mware) + (img,) = cam_mic.collect_cam( + img_source=args.img_source, + img_width=args.img_width, + img_height=args.img_height, + mware=args.mware, + ) else: img = None if img is not None: diff --git a/setup.py b/setup.py index 9ca68ba..ae26782 100755 --- a/setup.py +++ b/setup.py @@ -8,21 +8,31 @@ def check_cv2(default_python="opencv-python"): import pkg_resources from packaging import version import cv2 + if version.parse(cv2.__version__) < version.parse(REQUIRED_CV2_VERSION): UPGRADE_CV2 = True raise ImportError(f"OpenCV version must be at least {REQUIRED_CV2_VERSION}") except ImportError as e: import pkg_resources + if UPGRADE_CV2: print(e, "Will try to upgrade OpenCV") if "opencv-python" in [p.project_name for p in pkg_resources.working_set]: additional_packages = [f"opencv-python>={REQUIRED_CV2_VERSION}"] - elif "opencv-contrib-python" in [p.project_name for p in pkg_resources.working_set]: + elif "opencv-contrib-python" in [ + p.project_name for p in pkg_resources.working_set + ]: additional_packages = [f"opencv-contrib-python>={REQUIRED_CV2_VERSION}"] - elif "opencv-python-headless" in [p.project_name for p in pkg_resources.working_set]: - additional_packages = [f"opencv-python-headless>={REQUIRED_CV2_VERSION}"] + elif "opencv-python-headless" in [ + p.project_name for p in pkg_resources.working_set + ]: + additional_packages = [ + f"opencv-python-headless>={REQUIRED_CV2_VERSION}" + ] else: - raise ImportError(f"Unknown OpenCV package installed. Please upgrade manually to version >={REQUIRED_CV2_VERSION}") + raise ImportError( + f"Unknown OpenCV package installed. Please upgrade manually to version >={REQUIRED_CV2_VERSION}" + ) else: print(f"OpenCV not found. Will try to install {default_python}") additional_packages = [f"{default_python}>={REQUIRED_CV2_VERSION}"] @@ -33,26 +43,53 @@ def check_cv2(default_python="opencv-python"): setuptools.setup( - name = 'wrapyfi', - version = '0.4.39', - description = 'Wrapyfi is a wrapper for simplifying Middleware communication', - url = 'https://github.com/fabawi/wrapyfi/blob/main/', + name="wrapyfi", + version="0.4.40", + description="Wrapyfi is a wrapper for simplifying Middleware communication", + url="https://github.com/fabawi/wrapyfi/blob/main/", project_urls={ - 'Documentation': 'https://wrapyfi.readthedocs.io/en/latest/', - 'Source': 'https://github.com/fabawi/wrapyfi/', - 'Tracker': 'https://github.com/fabawi/wrapyfi/issues', + "Documentation": "https://wrapyfi.readthedocs.io/en/latest/", + "Source": "https://github.com/fabawi/wrapyfi/", + "Tracker": "https://github.com/fabawi/wrapyfi/issues", + }, + author="Fares Abawi", + author_email="f.abawi@outlook.com", + maintainer="Fares Abawi", + maintainer_email="f.abawi@outlook.com", + packages=setuptools.find_packages(), + extras_require={ + "docs": ["sphinx", "sphinx_rtd_theme", "myst_parser"], + "pyzmq": ["pyzmq>=19.0.0"], + "numpy": ["numpy>=1.19.2"], + "headless": ["wrapyfi[pyzmq]", "wrapyfi[numpy]"] + + check_cv2("opencv-python-headless"), + "all": ["wrapyfi[pyzmq]", "wrapyfi[numpy]"] + + check_cv2("opencv-contrib-python"), }, - author = 'Fares Abawi', - author_email = 'f.abawi@outlook.com', - maintainer = 'Fares Abawi', - maintainer_email = 'f.abawi@outlook.com', - packages = setuptools.find_packages(), - extras_require ={'docs': ['sphinx', 'sphinx_rtd_theme', 'myst_parser'], - 'pyzmq': ['pyzmq>=19.0.0'], - 'numpy': ['numpy>=1.19.2'], - 'headless': ['wrapyfi[pyzmq]', 'wrapyfi[numpy]'] + check_cv2("opencv-python-headless"), - 'all': ['wrapyfi[pyzmq]', 'wrapyfi[numpy]'] + check_cv2("opencv-contrib-python")}, - install_requires = ['pyyaml>=5.1.1'], - python_requires = '>=3.6', - setup_requires = ['cython>=0.29.1'] + install_requires=["pyyaml>=5.1.1"], + python_requires=">=3.6", + setup_requires=["cython>=0.29.1"], + classifiers=[ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "Intended Audience :: Robotics", + "Topic :: Middleware Wrapper", + "Topic :: Middlware :: ZeroMQ", + "Topic :: Middlware :: YARP", + "Topic :: Middlware :: ROS", + "Topic :: Middlware :: ROS 2", + "Topic :: Scientific/Engineering :: Deep Learning", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Robotics", + "Topic :: Scientific/Engineering :: Image Processing", + "License :: OSI Approved :: MIT License", + "Operating System :: Linux", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + ], ) diff --git a/upgrade_project.sh b/upgrade_project.sh index 181bb77..e0d679c 100755 --- a/upgrade_project.sh +++ b/upgrade_project.sh @@ -39,6 +39,10 @@ update_version_in_package_xml() { # GENERATE DOCUMENTATION ####################################################################################################################### +# refactor code with black +python3 -m pip install black +black . + # compile docs with sphinx cd docs python3 -m pip install -r requirements.txt diff --git a/wrapyfi/__init__.py b/wrapyfi/__init__.py index 003cc19..d04edee 100755 --- a/wrapyfi/__init__.py +++ b/wrapyfi/__init__.py @@ -7,7 +7,7 @@ def get_project_info_from_setup(): try: curr_dir = os.path.dirname(__file__) setup_path = os.path.join(curr_dir, "..", "setup.py") - with open(setup_path, 'r') as f: + with open(setup_path, "r") as f: content = f.read() except FileNotFoundError: return {} @@ -15,54 +15,72 @@ def get_project_info_from_setup(): version_match = re.search(r"version\s*=\s*['\"]([^'\"]*)['\"]", content) url_match = re.search(r"url\s*=\s*['\"]([^'\"]*)['\"]", content) doc_match = re.search(r"'Documentation':\s*['\"]([^'\"]*)['\"]", content) - + if not name_match or not version_match or not url_match: # raise RuntimeError("Unable to find name, version, or url string.") return {} - + return { - 'name': name_match.group(1), - 'version': version_match.group(1), - 'url': url_match.group(1), - 'doc': None if not doc_match else doc_match.group(1) + "name": name_match.group(1), + "version": version_match.group(1), + "url": url_match.group(1), + "doc": None if not doc_match else doc_match.group(1), } # extract project info project_info = get_project_info_from_setup() -__version__ = project_info.get('version', None) -__url__ = project_info.get('url', None) -__doc__ = project_info.get('doc', None) -name = project_info.get('name', 'wrapyfi') +__version__ = project_info.get("version", None) +__url__ = project_info.get("url", None) +__doc__ = project_info.get("doc", None) +name = project_info.get("name", "wrapyfi") if __version__ is None or __url__ is None or __doc__ is None: try: from importlib import metadata + mdata = metadata.metadata(__name__) __version__ = metadata.version(__name__) __url__ = mdata["Home-page"] # when installed with PyPi if __url__ is None: for url_extract in mdata.get_all("Project-URL"): - __url__ = url_extract.split(", ")[1] if url_extract.split(", ")[0] == "Homepage" else __url__ + __url__ = ( + url_extract.split(", ")[1] + if url_extract.split(", ")[0] == "Homepage" + else __url__ + ) if __doc__ is None: for url_extract in mdata.get_all("Project-URL"): - __doc__ = url_extract.split(", ")[1] if url_extract.split(", ")[0] == "Documentation" else __doc__ + __doc__ = ( + url_extract.split(", ")[1] + if url_extract.split(", ")[0] == "Documentation" + else __doc__ + ) except ImportError: try: # when Python < 3.8 and setuptools/pip have not been updated import pkg_resources + mdata = pkg_resources.get_distribution(__name__).metadata __version__ = pkg_resources.require(__name__)[0].version __url__ = mdata["Home-page"] # when installed with PyPi if __url__ is None: for url_extract in mdata.get_all("Project-URL"): - __url__ = url_extract.split(", ")[1] if url_extract.split(", ")[0] == "Homepage" else __url__ + __url__ = ( + url_extract.split(", ")[1] + if url_extract.split(", ")[0] == "Homepage" + else __url__ + ) if __doc__ is None: for url_extract in mdata.get_all("Project-URL"): - __doc__ = url_extract.split(", ")[1] if url_extract.split(", ")[0] == "Documentation" else __doc__ + __doc__ = ( + url_extract.split(", ")[1] + if url_extract.split(", ")[0] == "Documentation" + else __doc__ + ) except pkg_resources.DistributionNotFound: __version__ = "unknown_version" __url__ = "unknown_url" @@ -77,4 +95,5 @@ def get_project_info_from_setup(): PluginRegistrar.scan() import logging + logging.getLogger().setLevel(logging.INFO) diff --git a/wrapyfi/clients/__init__.py b/wrapyfi/clients/__init__.py index b4c87df..13a768e 100755 --- a/wrapyfi/clients/__init__.py +++ b/wrapyfi/clients/__init__.py @@ -6,9 +6,18 @@ @Clients.register("MMO", "fallback") class FallbackClient(Client): - def __init__(self, name: str, in_topic: str, carrier: str = "", missing_middleware_object: str = "", **kwargs): - logging.warning(f"Fallback client employed due to missing middleware or object type: " - f"{missing_middleware_object}") + def __init__( + self, + name: str, + in_topic: str, + carrier: str = "", + missing_middleware_object: str = "", + **kwargs, + ): + logging.warning( + f"Fallback client employed due to missing middleware or object type: " + f"{missing_middleware_object}" + ) Client.__init__(self, name, in_topic, carrier=carrier, **kwargs) self.missing_middleware_object = missing_middleware_object @@ -25,4 +34,4 @@ def _await_reply(self): return None def close(self): - return None \ No newline at end of file + return None diff --git a/wrapyfi/clients/ros.py b/wrapyfi/clients/ros.py index ec02b12..3117597 100755 --- a/wrapyfi/clients/ros.py +++ b/wrapyfi/clients/ros.py @@ -9,13 +9,23 @@ import std_msgs.msg from wrapyfi.connect.clients import Client, Clients -from wrapyfi.middlewares.ros import ROSMiddleware, ROSNativeObjectService, ROSImageService +from wrapyfi.middlewares.ros import ( + ROSMiddleware, + ROSNativeObjectService, + ROSImageService, +) from wrapyfi.encoders import JsonEncoder, JsonDecodeHook class ROSClient(Client): - def __init__(self, name: str, in_topic: str, carrier: str = "tcp", - ros_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + in_topic: str, + carrier: str = "tcp", + ros_kwargs: Optional[dict] = None, + **kwargs, + ): """ Initialize the client. @@ -26,7 +36,9 @@ def __init__(self, name: str, in_topic: str, carrier: str = "tcp", :param kwargs: dict: Additional kwargs for the client """ if carrier or carrier != "tcp": - logging.warning("[ROS] ROS does not support other carriers than TCP for REQ/REP pattern. Using TCP.") + logging.warning( + "[ROS] ROS does not support other carriers than TCP for REQ/REP pattern. Using TCP." + ) carrier = "tcp" super().__init__(name, in_topic, carrier=carrier, **kwargs) ROSMiddleware.activate(**ros_kwargs or {}) @@ -45,9 +57,16 @@ def __del__(self): @Clients.register("NativeObject", "ros") class ROSNativeObjectClient(ROSClient): - def __init__(self, name: str, in_topic: str, carrier: str = "tcp", persistent: bool = True, - serializer_kwargs: Optional[dict] = None, - deserializer_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + in_topic: str, + carrier: str = "tcp", + persistent: bool = True, + serializer_kwargs: Optional[dict] = None, + deserializer_kwargs: Optional[dict] = None, + **kwargs, + ): """ The NativeObject client using the ROS String message assuming the data is serialized as a JSON string. Deserializes the data (including plugins) using the decoder and parses it to a Python object. @@ -76,7 +95,9 @@ def establish(self): Establish the client's connection to the ROS service. """ rospy.wait_for_service(self.in_topic) - self._client = rospy.ServiceProxy(self.in_topic, ROSNativeObjectService, persistent=self.persistent) + self._client = rospy.ServiceProxy( + self.in_topic, ROSNativeObjectService, persistent=self.persistent + ) if self.persistent: self.established = True @@ -103,13 +124,19 @@ def _request(self, *args, **kwargs): :param args: tuple: Positional arguments to send in the request. :param kwargs: dict: Keyword arguments to send in the request. """ - args_str = json.dumps([args, kwargs], cls=self._plugin_encoder, **self._plugin_kwargs, - serializer_kwrags=self._serializer_kwargs) + args_str = json.dumps( + [args, kwargs], + cls=self._plugin_encoder, + **self._plugin_kwargs, + serializer_kwrags=self._serializer_kwargs, + ) args_msg = std_msgs.msg.String() args_msg.data = args_str msg = self._client(args_msg) - obj = json.loads(msg.data, object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs) + obj = json.loads( + msg.data, object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs + ) self._queue.put(obj, block=False) def _await_reply(self) -> Any: @@ -122,16 +149,29 @@ def _await_reply(self) -> Any: reply = self._queue.get(block=True) return reply except queue.Full: - logging.warning(f"[ROS] Discarding data because queue is full. " - f"This happened due to bad synchronization in {self.__class__.__name__}") + logging.warning( + f"[ROS] Discarding data because queue is full. " + f"This happened due to bad synchronization in {self.__class__.__name__}" + ) return None @Clients.register("Image", "ros") class ROSImageClient(ROSClient): - def __init__(self, name: str, in_topic: str, carrier: str = "tcp", width: int = -1, height: int = -1, persistent: bool = True, - rgb: bool = True, fp: bool = False, serializer_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + in_topic: str, + carrier: str = "tcp", + width: int = -1, + height: int = -1, + persistent: bool = True, + rgb: bool = True, + fp: bool = False, + serializer_kwargs: Optional[dict] = None, + **kwargs, + ): """ The Image client using the ROS Image message parsed to a numpy array. @@ -146,7 +186,9 @@ def __init__(self, name: str, in_topic: str, carrier: str = "tcp", width: int = :param serializer_kwargs: dict: Additional kwargs for the serializer """ if "jpg" in kwargs: - logging.warning("[ROS] ROS currently does not support JPG encoding in REQ/REP. Using raw image.") + logging.warning( + "[ROS] ROS currently does not support JPG encoding in REQ/REP. Using raw image." + ) kwargs.pop("jpg") super().__init__(name, in_topic, carrier=carrier, **kwargs) self.width = width @@ -155,10 +197,10 @@ def __init__(self, name: str, in_topic: str, carrier: str = "tcp", width: int = self.fp = fp if self.fp: - self._encoding = '32FC3' if self.rgb else '32FC1' + self._encoding = "32FC3" if self.rgb else "32FC1" self._type = np.float32 else: - self._encoding = 'bgr8' if self.rgb else 'mono8' + self._encoding = "bgr8" if self.rgb else "mono8" self._type = np.uint8 self._pixel_bytes = (3 if self.rgb else 1) * np.dtype(self._type).itemsize @@ -176,7 +218,9 @@ def establish(self): Establish the client's connection to the ROS service. """ rospy.wait_for_service(self.in_topic) - self._client = rospy.ServiceProxy(self.in_topic, ROSImageService, persistent=self.persistent) + self._client = rospy.ServiceProxy( + self.in_topic, ROSImageService, persistent=self.persistent + ) if self.persistent: self.established = True @@ -203,12 +247,19 @@ def _request(self, *args, **kwargs): :param args: tuple: Positional arguments to send in the request :param kwargs: dict: Keyword arguments to send in the request """ - args_str = json.dumps([args, kwargs], cls=self._plugin_encoder, **self._plugin_kwargs, - serializer_kwrags=self._serializer_kwargs) + args_str = json.dumps( + [args, kwargs], + cls=self._plugin_encoder, + **self._plugin_kwargs, + serializer_kwrags=self._serializer_kwargs, + ) args_msg = std_msgs.msg.String() args_msg.data = args_str msg = self._client(args_msg) - self._queue.put((msg.height, msg.width, msg.encoding, msg.is_bigendian, msg.data), block=False) + self._queue.put( + (msg.height, msg.width, msg.encoding, msg.is_bigendian, msg.data), + block=False, + ) def _await_reply(self): """ @@ -220,23 +271,42 @@ def _await_reply(self): height, width, encoding, is_bigendian, data = self._queue.get(block=True) if encoding != self._encoding: raise ValueError("Incorrect encoding for listener") - if 0 < self.width != width or 0 < self.height != height or len(data) != height * width * self._pixel_bytes: + if ( + 0 < self.width != width + or 0 < self.height != height + or len(data) != height * width * self._pixel_bytes + ): raise ValueError("Incorrect image shape for listener") - img = np.frombuffer(data, dtype=np.dtype(self._type).newbyteorder('>' if is_bigendian else '<')).reshape((height, width, -1)) + img = np.frombuffer( + data, + dtype=np.dtype(self._type).newbyteorder(">" if is_bigendian else "<"), + ).reshape((height, width, -1)) if img.shape[2] == 1: img = img.squeeze(axis=2) return img except queue.Full: - logging.warning(f"[ROS] Discarding data because queue is full. " - f"This happened due to bad synchronization in {self.__name__}") + logging.warning( + f"[ROS] Discarding data because queue is full. " + f"This happened due to bad synchronization in {self.__name__}" + ) return None @Clients.register("AudioChunk", "ros") class ROSAudioChunkClient(ROSClient): - def __init__(self, name: str, in_topic: str, carrier: str = "tcp", persistent: bool = True, - channels: int = 1, rate: int = 44100, chunk: int = -1, serializer_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + in_topic: str, + carrier: str = "tcp", + persistent: bool = True, + channels: int = 1, + rate: int = 44100, + chunk: int = -1, + serializer_kwargs: Optional[dict] = None, + **kwargs, + ): """ The AudioChunk client using the ROS Audio message parsed to a numpy array. @@ -271,12 +341,18 @@ def establish(self): from wrapyfi_ros_interfaces.srv import ROSAudioService except ImportError: import wrapyfi - logging.error("[ROS] Could not import ROSAudioService. " - "Make sure the ROS services in wrapyfi_extensions/wrapyfi_ros_interfaces are compiled. " - "Refer to the documentation for more information: \n" + - wrapyfi.__doc__ + "ros_interfaces_lnk.html") + + logging.error( + "[ROS] Could not import ROSAudioService. " + "Make sure the ROS services in wrapyfi_extensions/wrapyfi_ros_interfaces are compiled. " + "Refer to the documentation for more information: \n" + + wrapyfi.__doc__ + + "ros_interfaces_lnk.html" + ) sys.exit(1) - self._client = rospy.ServiceProxy(self.in_topic, ROSAudioService, persistent=self.persistent) + self._client = rospy.ServiceProxy( + self.in_topic, ROSAudioService, persistent=self.persistent + ) self._req_msg = ROSAudioService._request_class() if self.persistent: self.established = True @@ -304,12 +380,26 @@ def _request(self, *args, **kwargs): :param args: tuple: Positional arguments to send in the request :param kwargs: dict: Keyword arguments to send in the request """ - args_str = json.dumps([args, kwargs], cls=self._plugin_encoder, **self._plugin_kwargs, - serializer_kwrags=self._serializer_kwargs) + args_str = json.dumps( + [args, kwargs], + cls=self._plugin_encoder, + **self._plugin_kwargs, + serializer_kwrags=self._serializer_kwargs, + ) args_msg = self._req_msg args_msg.request = args_str msg = self._client(args_msg).response - self._queue.put((msg.chunk_size, msg.channels, msg.sample_rate, msg.encoding, msg.is_bigendian, msg.data), block=False) + self._queue.put( + ( + msg.chunk_size, + msg.channels, + msg.sample_rate, + msg.encoding, + msg.is_bigendian, + msg.data, + ), + block=False, + ) def _await_reply(self): """ @@ -318,19 +408,28 @@ def _await_reply(self): :return: Tuple[np.array, int]: The received audio chunk and rate from the ROS service """ try: - chunk, channels, rate, encoding, is_bigendian, data = self._queue.get(block=True) + chunk, channels, rate, encoding, is_bigendian, data = self._queue.get( + block=True + ) if 0 < self.rate != rate: raise ValueError("Incorrect audio rate for client") - if encoding not in ['S16LE', 'S16BE']: + if encoding not in ["S16LE", "S16BE"]: raise ValueError("Incorrect encoding for client") - if 0 < self.chunk != chunk or self.channels != channels or len(data) != chunk * channels * 4: + if ( + 0 < self.chunk != chunk + or self.channels != channels + or len(data) != chunk * channels * 4 + ): raise ValueError("Incorrect audio shape for client") - aud = np.frombuffer(data, dtype=np.dtype(np.float32).newbyteorder('>' if is_bigendian else '<')).reshape( - (chunk, channels)) + aud = np.frombuffer( + data, + dtype=np.dtype(np.float32).newbyteorder(">" if is_bigendian else "<"), + ).reshape((chunk, channels)) # aud = aud / 32767.0 return aud, rate except queue.Full: - logging.warning(f"[ROS] Discarding data because queue is full. " - f"This happened due to bad synchronization in {self.__name__}") + logging.warning( + f"[ROS] Discarding data because queue is full. " + f"This happened due to bad synchronization in {self.__name__}" + ) return None, self.rate - diff --git a/wrapyfi/clients/ros2.py b/wrapyfi/clients/ros2.py index 291d84e..4190adb 100755 --- a/wrapyfi/clients/ros2.py +++ b/wrapyfi/clients/ros2.py @@ -20,7 +20,9 @@ class ROS2Client(Client, Node): - def __init__(self, name: str, in_topic: str, ros2_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, name: str, in_topic: str, ros2_kwargs: Optional[dict] = None, **kwargs + ): """ Initialize the client. @@ -32,7 +34,8 @@ def __init__(self, name: str, in_topic: str, ros2_kwargs: Optional[dict] = None, carrier = "tcp" if "carrier" in kwargs and kwargs["carrier"] not in ["", None]: logging.warning( - "[ROS 2] ROS 2 currently does not support explicit carrier setting for PUB/SUB pattern. Using TCP.") + "[ROS 2] ROS 2 currently does not support explicit carrier setting for PUB/SUB pattern. Using TCP." + ) if "carrier" in kwargs: del kwargs["carrier"] ROS2Middleware.activate(**ros2_kwargs or {}) @@ -53,9 +56,14 @@ def __del__(self): @Clients.register("NativeObject", "ros2") class ROS2NativeObjectClient(ROS2Client): - def __init__(self, name: str, in_topic: str, - serializer_kwargs: Optional[dict] = None, - deserializer_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + in_topic: str, + serializer_kwargs: Optional[dict] = None, + deserializer_kwargs: Optional[dict] = None, + **kwargs, + ): """ The NativeObject listener using the ROS 2 String message assuming the data is serialized as a JSON string. Deserializes the data (including plugins) using the decoder and parses it to a Python object. @@ -83,16 +91,20 @@ def establish(self): from wrapyfi_ros2_interfaces.srv import ROS2NativeObjectService except ImportError: import wrapyfi - logging.error("[ROS 2] Could not import ROS2NativeObjectService. " - "Make sure the ROS 2 services in wrapyfi_extensions/wrapyfi_ros2_interfaces are compiled. " - "Refer to the documentation for more information: \n" + - wrapyfi.__doc__ + "ros2_interfaces_lnk.html") + + logging.error( + "[ROS 2] Could not import ROS2NativeObjectService. " + "Make sure the ROS 2 services in wrapyfi_extensions/wrapyfi_ros2_interfaces are compiled. " + "Refer to the documentation for more information: \n" + + wrapyfi.__doc__ + + "ros2_interfaces_lnk.html" + ) sys.exit(1) self._client = self.create_client(ROS2NativeObjectService, self.in_topic) self._req_msg = ROS2NativeObjectService.Request() while not self._client.wait_for_service(timeout_sec=1.0): - logging.info('[ROS 2] Service not available, waiting again...') + logging.info("[ROS 2] Service not available, waiting again...") self.established = True def request(self, *args, **kwargs): @@ -119,8 +131,12 @@ def _request(self, *args, **kwargs): :param kwargs: dict: Keyword arguments to send in the request """ # transmit args to server - args_str = json.dumps([args, kwargs], cls=self._plugin_encoder, **self._plugin_kwargs, - serializer_kwrags=self._serializer_kwargs) + args_str = json.dumps( + [args, kwargs], + cls=self._plugin_encoder, + **self._plugin_kwargs, + serializer_kwrags=self._serializer_kwargs, + ) self._req_msg.request = args_str future = self._client.call_async(self._req_msg) # receive message from server @@ -129,7 +145,11 @@ def _request(self, *args, **kwargs): if future.done(): try: msg = future.result() - obj = json.loads(msg.response, object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs) + obj = json.loads( + msg.response, + object_hook=self._plugin_decoder_hook, + **self._deserializer_kwargs, + ) self._queue.put(obj, block=False) except Exception as e: logging.error("[ROS 2] Service call failed: %s" % e) @@ -145,15 +165,27 @@ def _await_reply(self) -> Any: reply = self._queue.get(block=True) return reply except queue.Full: - logging.warning(f"[ROS 2] Discarding data because queue is full. " - f"This happened due to bad synchronization in {self.__name__}") + logging.warning( + f"[ROS 2] Discarding data because queue is full. " + f"This happened due to bad synchronization in {self.__name__}" + ) return None @Clients.register("Image", "ros2") class ROS2ImageClient(ROS2Client): - def __init__(self, name: str, in_topic: str, width: int = -1, height: int = -1, - rgb: bool = True, fp: bool = False, jpg: bool = False, serializer_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + in_topic: str, + width: int = -1, + height: int = -1, + rgb: bool = True, + fp: bool = False, + jpg: bool = False, + serializer_kwargs: Optional[dict] = None, + **kwargs, + ): """ The Image client using the ROS 2 Image message parsed to a numpy array. @@ -177,10 +209,10 @@ def __init__(self, name: str, in_topic: str, width: int = -1, height: int = -1, self.jpg = jpg if self.fp: - self._encoding = '32FC3' if self.rgb else '32FC1' + self._encoding = "32FC3" if self.rgb else "32FC1" self._type = np.float32 else: - self._encoding = 'bgr8' if self.rgb else 'mono8' + self._encoding = "bgr8" if self.rgb else "mono8" self._type = np.uint8 self._pixel_bytes = (3 if self.rgb else 1) * np.dtype(self._type).itemsize @@ -196,13 +228,20 @@ def establish(self): Establish the client's connection to the ROS 2 service. """ try: - from wrapyfi_ros2_interfaces.srv import ROS2ImageService, ROS2CompressedImageService + from wrapyfi_ros2_interfaces.srv import ( + ROS2ImageService, + ROS2CompressedImageService, + ) except ImportError: import wrapyfi - logging.error("[ROS 2] Could not import ROS2ImageService. " - "Make sure the ROS 2 services in wrapyfi_extensions/wrapyfi_ros2_interfaces are compiled. " - "Refer to the documentation for more information: \n" + - wrapyfi.__doc__ + "ros2_interfaces_lnk.html") + + logging.error( + "[ROS 2] Could not import ROS2ImageService. " + "Make sure the ROS 2 services in wrapyfi_extensions/wrapyfi_ros2_interfaces are compiled. " + "Refer to the documentation for more information: \n" + + wrapyfi.__doc__ + + "ros2_interfaces_lnk.html" + ) sys.exit(1) if self.jpg: self._client = self.create_client(ROS2CompressedImageService, self.in_topic) @@ -212,7 +251,7 @@ def establish(self): self._req_msg = ROS2ImageService.Request() while not self._client.wait_for_service(timeout_sec=1.0): - logging.info('[ROS 2] Service not available, waiting again...') + logging.info("[ROS 2] Service not available, waiting again...") self.established = True def request(self, *args, **kwargs): @@ -239,8 +278,12 @@ def _request(self, *args, **kwargs): :param kwargs: dict: Keyword arguments to send in the request """ # transmit args to server - args_str = json.dumps([args, kwargs], cls=self._plugin_encoder, **self._plugin_kwargs, - serializer_kwrags=self._serializer_kwargs) + args_str = json.dumps( + [args, kwargs], + cls=self._plugin_encoder, + **self._plugin_kwargs, + serializer_kwrags=self._serializer_kwargs, + ) self._req_msg.request = args_str future = self._client.call_async(self._req_msg) # receive message from server @@ -253,8 +296,16 @@ def _request(self, *args, **kwargs): if self.jpg: self._queue.put((data.format, data.data), block=False) else: - self._queue.put((data.height, data.width, data.encoding, data.is_bigendian, data.data), - block=False) + self._queue.put( + ( + data.height, + data.width, + data.encoding, + data.is_bigendian, + data.data, + ), + block=False, + ) except Exception as e: logging.error("[ROS 2] Service call failed: %s" % e) break @@ -268,33 +319,55 @@ def _await_reply(self): try: if self.jpg: format, data = self._queue.get(block=True) - if format != 'jpeg': + if format != "jpeg": raise ValueError(f"Unsupported image format: {format}") if self.rgb: img = cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_COLOR) else: - img = cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_GRAYSCALE) + img = cv2.imdecode( + np.frombuffer(data, np.uint8), cv2.IMREAD_GRAYSCALE + ) else: - height, width, encoding, is_bigendian, data = self._queue.get(block=True) + height, width, encoding, is_bigendian, data = self._queue.get( + block=True + ) if encoding != self._encoding: raise ValueError("Incorrect encoding for listener") - if 0 < self.width != width or 0 < self.height != height or len(data) != height * width * self._pixel_bytes: + if ( + 0 < self.width != width + or 0 < self.height != height + or len(data) != height * width * self._pixel_bytes + ): raise ValueError("Incorrect image shape for listener") - img = np.frombuffer(data, dtype=np.dtype(self._type).newbyteorder('>' if is_bigendian else '<')).reshape((height, width, -1)) + img = np.frombuffer( + data, + dtype=np.dtype(self._type).newbyteorder( + ">" if is_bigendian else "<" + ), + ).reshape((height, width, -1)) if img.shape[2] == 1: img = img.squeeze(axis=2) return img except queue.Full: - logging.warning(f"[ROS 2] Discarding data because queue is full. " - f"This happened due to bad synchronization in {self.__name__}") + logging.warning( + f"[ROS 2] Discarding data because queue is full. " + f"This happened due to bad synchronization in {self.__name__}" + ) return None @Clients.register("AudioChunk", "ros2") class ROS2AudioChunkClient(ROS2Client): - def __init__(self, name: str, in_topic: str, - channels: int = 1, rate: int = 44100, chunk: int = -1, - serializer_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + in_topic: str, + channels: int = 1, + rate: int = 44100, + chunk: int = -1, + serializer_kwargs: Optional[dict] = None, + **kwargs, + ): """ The AudioChunk client using the ROS 2 Audio message parsed to a numpy array. @@ -328,16 +401,20 @@ def establish(self): from wrapyfi_ros2_interfaces.srv import ROS2AudioService except ImportError: import wrapyfi - logging.error("[ROS 2] Could not import ROS2AudioService. " - "Make sure the ROS 2 services in wrapyfi_extensions/wrapyfi_ros2_interfaces are compiled. " - "Refer to the documentation for more information: \n" + - wrapyfi.__doc__ + "ros2_interfaces_lnk.html") + + logging.error( + "[ROS 2] Could not import ROS2AudioService. " + "Make sure the ROS 2 services in wrapyfi_extensions/wrapyfi_ros2_interfaces are compiled. " + "Refer to the documentation for more information: \n" + + wrapyfi.__doc__ + + "ros2_interfaces_lnk.html" + ) sys.exit(1) self._client = self.create_client(ROS2AudioService, self.in_topic) self._req_msg = ROS2AudioService.Request() while not self._client.wait_for_service(timeout_sec=1.0): - logging.info('[ROS 2] Service not available, waiting again...') + logging.info("[ROS 2] Service not available, waiting again...") self.established = True def request(self, *args, **kwargs): @@ -363,8 +440,12 @@ def _request(self, *args, **kwargs): :param args: tuple: Positional arguments to send in the request :param kwargs: dict: Keyword arguments to send in the request """ - args_str = json.dumps([args, kwargs], cls=self._plugin_encoder, **self._plugin_kwargs, - serializer_kwrags=self._serializer_kwargs) + args_str = json.dumps( + [args, kwargs], + cls=self._plugin_encoder, + **self._plugin_kwargs, + serializer_kwrags=self._serializer_kwargs, + ) self._req_msg.request = args_str future = self._client.call_async(self._req_msg) @@ -374,8 +455,17 @@ def _request(self, *args, **kwargs): try: msg = future.result() data = msg.response - self._queue.put((data.chunk_size, data.channels, data.sample_rate, data.encoding, data.is_bigendian, data.data), - block=False) + self._queue.put( + ( + data.chunk_size, + data.channels, + data.sample_rate, + data.encoding, + data.is_bigendian, + data.data, + ), + block=False, + ) except Exception as e: logging.error("[ROS 2] Service call failed: %s" % e) break @@ -387,19 +477,28 @@ def _await_reply(self): :return: Tuple[np.ndarray, int]: The received message as a numpy array formatted as (np.ndarray[audio_chunk, channels], int[samplerate]) """ try: - chunk, channels, rate, encoding, is_bigendian, data = self._queue.get(block=False) + chunk, channels, rate, encoding, is_bigendian, data = self._queue.get( + block=False + ) if 0 < self.rate != rate: raise ValueError("Incorrect audio rate for publisher") - if encoding not in ['S16LE', 'S16BE']: + if encoding not in ["S16LE", "S16BE"]: raise ValueError("Incorrect encoding for listener") - if 0 < self.chunk != chunk or self.channels != channels or len(data) != chunk * channels * 4: + if ( + 0 < self.chunk != chunk + or self.channels != channels + or len(data) != chunk * channels * 4 + ): raise ValueError("Incorrect audio shape for listener") - aud = np.frombuffer(data, dtype=np.dtype(np.float32).newbyteorder('>' if is_bigendian else '<')).reshape( - (chunk, channels)) + aud = np.frombuffer( + data, + dtype=np.dtype(np.float32).newbyteorder(">" if is_bigendian else "<"), + ).reshape((chunk, channels)) # aud = aud / 32767.0 return aud, rate except queue.Full: - logging.warning(f"[ROS 2] Discarding data because queue is full. " - f"This happened due to bad synchronization in {self.__name__}") + logging.warning( + f"[ROS 2] Discarding data because queue is full. " + f"This happened due to bad synchronization in {self.__name__}" + ) return None - diff --git a/wrapyfi/clients/yarp.py b/wrapyfi/clients/yarp.py index 7b53fee..4ffe3a2 100755 --- a/wrapyfi/clients/yarp.py +++ b/wrapyfi/clients/yarp.py @@ -14,8 +14,15 @@ class YarpClient(Client): - def __init__(self, name: str, in_topic: str, carrier: Literal["tcp", "udp", "mcast"] = "tcp", - persistent: bool = True, yarp_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + in_topic: str, + carrier: Literal["tcp", "udp", "mcast"] = "tcp", + persistent: bool = True, + yarp_kwargs: Optional[dict] = None, + **kwargs, + ): """ Initialize the client. @@ -30,6 +37,7 @@ def __init__(self, name: str, in_topic: str, carrier: Literal["tcp", "udp", "mca YarpMiddleware.activate(**yarp_kwargs or {}) self.persistent = persistent + def close(self): """ Close the client. @@ -44,9 +52,16 @@ def __del__(self): @Clients.register("NativeObject", "yarp") class YarpNativeObjectClient(YarpClient): - def __init__(self, name: str, in_topic: str, carrier: Literal["tcp", "udp", "mcast"] = "tcp", - persistent: bool = True, - serializer_kwargs: Optional[dict] = None, deserializer_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + in_topic: str, + carrier: Literal["tcp", "udp", "mcast"] = "tcp", + persistent: bool = True, + serializer_kwargs: Optional[dict] = None, + deserializer_kwargs: Optional[dict] = None, + **kwargs, + ): """ The NativeObject listener using the YARP Bottle construct assuming the data is serialized as a JSON string. Deserializes the data (including plugins) using the decoder and parses it to a Python object. @@ -61,7 +76,9 @@ def __init__(self, name: str, in_topic: str, carrier: Literal["tcp", "udp", "mca :param deserializer_kwargs: dict: Additional kwargs for the deserializer :param kwargs: dict: Additional kwargs for the client """ - super().__init__(name, in_topic, carrier=carrier, persistent=persistent, **kwargs) + super().__init__( + name, in_topic, carrier=carrier, persistent=persistent, **kwargs + ) self._port = None self._queue = queue.Queue(maxsize=1) @@ -108,8 +125,12 @@ def _request(self, *args, **kwargs): :param args: tuple: Positional arguments to send in the request :param kwargs: dict: Keyword arguments to send in the request """ - args_str = json.dumps([args, kwargs], cls=self._plugin_encoder, **self._plugin_kwargs, - serializer_kwrags=self._serializer_kwargs) + args_str = json.dumps( + [args, kwargs], + cls=self._plugin_encoder, + **self._plugin_kwargs, + serializer_kwrags=self._serializer_kwargs, + ) args_msg = yarp.Bottle() args_msg.clear() args_msg.addString(args_str) @@ -118,7 +139,11 @@ def _request(self, *args, **kwargs): msg.clear() self._port.write(args_msg, msg) - obj = json.loads(msg.get(0).asString(), object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs) + obj = json.loads( + msg.get(0).asString(), + object_hook=self._plugin_decoder_hook, + **self._deserializer_kwargs, + ) self._queue.put(obj, block=False) def _await_reply(self): @@ -131,16 +156,28 @@ def _await_reply(self): reply = self._queue.get(block=True) return reply except queue.Full: - logging.warning(f"[YARP] Discarding data because queue is full. " - f"This happened due to bad synchronization in {self.__name__}") + logging.warning( + f"[YARP] Discarding data because queue is full. " + f"This happened due to bad synchronization in {self.__name__}" + ) return None @Clients.register("Image", "yarp") class YarpImageClient(YarpNativeObjectClient): - def __init__(self, name: str, in_topic: str, carrier: Literal["tcp", "udp", "mcast"] = "tcp", - width: int = -1, height: int = -1, rgb: bool = True, fp: bool = False, - persistent: bool = True, serializer_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + in_topic: str, + carrier: Literal["tcp", "udp", "mcast"] = "tcp", + width: int = -1, + height: int = -1, + rgb: bool = True, + fp: bool = False, + persistent: bool = True, + serializer_kwargs: Optional[dict] = None, + **kwargs, + ): """ The Image client using the YARP Bottle construct parsed to a numpy array. @@ -155,9 +192,18 @@ def __init__(self, name: str, in_topic: str, carrier: Literal["tcp", "udp", "mca :param serializer_kwargs: dict: Additional kwargs for the serializer """ if "jpg" in kwargs: - logging.warning("[YARP] YARP currently does not support JPG encoding in REQ/REP. Using raw image.") + logging.warning( + "[YARP] YARP currently does not support JPG encoding in REQ/REP. Using raw image." + ) kwargs.pop("jpg") - super().__init__(name, in_topic, carrier=carrier, persistent=persistent, serializer_kwargs=serializer_kwargs, **kwargs) + super().__init__( + name, + in_topic, + carrier=carrier, + persistent=persistent, + serializer_kwargs=serializer_kwargs, + **kwargs, + ) self.width = width self.height = height self.rgb = rgb @@ -170,15 +216,23 @@ def _request(self, *args, **kwargs): :param args: tuple: Positional arguments to send in the request :param kwargs: dict: Keyword arguments to send in the request """ - args_str = json.dumps([args, kwargs], cls=self._plugin_encoder, **self._plugin_kwargs, - serializer_kwrags=self._serializer_kwargs) + args_str = json.dumps( + [args, kwargs], + cls=self._plugin_encoder, + **self._plugin_kwargs, + serializer_kwrags=self._serializer_kwargs, + ) args_msg = yarp.Bottle() args_msg.clear() args_msg.addString(args_str) msg = yarp.Bottle() msg.clear() self._port.write(args_msg, msg) - img = json.loads(msg.get(0).asString(), object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs) + img = json.loads( + msg.get(0).asString(), + object_hook=self._plugin_decoder_hook, + **self._deserializer_kwargs, + ) height, width, channels = img.shape if 0 < self.width != width or 0 < self.height != height: raise ValueError("Incorrect image shape for client") @@ -188,9 +242,18 @@ def _request(self, *args, **kwargs): @Clients.register("AudioChunk", "yarp") class YarpAudioChunkClient(YarpNativeObjectClient): - def __init__(self, name: str, in_topic: str, carrier: Literal["tcp", "udp", "mcast"] = "tcp", - channels: int = 1, rate: int = 44100, chunk: int = -1, - persistent: bool = True, serializer_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + in_topic: str, + carrier: Literal["tcp", "udp", "mcast"] = "tcp", + channels: int = 1, + rate: int = 44100, + chunk: int = -1, + persistent: bool = True, + serializer_kwargs: Optional[dict] = None, + **kwargs, + ): """ The AudioChunk client using the YARP Bottle construct parsed to a numpy array. @@ -203,7 +266,14 @@ def __init__(self, name: str, in_topic: str, carrier: Literal["tcp", "udp", "mca :param persistent: bool: Whether to keep the service connection alive across multiple service calls. Default is True :param serializer_kwargs: dict: Additional kwargs for the serializer """ - super().__init__(name, in_topic, carrier=carrier, persistent=persistent, serializer_kwargs=serializer_kwargs, **kwargs) + super().__init__( + name, + in_topic, + carrier=carrier, + persistent=persistent, + serializer_kwargs=serializer_kwargs, + **kwargs, + ) self.channels = channels self.rate = rate self.chunk = chunk @@ -215,18 +285,30 @@ def _request(self, *args, **kwargs): :param args: tuple: Positional arguments to send in the request :param kwargs: dict: Keyword arguments to send in the request """ - args_str = json.dumps([args, kwargs], cls=self._plugin_encoder, **self._plugin_kwargs, - serializer_kwrags=self._serializer_kwargs) + args_str = json.dumps( + [args, kwargs], + cls=self._plugin_encoder, + **self._plugin_kwargs, + serializer_kwrags=self._serializer_kwargs, + ) args_msg = yarp.Bottle() args_msg.clear() args_msg.addString(args_str) msg = yarp.Bottle() msg.clear() self._port.write(args_msg, msg) - chunk, channels, rate, aud = json.loads(msg.get(0).asString(), object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs) + chunk, channels, rate, aud = json.loads( + msg.get(0).asString(), + object_hook=self._plugin_decoder_hook, + **self._deserializer_kwargs, + ) if 0 < self.rate != rate: raise ValueError("Incorrect audio rate for client") - if 0 < self.chunk != chunk or self.channels != channels or aud.size != chunk * channels: + if ( + 0 < self.chunk != chunk + or self.channels != channels + or aud.size != chunk * channels + ): raise ValueError("Incorrect audio shape for client") else: - self._queue.put((aud, rate), block=False) \ No newline at end of file + self._queue.put((aud, rate), block=False) diff --git a/wrapyfi/clients/zeromq.py b/wrapyfi/clients/zeromq.py index 8c6f692..5d93839 100644 --- a/wrapyfi/clients/zeromq.py +++ b/wrapyfi/clients/zeromq.py @@ -18,9 +18,16 @@ class ZeroMQClient(Client): - def __init__(self, name, in_topic, carrier="tcp", - socket_ip: str = SOCKET_IP, socket_rep_port: int = SOCKET_REP_PORT, - zeromq_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name, + in_topic, + carrier="tcp", + socket_ip: str = SOCKET_IP, + socket_rep_port: int = SOCKET_REP_PORT, + zeromq_kwargs: Optional[dict] = None, + **kwargs, + ): """ Initialize the client. @@ -31,10 +38,14 @@ def __init__(self, name, in_topic, carrier="tcp", :param kwargs: dict: Additional kwargs for the client """ if in_topic != "": - logging.warning(f"[ZeroMQ] ZeroMQ does not support topics for the REQ/REP pattern. Topic {in_topic} removed") + logging.warning( + f"[ZeroMQ] ZeroMQ does not support topics for the REQ/REP pattern. Topic {in_topic} removed" + ) in_topic = "" if carrier or carrier != "tcp": - logging.warning("[ZeroMQ] ZeroMQ does not support other carriers than TCP for REQ/REP pattern. Using TCP.") + logging.warning( + "[ZeroMQ] ZeroMQ does not support other carriers than TCP for REQ/REP pattern. Using TCP." + ) carrier = "tcp" super().__init__(name, in_topic, carrier=carrier, **kwargs) @@ -56,9 +67,15 @@ def __del__(self): @Clients.register("NativeObject", "zeromq") class ZeroMQNativeObjectClient(ZeroMQClient): - def __init__(self, name: str, in_topic: str, carrier: str = "tcp", - serializer_kwargs: Optional[dict] = None, - deserializer_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + in_topic: str, + carrier: str = "tcp", + serializer_kwargs: Optional[dict] = None, + deserializer_kwargs: Optional[dict] = None, + **kwargs, + ): """ Specific client for handling native Python objects, serializing them to JSON strings for transmission. @@ -86,9 +103,13 @@ def establish(self, **kwargs): self._socket = zmq.Context().instance().socket(zmq.REQ) for socket_property in ZeroMQMiddlewareReqRep().zeromq_kwargs.items(): if isinstance(socket_property[1], str): - self._socket.setsockopt_string(getattr(zmq, socket_property[0]), socket_property[1]) + self._socket.setsockopt_string( + getattr(zmq, socket_property[0]), socket_property[1] + ) else: - self._socket.setsockopt(getattr(zmq, socket_property[0]), socket_property[1]) + self._socket.setsockopt( + getattr(zmq, socket_property[0]), socket_property[1] + ) self._socket.connect(self.socket_address) self.established = True @@ -115,11 +136,15 @@ def _request(self, *args, **kwargs): :param args: tuple: Arguments to be serialized and sent :param kwargs: dict: Keyword arguments to be serialized and sent """ - args_str = json.dumps([args, kwargs], cls=self._plugin_encoder, **self._serializer_kwargs) + args_str = json.dumps( + [args, kwargs], cls=self._plugin_encoder, **self._serializer_kwargs + ) self._socket.send_string(args_str) obj_str = self._socket.recv_string() - obj = json.loads(obj_str, object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs) + obj = json.loads( + obj_str, object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs + ) self._queue.put(obj, block=False) def _await_reply(self): @@ -132,16 +157,28 @@ def _await_reply(self): reply = self._queue.get(block=True) return reply except queue.Empty: - logging.warning(f"[ZeroMQ] Discarding data because queue is empty. " - f"This happened due to bad synchronization in {self.__class__.__name__}") + logging.warning( + f"[ZeroMQ] Discarding data because queue is empty. " + f"This happened due to bad synchronization in {self.__class__.__name__}" + ) return None @Clients.register("Image", "zeromq") class ZeroMQImageClient(ZeroMQNativeObjectClient): - def __init__(self, name: str, in_topic: str, carrier: str = "tcp", - width: int = -1, height: int = -1, rgb: bool = True, fp: bool = False, jpg: bool = False, - serializer_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + in_topic: str, + carrier: str = "tcp", + width: int = -1, + height: int = -1, + rgb: bool = True, + fp: bool = False, + jpg: bool = False, + serializer_kwargs: Optional[dict] = None, + **kwargs, + ): """ The Image client using the ZeroMQ message construct parsed to a numpy array. @@ -155,7 +192,13 @@ def __init__(self, name: str, in_topic: str, carrier: str = "tcp", :param jpg: bool: True if the image should be compressed to JPG before sending. Default is False :param serializer_kwargs: dict: Additional kwargs for the serializer """ - super().__init__(name, in_topic, carrier=carrier, serializer_kwargs=serializer_kwargs, **kwargs) + super().__init__( + name, + in_topic, + carrier=carrier, + serializer_kwargs=serializer_kwargs, + **kwargs, + ) self.width = width self.height = height self.rgb = rgb @@ -171,12 +214,16 @@ def _request(self, *args, **kwargs): :param args: tuple: Arguments to be serialized and sent :param kwargs: dict: Keyword arguments to be serialized and sent """ - args_str = json.dumps([args, kwargs], cls=self._plugin_encoder, **self._serializer_kwargs) + args_str = json.dumps( + [args, kwargs], cls=self._plugin_encoder, **self._serializer_kwargs + ) self._socket.send_string(args_str) if self.jpg: reply_bytes = self._socket.recv() - reply_img = cv2.imdecode(np.frombuffer(reply_bytes, np.uint8), cv2.IMREAD_ANYCOLOR) + reply_img = cv2.imdecode( + np.frombuffer(reply_bytes, np.uint8), cv2.IMREAD_ANYCOLOR + ) else: reply_str = self._socket.recv_string() reply_img_list = json.loads(reply_str) @@ -192,20 +239,34 @@ def _await_reply(self): try: img = self._queue.get(block=True) height, width, channels = img.shape - if 0 < self.width != width or 0 < self.height != height or img.size != height * width * (3 if self.rgb else 1): + if ( + 0 < self.width != width + or 0 < self.height != height + or img.size != height * width * (3 if self.rgb else 1) + ): raise ValueError("Incorrect image shape for subscriber") return img except queue.Empty: - logging.warning(f"[ZeroMQ] Discarding data because queue is empty. " - f"This happened due to bad synchronization in {self.__class__.__name__}") + logging.warning( + f"[ZeroMQ] Discarding data because queue is empty. " + f"This happened due to bad synchronization in {self.__class__.__name__}" + ) return None @Clients.register("AudioChunk", "zeromq") class ZeroMQAudioChunkClient(ZeroMQNativeObjectClient): - def __init__(self, name: str, in_topic: str, carrier: str = "tcp", - channels: int = 1, rate: int = 44100, chunk: int = -1, - serializer_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + in_topic: str, + carrier: str = "tcp", + channels: int = 1, + rate: int = 44100, + chunk: int = -1, + serializer_kwargs: Optional[dict] = None, + **kwargs, + ): """ The AudioChunk client using the ZeroMQ message construct parsed to a numpy array. @@ -219,7 +280,13 @@ def __init__(self, name: str, in_topic: str, carrier: str = "tcp", :param jpg: bool: True if the image should be compressed to JPG before sending. Default is False :param serializer_kwargs: dict: Additional kwargs for the serializer """ - super().__init__(name, in_topic, carrier=carrier, serializer_kwargs=serializer_kwargs, **kwargs) + super().__init__( + name, + in_topic, + carrier=carrier, + serializer_kwargs=serializer_kwargs, + **kwargs, + ) self.channels = channels self.rate = rate self.chunk = chunk @@ -231,7 +298,9 @@ def _request(self, *args, **kwargs): :param args: tuple: Arguments to be serialized and sent :param kwargs: dict: Keyword arguments to be serialized and sent """ - args_str = json.dumps([args, kwargs], cls=self._plugin_encoder, **self._serializer_kwargs) + args_str = json.dumps( + [args, kwargs], cls=self._plugin_encoder, **self._serializer_kwargs + ) self._socket.send_string(args_str) reply_str = self._socket.recv_string() @@ -250,10 +319,16 @@ def _await_reply(self): chunk, channels, rate, aud = self._queue.get(block=True) if 0 < self.rate != rate: raise ValueError("Incorrect audio rate for listener") - if 0 < self.chunk != chunk or self.channels != channels or aud.size != chunk * channels: + if ( + 0 < self.chunk != chunk + or self.channels != channels + or aud.size != chunk * channels + ): raise ValueError("Incorrect audio shape for listener") return aud, rate except queue.Empty: - logging.warning(f"[ZeroMQ] Discarding data because queue is empty. " - f"This happened due to bad synchronization in {self.__class__.__name__}") - return None \ No newline at end of file + logging.warning( + f"[ZeroMQ] Discarding data because queue is empty. " + f"This happened due to bad synchronization in {self.__class__.__name__}" + ) + return None diff --git a/wrapyfi/config/manager.py b/wrapyfi/config/manager.py index 36df4fa..2943fee 100755 --- a/wrapyfi/config/manager.py +++ b/wrapyfi/config/manager.py @@ -9,6 +9,7 @@ class ConfigManager(metaclass=SingletonOptimized): """ The configuration manager is a singleton which is invoked once throughout the runtime. """ + def __init__(self, config: Optional[Union[dict, str]], **kwargs): """ Initializing the ConfigManager. The configuration can be provided as a yaml file name or as a dictionary. @@ -44,4 +45,3 @@ def __writefile(self, filename: str): """ with open(filename, "w") as fp: yaml.safe_dump(self.config, fp) - diff --git a/wrapyfi/connect/__init__.py b/wrapyfi/connect/__init__.py index fe067bf..20f5431 100755 --- a/wrapyfi/connect/__init__.py +++ b/wrapyfi/connect/__init__.py @@ -6,4 +6,4 @@ Listeners.scan() Publishers.scan() Servers.scan() -Clients.scan() \ No newline at end of file +Clients.scan() diff --git a/wrapyfi/connect/clients.py b/wrapyfi/connect/clients.py index 7a9a6c4..45666f2 100755 --- a/wrapyfi/connect/clients.py +++ b/wrapyfi/connect/clients.py @@ -9,6 +9,7 @@ class Clients(object): """ A class that holds all clients and their corresponding middleware communicators. """ + registry = {} mwares = set() @@ -20,10 +21,12 @@ def register(cls, data_type: str, communicator: str): :param data_type: str: The data type to register the client for e.g., "NativeObject", "Image", "AudioChunk", etc. :param communicator: str: The middleware communicator to register the client for e.g., "ros", "ros2", "yarp", "zeromq", etc. """ + def decorator(cls_): cls.registry[data_type + ":" + communicator] = cls_ cls.mwares.add(communicator) return cls_ + return decorator @staticmethod @@ -31,9 +34,15 @@ def scan(): """ Scan for clients and add them to the registry. """ - modules = glob(os.path.join(os.path.dirname(__file__), "..", "clients", "*.py"), recursive=True) - modules = ["wrapyfi.clients." + module.replace(os.path.dirname(__file__) + "/../clients/", "") for module in - modules] + modules = glob( + os.path.join(os.path.dirname(__file__), "..", "clients", "*.py"), + recursive=True, + ) + modules = [ + "wrapyfi.clients." + + module.replace(os.path.dirname(__file__) + "/../clients/", "") + for module in modules + ] dynamic_module_import(modules, globals()) @@ -41,6 +50,7 @@ class Client(object): """ A base class for clients. """ + def __init__(self, name: str, in_topic: str, carrier: str = "", **kwargs): """ Initialize the client. diff --git a/wrapyfi/connect/listeners.py b/wrapyfi/connect/listeners.py index fe4c595..54fcdfe 100755 --- a/wrapyfi/connect/listeners.py +++ b/wrapyfi/connect/listeners.py @@ -9,6 +9,7 @@ class ListenerWatchDog(metaclass=SingletonOptimized): """ A watchdog that scans for listeners and removes them from the ring if they are not established. """ + def __init__(self, repeats: int = 10, inner_repeats: int = 10): """ Initialize the ListenerWatchDog. @@ -53,6 +54,7 @@ class Listeners(object): """ A class that holds all listeners and their corresponding middleware communicators. """ + registry = {} mwares = set() @@ -64,10 +66,12 @@ def register(cls, data_type: str, communicator: str): :param data_type: str: The data type to register the listener for e.g., "NativeObject", "Image", "AudioChunk", etc. :param communicator: str: The middleware communicator to register the listener for e.g., "ros", "ros2", "yarp", "zeromq", etc. """ + def decorator(cls_): cls.registry[data_type + ":" + communicator] = cls_ cls.mwares.add(communicator) return cls_ + return decorator @staticmethod @@ -75,9 +79,15 @@ def scan(): """ Scan for listeners and add them to the registry. """ - modules = glob(os.path.join(os.path.dirname(__file__), "..", "listeners", "*.py"), recursive=True) - modules = ["wrapyfi.listeners." + module.replace(os.path.dirname(__file__) + "/../listeners/", "") for module in - modules] + modules = glob( + os.path.join(os.path.dirname(__file__), "..", "listeners", "*.py"), + recursive=True, + ) + modules = [ + "wrapyfi.listeners." + + module.replace(os.path.dirname(__file__) + "/../listeners/", "") + for module in modules + ] dynamic_module_import(modules, globals()) @@ -85,7 +95,15 @@ class Listener(object): """ A base class for listeners. """ - def __init__(self, name: str, in_topic: str, carrier: str = "", should_wait: bool = True, **kwargs): + + def __init__( + self, + name: str, + in_topic: str, + carrier: str = "", + should_wait: bool = True, + **kwargs, + ): """ Initialize the Listener. diff --git a/wrapyfi/connect/publishers.py b/wrapyfi/connect/publishers.py index edd9b6b..de92df1 100755 --- a/wrapyfi/connect/publishers.py +++ b/wrapyfi/connect/publishers.py @@ -9,6 +9,7 @@ class PublisherWatchDog(metaclass=SingletonOptimized): """ A watchdog that scans for publishers and removes them from the ring if they are not established. """ + def __init__(self, repeats: int = 10, inner_repeats: int = 10): """ Initialize the PublisherWatchDog. @@ -53,6 +54,7 @@ class Publishers(object): """ A class that holds all publishers and their corresponding middleware communicators. """ + registry = {} mwares = set() @@ -65,10 +67,12 @@ def register(cls, data_type: str, communicator: str): :param communicator: str: The middleware communicator to register the publisher for e.g., "ros", "ros2", "yarp", "zeromq", etc. :return: Callable[..., Any]: A decorator function that registers the decorated class as a publisher for the given data type and middleware communicator """ + def decorator(cls_): cls.registry[data_type + ":" + communicator] = cls_ cls.mwares.add(communicator) return cls_ + return decorator @staticmethod @@ -76,9 +80,15 @@ def scan(): """ Scan for publishers and add them to the registry. """ - modules = glob(os.path.join(os.path.dirname(__file__), "..", "publishers", "*.py"), recursive=True) - modules = ["wrapyfi.publishers." + module.replace(os.path.dirname(__file__) + "/../publishers/", "") for module in - modules] + modules = glob( + os.path.join(os.path.dirname(__file__), "..", "publishers", "*.py"), + recursive=True, + ) + modules = [ + "wrapyfi.publishers." + + module.replace(os.path.dirname(__file__) + "/../publishers/", "") + for module in modules + ] dynamic_module_import(modules, globals()) @@ -86,7 +96,15 @@ class Publisher(object): """ A base class for all publishers. """ - def __init__(self, name: str, out_topic: str, carrier: str = "", should_wait: bool = True, **kwargs): + + def __init__( + self, + name: str, + out_topic: str, + carrier: str = "", + should_wait: bool = True, + **kwargs, + ): """ Initialize the Publisher. diff --git a/wrapyfi/connect/servers.py b/wrapyfi/connect/servers.py index 4f25171..831a924 100755 --- a/wrapyfi/connect/servers.py +++ b/wrapyfi/connect/servers.py @@ -10,6 +10,7 @@ class Servers(object): """ A class that holds all servers and their corresponding middleware communicators. """ + registry = {} mwares = set() @@ -22,10 +23,12 @@ def register(cls, data_type: str, communicator: str): :param communicator: str: The middleware communicator to register the server for e.g., "ros", "ros2", "yarp", "zeromq", etc. :return: Callable: A decorator that registers the server with the given data type and middleware communicator """ + def decorator(cls_): cls.registry[data_type + ":" + communicator] = cls_ cls.mwares.add(communicator) return cls_ + return decorator @staticmethod @@ -33,9 +36,15 @@ def scan(): """ Scan for servers and add them to the registry. """ - modules = glob(os.path.join(os.path.dirname(__file__), "..", "servers", "*.py"), recursive=True) - modules = ["wrapyfi.servers." + module.replace(os.path.dirname(__file__) + "/../servers/", "") for module in - modules] + modules = glob( + os.path.join(os.path.dirname(__file__), "..", "servers", "*.py"), + recursive=True, + ) + modules = [ + "wrapyfi.servers." + + module.replace(os.path.dirname(__file__) + "/../servers/", "") + for module in modules + ] dynamic_module_import(modules, globals()) @@ -43,7 +52,15 @@ class Server(object): """ A base class for servers. """ - def __init__(self, name: str, out_topic: str, carrier: str = "", out_topic_connect: Optional[str] = None, **kwargs): + + def __init__( + self, + name: str, + out_topic: str, + carrier: str = "", + out_topic_connect: Optional[str] = None, + **kwargs, + ): """ Initialize the server. @@ -55,7 +72,9 @@ def __init__(self, name: str, out_topic: str, carrier: str = "", out_topic_conne self.__name__ = name self.out_topic = out_topic self.carrier = carrier - self.out_topic_connect = out_topic + ":out" if out_topic_connect is None else out_topic_connect + self.out_topic_connect = ( + out_topic + ":out" if out_topic_connect is None else out_topic_connect + ) self.established = False def establish(self): diff --git a/wrapyfi/connect/wrapper.py b/wrapyfi/connect/wrapper.py index 0122bdd..6403ef2 100755 --- a/wrapyfi/connect/wrapper.py +++ b/wrapyfi/connect/wrapper.py @@ -29,7 +29,9 @@ def __init__(self): self.activate_communication(getattr(self.__class__, key), mode=value) @classmethod - def __trigger_publish(cls, func: Callable[..., Any], instance_id: str, kwd: dict, *wds, **kwds): + def __trigger_publish( + cls, func: Callable[..., Any], instance_id: str, kwd: dict, *wds, **kwds + ): """ Triggers the publish mode of the middleware communicator. @@ -42,60 +44,97 @@ def __trigger_publish(cls, func: Callable[..., Any], instance_id: str, kwd: dict :raises: KeyError: If the intended publisher type and middleware are unavailable, resorting to a fallback publisher """ - if "wrapped_executor" not in \ - cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["communicator"][0]: + if ( + "wrapped_executor" + not in cls._MiddlewareCommunicator__registry[ + func.__qualname__ + instance_id + ]["communicator"][0] + ): # instantiate the publishers - cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["communicator"].reverse() - for communicator in cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id][ - "communicator"]: + cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id][ + "communicator" + ].reverse() + for communicator in cls._MiddlewareCommunicator__registry[ + func.__qualname__ + instance_id + ]["communicator"]: # single element if isinstance(communicator["return_func_type"], str): - return_func_pub_kwargs = deepcopy(communicator["return_func_kwargs"]) - return_func_pub_kwargs.update(return_func_pub_kwargs.get("publisher_kwargs", {})) + return_func_pub_kwargs = deepcopy( + communicator["return_func_kwargs"] + ) + return_func_pub_kwargs.update( + return_func_pub_kwargs.get("publisher_kwargs", {}) + ) return_func_pub_kwargs.pop("listener_kwargs", None) return_func_pub_kwargs.pop("publisher_kwargs", None) new_args, new_kwargs = match_args( - communicator["return_func_args"], return_func_pub_kwargs, wds[1:], kwd) + communicator["return_func_args"], + return_func_pub_kwargs, + wds[1:], + kwd, + ) return_func_type = communicator["return_func_type"] - return_func_middleware = new_kwargs.pop("middleware", DEFAULT_COMMUNICATOR) + return_func_middleware = new_kwargs.pop( + "middleware", DEFAULT_COMMUNICATOR + ) try: communicator["wrapped_executor"] = pub.Publishers.registry[ - return_func_type + return_func_middleware](*new_args, **new_kwargs) + return_func_type + return_func_middleware + ](*new_args, **new_kwargs) except KeyError: communicator["wrapped_executor"] = pub.Publishers.registry[ - "MMO:fallback"](*new_args, - missing_middleware_object=return_func_type + return_func_middleware, - **new_kwargs) + "MMO:fallback" + ]( + *new_args, + missing_middleware_object=return_func_type + + return_func_middleware, + **new_kwargs, + ) communicator["return_func_type"] = "MMO:" # list for single return elif isinstance(communicator["return_func_type"], list): communicator["wrapped_executor"] = [] for comm_idx in range(len(communicator["return_func_type"])): - return_func_pub_kwargs = deepcopy(communicator["return_func_kwargs"][comm_idx]) - return_func_pub_kwargs.update(return_func_pub_kwargs.get("publisher_kwargs", {})) + return_func_pub_kwargs = deepcopy( + communicator["return_func_kwargs"][comm_idx] + ) + return_func_pub_kwargs.update( + return_func_pub_kwargs.get("publisher_kwargs", {}) + ) return_func_pub_kwargs.pop("listener_kwargs", None) return_func_pub_kwargs.pop("publisher_kwargs", None) new_args, new_kwargs = match_args( - communicator["return_func_args"][comm_idx], return_func_pub_kwargs, wds[1:], kwd) + communicator["return_func_args"][comm_idx], + return_func_pub_kwargs, + wds[1:], + kwd, + ) return_func_type = communicator["return_func_type"][comm_idx] - return_func_middleware = new_kwargs.pop("middleware", DEFAULT_COMMUNICATOR) + return_func_middleware = new_kwargs.pop( + "middleware", DEFAULT_COMMUNICATOR + ) try: communicator["wrapped_executor"].append( - pub.Publishers.registry[return_func_type + return_func_middleware](*new_args, - **new_kwargs)) + pub.Publishers.registry[ + return_func_type + return_func_middleware + ](*new_args, **new_kwargs) + ) except KeyError: communicator["wrapped_executor"].append( - pub.Publishers.registry[ - "MMO:fallback"](*new_args, - missing_middleware_object=return_func_type + return_func_middleware, - **new_kwargs)) + pub.Publishers.registry["MMO:fallback"]( + *new_args, + missing_middleware_object=return_func_type + + return_func_middleware, + **new_kwargs, + ) + ) communicator["return_func_type"][comm_idx] = "MMO:" returns = func(*wds, **kwds) for ret_idx, ret in enumerate(returns): - wrp_exec = \ - cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["communicator"][ret_idx][ - "wrapped_executor"] + wrp_exec = cls._MiddlewareCommunicator__registry[ + func.__qualname__ + instance_id + ]["communicator"][ret_idx]["wrapped_executor"] # single element if isinstance(wrp_exec, pub.Publisher): wrp_exec.publish(ret) @@ -106,7 +145,9 @@ def __trigger_publish(cls, func: Callable[..., Any], instance_id: str, kwd: dict return returns @classmethod - def __trigger_listen(cls, func: Callable[..., Any], instance_id: str, kwd: dict, *wds, **kwds): + def __trigger_listen( + cls, func: Callable[..., Any], instance_id: str, kwd: dict, *wds, **kwds + ): """ Triggers the listen mode of the middleware communicator. @@ -119,59 +160,104 @@ def __trigger_listen(cls, func: Callable[..., Any], instance_id: str, kwd: dict, :raises: KeyError: If the intended listener type and middleware are unavailable, resorting to a fallback listener """ - if "wrapped_executor" not in \ - cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["communicator"][0]: + if ( + "wrapped_executor" + not in cls._MiddlewareCommunicator__registry[ + func.__qualname__ + instance_id + ]["communicator"][0] + ): # instantiate the listeners - cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["communicator"].reverse() - for communicator in cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["communicator"]: + cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id][ + "communicator" + ].reverse() + for communicator in cls._MiddlewareCommunicator__registry[ + func.__qualname__ + instance_id + ]["communicator"]: # single element if isinstance(communicator["return_func_type"], str): - return_func_lsn_kwargs = deepcopy(communicator["return_func_kwargs"]) - return_func_lsn_kwargs.update(return_func_lsn_kwargs.get("listener_kwargs", {})) + return_func_lsn_kwargs = deepcopy( + communicator["return_func_kwargs"] + ) + return_func_lsn_kwargs.update( + return_func_lsn_kwargs.get("listener_kwargs", {}) + ) return_func_lsn_kwargs.pop("listener_kwargs", None) return_func_lsn_kwargs.pop("publisher_kwargs", None) - new_args, new_kwargs = match_args(communicator["return_func_args"], return_func_lsn_kwargs, wds[1:], - kwd) + new_args, new_kwargs = match_args( + communicator["return_func_args"], + return_func_lsn_kwargs, + wds[1:], + kwd, + ) return_func_type = communicator["return_func_type"] - return_func_middleware = new_kwargs.pop("middleware", DEFAULT_COMMUNICATOR) + return_func_middleware = new_kwargs.pop( + "middleware", DEFAULT_COMMUNICATOR + ) try: communicator["wrapped_executor"] = lsn.Listeners.registry[ - return_func_type + return_func_middleware](*new_args, **new_kwargs) + return_func_type + return_func_middleware + ](*new_args, **new_kwargs) except KeyError: communicator["wrapped_executor"] = lsn.Listeners.registry[ - "MMO:fallback"](*new_args, - missing_middleware_object=return_func_type + return_func_middleware, - **new_kwargs) + "MMO:fallback" + ]( + *new_args, + missing_middleware_object=return_func_type + + return_func_middleware, + **new_kwargs, + ) communicator["return_func_type"] = "MMO:" # list for single return elif isinstance(communicator["return_func_type"], list): communicator["wrapped_executor"] = [] for comm_idx in range(len(communicator["return_func_type"])): - return_func_lsn_kwargs = deepcopy(communicator["return_func_kwargs"][comm_idx]) - return_func_lsn_kwargs.update(return_func_lsn_kwargs.get("listener_kwargs", {})) + return_func_lsn_kwargs = deepcopy( + communicator["return_func_kwargs"][comm_idx] + ) + return_func_lsn_kwargs.update( + return_func_lsn_kwargs.get("listener_kwargs", {}) + ) return_func_lsn_kwargs.pop("listener_kwargs", None) return_func_lsn_kwargs.pop("publisher_kwargs", None) - new_args, new_kwargs = match_args(communicator["return_func_args"][comm_idx], - return_func_lsn_kwargs, wds[1:], kwd) + new_args, new_kwargs = match_args( + communicator["return_func_args"][comm_idx], + return_func_lsn_kwargs, + wds[1:], + kwd, + ) return_func_type = communicator["return_func_type"][comm_idx] - return_func_middleware = new_kwargs.pop("middleware", DEFAULT_COMMUNICATOR) + return_func_middleware = new_kwargs.pop( + "middleware", DEFAULT_COMMUNICATOR + ) try: communicator["wrapped_executor"].append( - lsn.Listeners.registry[return_func_type + return_func_middleware](*new_args, **new_kwargs)) + lsn.Listeners.registry[ + return_func_type + return_func_middleware + ](*new_args, **new_kwargs) + ) except KeyError: communicator["wrapped_executor"].append( - lsn.Listeners.registry[ - "MMO:fallback"](*new_args, - missing_middleware_object=return_func_type + return_func_middleware, - **new_kwargs)) + lsn.Listeners.registry["MMO:fallback"]( + *new_args, + missing_middleware_object=return_func_type + + return_func_middleware, + **new_kwargs, + ) + ) communicator["return_func_type"][comm_idx] = "MMO:" returns = [] for ret_idx in range( - len(cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["communicator"])): - wrp_exec = cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["communicator"][ret_idx][ - "wrapped_executor"] + len( + cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id][ + "communicator" + ] + ) + ): + wrp_exec = cls._MiddlewareCommunicator__registry[ + func.__qualname__ + instance_id + ]["communicator"][ret_idx]["wrapped_executor"] # single element if isinstance(wrp_exec, lsn.Listener): returns.append(wrp_exec.listen()) @@ -184,7 +270,9 @@ def __trigger_listen(cls, func: Callable[..., Any], instance_id: str, kwd: dict, return returns @classmethod - def __trigger_reply(cls, func: Callable[..., Any], instance_id: str, kwd, *wds, **kwds): + def __trigger_reply( + cls, func: Callable[..., Any], instance_id: str, kwd, *wds, **kwds + ): """ Triggers the reply mode of the middleware communicator. @@ -197,60 +285,99 @@ def __trigger_reply(cls, func: Callable[..., Any], instance_id: str, kwd, *wds, :raises: KeyError: If the intended server type and middleware are unavailable, resorting to a fallback server """ - if "wrapped_executor" not in \ - cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["communicator"][0]: + if ( + "wrapped_executor" + not in cls._MiddlewareCommunicator__registry[ + func.__qualname__ + instance_id + ]["communicator"][0] + ): # instantiate the publishers - cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["communicator"].reverse() - for communicator in cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["communicator"]: + cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id][ + "communicator" + ].reverse() + for communicator in cls._MiddlewareCommunicator__registry[ + func.__qualname__ + instance_id + ]["communicator"]: # single element if isinstance(communicator["return_func_type"], str): - return_func_pub_kwargs = deepcopy(communicator["return_func_kwargs"]) - return_func_pub_kwargs.update(return_func_pub_kwargs.get("publisher_kwargs", {})) + return_func_pub_kwargs = deepcopy( + communicator["return_func_kwargs"] + ) + return_func_pub_kwargs.update( + return_func_pub_kwargs.get("publisher_kwargs", {}) + ) return_func_pub_kwargs.pop("listener_kwargs", None) return_func_pub_kwargs.pop("publisher_kwargs", None) new_args, new_kwargs = match_args( - communicator["return_func_args"], return_func_pub_kwargs, wds[1:], kwd) + communicator["return_func_args"], + return_func_pub_kwargs, + wds[1:], + kwd, + ) return_func_type = communicator["return_func_type"] - return_func_middleware = new_kwargs.pop("middleware", DEFAULT_COMMUNICATOR) + return_func_middleware = new_kwargs.pop( + "middleware", DEFAULT_COMMUNICATOR + ) try: communicator["wrapped_executor"] = srv.Servers.registry[ - return_func_type + return_func_middleware](*new_args, **new_kwargs) + return_func_type + return_func_middleware + ](*new_args, **new_kwargs) except KeyError: communicator["wrapped_executor"] = srv.Servers.registry[ - "MMO:fallback"](*new_args, - missing_middleware_object=return_func_type + return_func_middleware, - **new_kwargs) + "MMO:fallback" + ]( + *new_args, + missing_middleware_object=return_func_type + + return_func_middleware, + **new_kwargs, + ) communicator["return_func_type"] = "MMO:" # list for single return elif isinstance(communicator["return_func_type"], list): communicator["wrapped_executor"] = [] for comm_idx in range(len(communicator["return_func_type"])): - return_func_pub_kwargs = deepcopy(communicator["return_func_kwargs"][comm_idx]) + return_func_pub_kwargs = deepcopy( + communicator["return_func_kwargs"][comm_idx] + ) return_func_pub_kwargs.update( - return_func_pub_kwargs.get("publisher_kwargs", {})) + return_func_pub_kwargs.get("publisher_kwargs", {}) + ) return_func_pub_kwargs.pop("listener_kwargs", None) return_func_pub_kwargs.pop("publisher_kwargs", None) new_args, new_kwargs = match_args( - communicator["return_func_args"][comm_idx], return_func_pub_kwargs, wds[1:], - kwd) + communicator["return_func_args"][comm_idx], + return_func_pub_kwargs, + wds[1:], + kwd, + ) return_func_type = communicator["return_func_type"][comm_idx] - return_func_middleware = new_kwargs.pop("middleware", DEFAULT_COMMUNICATOR) + return_func_middleware = new_kwargs.pop( + "middleware", DEFAULT_COMMUNICATOR + ) try: communicator["wrapped_executor"].append( - srv.Servers.registry[return_func_type + return_func_middleware]( - *new_args, **new_kwargs)) + srv.Servers.registry[ + return_func_type + return_func_middleware + ](*new_args, **new_kwargs) + ) except KeyError: communicator["wrapped_executor"].append( - srv.Servers.registry[ - "MMO:fallback"](*new_args, - missing_middleware_object=return_func_type + return_func_middleware, - **new_kwargs)) + srv.Servers.registry["MMO:fallback"]( + *new_args, + missing_middleware_object=return_func_type + + return_func_middleware, + **new_kwargs, + ) + ) communicator["return_func_type"][comm_idx] = "MMO:" returns = None for ret_idx, functor in enumerate( - cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["communicator"]): + cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id][ + "communicator" + ] + ): wrp_exec = functor["wrapped_executor"] # single element if isinstance(wrp_exec, srv.Server): @@ -270,7 +397,9 @@ def __trigger_reply(cls, func: Callable[..., Any], instance_id: str, kwd, *wds, return returns @classmethod - def __trigger_request(cls, func: Callable[..., Any], instance_id: str, kwd, *wds, **kwds): + def __trigger_request( + cls, func: Callable[..., Any], instance_id: str, kwd, *wds, **kwds + ): """ Triggers the request mode of the middleware communicator. @@ -283,57 +412,99 @@ def __trigger_request(cls, func: Callable[..., Any], instance_id: str, kwd, *wds :raises: KeyError: If the intended client type and middleware are unavailable, resorting to a fallback client """ - if "wrapped_executor" not in \ - cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["communicator"][0]: + if ( + "wrapped_executor" + not in cls._MiddlewareCommunicator__registry[ + func.__qualname__ + instance_id + ]["communicator"][0] + ): # instantiate the listeners - cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["communicator"].reverse() - for communicator in cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["communicator"]: + cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id][ + "communicator" + ].reverse() + for communicator in cls._MiddlewareCommunicator__registry[ + func.__qualname__ + instance_id + ]["communicator"]: # single element if isinstance(communicator["return_func_type"], str): - return_func_lsn_kwargs = deepcopy(communicator["return_func_kwargs"]) - return_func_lsn_kwargs.update(return_func_lsn_kwargs.get("listener_kwargs", {})) + return_func_lsn_kwargs = deepcopy( + communicator["return_func_kwargs"] + ) + return_func_lsn_kwargs.update( + return_func_lsn_kwargs.get("listener_kwargs", {}) + ) return_func_lsn_kwargs.pop("listener_kwargs", None) return_func_lsn_kwargs.pop("publisher_kwargs", None) - new_args, new_kwargs = match_args(communicator["return_func_args"], return_func_lsn_kwargs, wds[1:], - kwd) + new_args, new_kwargs = match_args( + communicator["return_func_args"], + return_func_lsn_kwargs, + wds[1:], + kwd, + ) return_func_type = communicator["return_func_type"] - return_func_middleware = new_kwargs.pop("middleware", DEFAULT_COMMUNICATOR) + return_func_middleware = new_kwargs.pop( + "middleware", DEFAULT_COMMUNICATOR + ) try: communicator["wrapped_executor"] = clt.Clients.registry[ - return_func_type + return_func_middleware](*new_args, **new_kwargs) + return_func_type + return_func_middleware + ](*new_args, **new_kwargs) except KeyError: communicator["wrapped_executor"] = clt.Clients.registry[ - "MMO:fallback"](*new_args, - missing_middleware_object=return_func_type + return_func_middleware, - **new_kwargs) + "MMO:fallback" + ]( + *new_args, + missing_middleware_object=return_func_type + + return_func_middleware, + **new_kwargs, + ) communicator["return_func_type"] = "MMO:" # list for single return elif isinstance(communicator["return_func_type"], list): communicator["wrapped_executor"] = [] for comm_idx in range(len(communicator["return_func_type"])): - return_func_lsn_kwargs = deepcopy(communicator["return_func_kwargs"][comm_idx]) - return_func_lsn_kwargs.update(return_func_lsn_kwargs.get("listener_kwargs", {})) + return_func_lsn_kwargs = deepcopy( + communicator["return_func_kwargs"][comm_idx] + ) + return_func_lsn_kwargs.update( + return_func_lsn_kwargs.get("listener_kwargs", {}) + ) return_func_lsn_kwargs.pop("listener_kwargs", None) return_func_lsn_kwargs.pop("publisher_kwargs", None) - new_args, new_kwargs = match_args(communicator["return_func_args"][comm_idx], - return_func_lsn_kwargs, wds[1:], kwd) + new_args, new_kwargs = match_args( + communicator["return_func_args"][comm_idx], + return_func_lsn_kwargs, + wds[1:], + kwd, + ) return_func_type = communicator["return_func_type"][comm_idx] - return_func_middleware = new_kwargs.pop("middleware", DEFAULT_COMMUNICATOR) + return_func_middleware = new_kwargs.pop( + "middleware", DEFAULT_COMMUNICATOR + ) try: communicator["wrapped_executor"].append( - clt.Clients.registry[return_func_type + return_func_middleware](*new_args, **new_kwargs)) + clt.Clients.registry[ + return_func_type + return_func_middleware + ](*new_args, **new_kwargs) + ) except KeyError: communicator["wrapped_executor"].append( - clt.Clients.registry[ - "MMO:fallback"](*new_args, - missing_middleware_object=return_func_type + return_func_middleware, - **new_kwargs)) + clt.Clients.registry["MMO:fallback"]( + *new_args, + missing_middleware_object=return_func_type + + return_func_middleware, + **new_kwargs, + ) + ) communicator["return_func_type"][comm_idx] = "MMO:" returns = [] for ret_idx, functor in enumerate( - cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["communicator"]): + cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id][ + "communicator" + ] + ): wrp_exec = functor["wrapped_executor"] # single element if isinstance(wrp_exec, clt.Client): @@ -349,7 +520,13 @@ def __trigger_request(cls, func: Callable[..., Any], instance_id: str, kwd, *wds return returns @classmethod - def register(cls, data_type: Union[str, List[Any]], middleware: str = DEFAULT_COMMUNICATOR, *args, **kwargs): + def register( + cls, + data_type: Union[str, List[Any]], + middleware: str = DEFAULT_COMMUNICATOR, + *args, + **kwargs, + ): """ Registers a function to the middleware communicator, defining its communication message type and associated middleware. Note that the function returned is a wrapper that can alter the behavior of the registered function @@ -363,6 +540,7 @@ def register(cls, data_type: Union[str, List[Any]], middleware: str = DEFAULT_CO :raises: NotImplementedError: If `data_type` is a dictionary or an unsupported type """ + def encapsulate(func): # define the communication message type (single element) if isinstance(data_type, str): @@ -376,29 +554,45 @@ def encapsulate(func): return_func_args, return_func_kwargs, return_func_type = [], [], [] for arg in data_type: data_spec = arg[0] + ":" - return_func_args.append([a for a in arg[2:] if not isinstance(a, dict)]) - return_func_kwargs.append(*[a for a in arg[2:] if isinstance(a, dict)]) + return_func_args.append( + [a for a in arg[2:] if not isinstance(a, dict)] + ) + return_func_kwargs.append( + *[a for a in arg[2:] if isinstance(a, dict)] + ) return_func_kwargs[-1]["middleware"] = str(arg[1]) return_func_type.append(data_spec) # define the communication message type (dict for single return). NOTE: supports 1 layer depth only elif isinstance(data_type, dict): - raise NotImplementedError("Dictionaries are not yet supported as a return type") + raise NotImplementedError( + "Dictionaries are not yet supported as a return type" + ) else: - raise NotImplementedError(f"Return data type not supported: {data_type}") + raise NotImplementedError( + f"Return data type not supported: {data_type}" + ) func_qualname = func.__qualname__ if func_qualname in cls.__registry: - cls.__registry[func_qualname]["communicator"].append({ - "return_func_args": return_func_args, - "return_func_kwargs": return_func_kwargs, - "return_func_type": return_func_type}) + cls.__registry[func_qualname]["communicator"].append( + { + "return_func_args": return_func_args, + "return_func_kwargs": return_func_kwargs, + "return_func_type": return_func_type, + } + ) else: - cls.__registry[func_qualname] = {"communicator": [{ - "return_func_args": return_func_args, - "return_func_kwargs": return_func_kwargs, - "return_func_type": return_func_type}]} + cls.__registry[func_qualname] = { + "communicator": [ + { + "return_func_args": return_func_args, + "return_func_kwargs": return_func_kwargs, + "return_func_type": return_func_type, + } + ] + } cls.__registry[func_qualname]["mode"] = None @wraps(func) @@ -408,47 +602,101 @@ def wrapper(*wds, **kwds): # triggers on calling the method instance_address = hex(id(wds[0])) try: - instance_id = cls._MiddlewareCommunicator__registry[func.__qualname__]["__WRAPYFI_INSTANCES"].index(instance_address) + 1 + instance_id = ( + cls._MiddlewareCommunicator__registry[func.__qualname__][ + "__WRAPYFI_INSTANCES" + ].index(instance_address) + + 1 + ) instance_id = "" if instance_id <= 1 else "." + str(instance_id) except KeyError: instance_id = "" # execute the method as usual - if cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["mode"] is None: + if ( + cls._MiddlewareCommunicator__registry[ + func.__qualname__ + instance_id + ]["mode"] + is None + ): return func(*wds, **kwds) kwd = get_default_args(func) kwd.update(kwds) - cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["args"] = wds - cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["kwargs"] = kwd + cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id][ + "args" + ] = wds + cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id][ + "kwargs" + ] = kwd # publishes the method returns - if cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["mode"] == "publish": + if ( + cls._MiddlewareCommunicator__registry[ + func.__qualname__ + instance_id + ]["mode"] + == "publish" + ): return cls.__trigger_publish(func, instance_id, kwd, *wds, **kwds) # listens to the publisher and returns the messages - elif cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["mode"] == "listen": + elif ( + cls._MiddlewareCommunicator__registry[ + func.__qualname__ + instance_id + ]["mode"] + == "listen" + ): return cls.__trigger_listen(func, instance_id, kwd, *wds, **kwds) # server awaits request from client and replies with method returns - elif cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["mode"] == "reply": + elif ( + cls._MiddlewareCommunicator__registry[ + func.__qualname__ + instance_id + ]["mode"] + == "reply" + ): return cls.__trigger_reply(func, instance_id, kwd, *wds, **kwds) # client requests with args from server and awaits reply - elif cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["mode"] == "request": + elif ( + cls._MiddlewareCommunicator__registry[ + func.__qualname__ + instance_id + ]["mode"] + == "request" + ): return cls.__trigger_request(func, instance_id, kwd, *wds, **kwds) # WARNING: use with caution. This produces "None" for all the method's returns - elif cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["mode"] == "disable": - cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["last_results"] = [] - for ret_idx in range(len(cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["communicator"])): - cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["last_results"].append(None) - return cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["last_results"] + elif ( + cls._MiddlewareCommunicator__registry[ + func.__qualname__ + instance_id + ]["mode"] + == "disable" + ): + cls._MiddlewareCommunicator__registry[ + func.__qualname__ + instance_id + ]["last_results"] = [] + for ret_idx in range( + len( + cls._MiddlewareCommunicator__registry[ + func.__qualname__ + instance_id + ]["communicator"] + ) + ): + cls._MiddlewareCommunicator__registry[ + func.__qualname__ + instance_id + ]["last_results"].append(None) + return cls._MiddlewareCommunicator__registry[ + func.__qualname__ + instance_id + ]["last_results"] return wrapper + return encapsulate - def activate_communication(self, func: Union[str, Callable[..., Any]], mode: Union[str, List[str]]): + def activate_communication( + self, func: Union[str, Callable[..., Any]], mode: Union[str, List[str]] + ): """ Activates the communication mode for a registered function in the middleware communicator. The mode determines how the function will interact with the middleware communicator upon invocation, @@ -477,16 +725,24 @@ def activate_communication(self, func: Union[str, Callable[..., Any]], mode: Uni wrapyfi_instances.append(instance_addr) instance_id = len(wrapyfi_instances) if instance_id > 1: - self.__registry[f"{func.__qualname__}.{instance_id}"] = deepcopy(entry, exclude_keys=["wrapped_executor"]) + self.__registry[f"{func.__qualname__}.{instance_id}"] = ( + deepcopy(entry, exclude_keys=["wrapped_executor"]) + ) - instance_qualname = f"{func.__qualname__}.{instance_id}" if instance_id > 1 else func.__qualname__ + instance_qualname = ( + f"{func.__qualname__}.{instance_id}" + if instance_id > 1 + else func.__qualname__ + ) if isinstance(mode, list): try: - self.__registry[instance_qualname]["mode"] = mode[instance_id-1] + self.__registry[instance_qualname]["mode"] = mode[instance_id - 1] except IndexError: - raise IndexError("When mode (publish|listen|disable|null) specified in configuration file is a " - "list, No. of elements in the list should match the number of instances") + raise IndexError( + "When mode (publish|listen|disable|null) specified in configuration file is a " + "list, No. of elements in the list should match the number of instances" + ) else: self.__registry[instance_qualname]["mode"] = mode self.__registry[instance_qualname]["instance_addr"] = instance_addr @@ -538,11 +794,17 @@ def close_instance(cls, instance_addr: Optional[str] = None): for entry_key, entry_val in cls._MiddlewareCommunicator__registry.items(): if instance_addr in entry_val.get("__WRAPYFI_INSTANCES", []): if not del_entry: - del_entry_idx = entry_val["__WRAPYFI_INSTANCES"].index(instance_addr) + del_entry_idx = entry_val["__WRAPYFI_INSTANCES"].index( + instance_addr + ) if del_entry_idx == 0: del_entry = entry_key else: - del_entry = re.sub("\.\d+", "\.", entry_key) + "." + str(del_entry_idx + 1) + del_entry = ( + re.sub("\.\d+", "\.", entry_key) + + "." + + str(del_entry_idx + 1) + ) del_entry_name = re.sub("\.\d+", "\.", entry_key) else: if del_entry_name in entry_key: @@ -551,7 +813,9 @@ def close_instance(cls, instance_addr: Optional[str] = None): # delete registry entry and all its publishers/listeners/servers/clients if del_entry: - for communicator in cls._MiddlewareCommunicator__registry[del_entry]["communicator"]: + for communicator in cls._MiddlewareCommunicator__registry[del_entry][ + "communicator" + ]: wrapped_executor = communicator.get("wrapped_executor", False) if wrapped_executor: if isinstance(wrapped_executor, list): @@ -563,25 +827,29 @@ def close_instance(cls, instance_addr: Optional[str] = None): if other_entry_keys: del cls._MiddlewareCommunicator__registry[del_entry] else: - cls._MiddlewareCommunicator__registry[del_entry].pop("__WRAPYFI_INSTANCES") + cls._MiddlewareCommunicator__registry[del_entry].pop( + "__WRAPYFI_INSTANCES" + ) cls._MiddlewareCommunicator__registry[del_entry].pop("mode") - cls._MiddlewareCommunicator__registry[del_entry].pop("instance_addr") + cls._MiddlewareCommunicator__registry[del_entry].pop( + "instance_addr" + ) # shift all entries backwards following the deleted one if del_entry_idx - 1 < len(other_entry_keys): for other_entry_key in other_entry_keys[del_entry_idx:]: - new_key = re.split("\.(\d+)", other_entry_key) - if len(new_key) == 1: - new_key = new_key[0] - elif str(int(new_key[1]) - 1) == "1": - new_key = new_key[0] - else: - new_key = new_key[0] + "." + str(int(new_key[1]) - 1) - cls._MiddlewareCommunicator__registry[new_key] = \ - cls._MiddlewareCommunicator__registry.pop(other_entry_key) + new_key = re.split("\.(\d+)", other_entry_key) + if len(new_key) == 1: + new_key = new_key[0] + elif str(int(new_key[1]) - 1) == "1": + new_key = new_key[0] + else: + new_key = new_key[0] + "." + str(int(new_key[1]) - 1) + cls._MiddlewareCommunicator__registry[new_key] = ( + cls._MiddlewareCommunicator__registry.pop(other_entry_key) + ) else: break def __del__(self): self.close() - diff --git a/wrapyfi/encoders.py b/wrapyfi/encoders.py index 9d59f02..d4bd2f6 100644 --- a/wrapyfi/encoders.py +++ b/wrapyfi/encoders.py @@ -18,13 +18,14 @@ class JsonEncoder(json.JSONEncoder): - Numpy ndarray objects - Objects registered with the PluginRegistrar """ + def __init__(self, **kwargs): """ Initialize the JsonEncoder. :param kwargs: dict: Additional keyword arguments extracting values from the 'serializer_kwargs' key and passing them to the base class. All other keyword arguments are passed to the corresponding Plugin. """ - super().__init__(**kwargs.get('serializer_kwargs', {})) + super().__init__(**kwargs.get("serializer_kwargs", {})) self.plugins = dict() for plugin_key, plugin_val in PluginRegistrar.encoder_registry.items(): self.plugins[plugin_key] = plugin_val(**kwargs) @@ -37,7 +38,7 @@ def find_plugin(self, obj): :return: Plugin: The plugin for the given object if its type is registered, None otherwise """ for cls in reversed(type(obj).__mro__[:-1]): - if cls.__module__ == 'collections.abc': + if cls.__module__ == "collections.abc": continue # skip classes from collections.abc if issubclass(cls, abc.ABCMeta): if cls.__abstractmethods__: @@ -52,9 +53,10 @@ def encode(self, obj): :param obj: Any: The object to encode :return: str: The JSON string representation of the object returned by the base class """ + def hint_tuples(item): if isinstance(item, tuple): - return dict(__wrapyfi__=('tuple', item)) + return dict(__wrapyfi__=("tuple", item)) if isinstance(item, list): return [hint_tuples(e) for e in item] if isinstance(item, dict): @@ -72,19 +74,19 @@ def default(self, obj): :return: dict: A dictionary containing the class name and encoded data string """ if isinstance(obj, set): - return dict(__wrapyfi__=('set', list(obj))) + return dict(__wrapyfi__=("set", list(obj))) elif isinstance(obj, datetime): - return dict(__wrapyfi__=('datetime', obj.isoformat())) + return dict(__wrapyfi__=("datetime", obj.isoformat())) elif isinstance(obj, np.datetime64): - return dict(__wrapyfi__=('numpy.datetime64', str(obj))) + return dict(__wrapyfi__=("numpy.datetime64", str(obj))) elif isinstance(obj, (np.ndarray, np.generic)): with io.BytesIO() as memfile: np.save(memfile, obj) - obj_data = base64.b64encode(memfile.getvalue()).decode('ascii') - return dict(__wrapyfi__=('numpy.ndarray', obj_data)) + obj_data = base64.b64encode(memfile.getvalue()).decode("ascii") + return dict(__wrapyfi__=("numpy.ndarray", obj_data)) plugin_match = self.find_plugin(obj) if plugin_match is not None: @@ -106,6 +108,7 @@ class JsonDecodeHook(object): - Numpy ndarray objects - Objects registered with the PluginRegistrar """ + def __init__(self, **kwargs): """ Initialize the JsonDecodeHook. @@ -124,24 +127,26 @@ def object_hook(self, obj): :return: Any: The decoded object """ if isinstance(obj, dict): - wrapyfi = obj.get('__wrapyfi__', None) + wrapyfi = obj.get("__wrapyfi__", None) if wrapyfi is not None: obj_type = wrapyfi[0] - if obj_type == 'tuple': + if obj_type == "tuple": return tuple(wrapyfi[1]) - elif obj_type == 'set': + elif obj_type == "set": return set(wrapyfi[1]) - elif obj_type == 'datetime': + elif obj_type == "datetime": return datetime.fromisoformat(wrapyfi[1]) - elif obj_type == 'numpy.datetime64': + elif obj_type == "numpy.datetime64": return np.datetime64(wrapyfi[1]) - elif obj_type == 'numpy.ndarray': - with io.BytesIO(base64.b64decode(wrapyfi[1].encode('ascii'))) as memfile: + elif obj_type == "numpy.ndarray": + with io.BytesIO( + base64.b64decode(wrapyfi[1].encode("ascii")) + ) as memfile: return np.load(memfile) plugin_match = self.plugins.get(obj_type, None) diff --git a/wrapyfi/listeners/__init__.py b/wrapyfi/listeners/__init__.py index 95bd55e..65a8ce4 100755 --- a/wrapyfi/listeners/__init__.py +++ b/wrapyfi/listeners/__init__.py @@ -6,11 +6,22 @@ @Listeners.register("MMO", "fallback") class FallbackListener(Listener): - def __init__(self, name: str, in_topic: str, carrier: str = "tcp", - should_wait: bool = True, missing_middleware_object: str = "", **kwargs): - logging.warning(f"Fallback listener employed due to missing middleware or object type: " - f"{missing_middleware_object}") - Listener.__init__(self, name, in_topic, carrier=carrier, should_wait=should_wait, **kwargs) + def __init__( + self, + name: str, + in_topic: str, + carrier: str = "tcp", + should_wait: bool = True, + missing_middleware_object: str = "", + **kwargs, + ): + logging.warning( + f"Fallback listener employed due to missing middleware or object type: " + f"{missing_middleware_object}" + ) + Listener.__init__( + self, name, in_topic, carrier=carrier, should_wait=should_wait, **kwargs + ) self.missing_middleware_object = missing_middleware_object def establish(self, repeats: int = -1, **kwargs): @@ -20,4 +31,4 @@ def listen(self): return None def close(self): - return None \ No newline at end of file + return None diff --git a/wrapyfi/listeners/ros.py b/wrapyfi/listeners/ros.py index 456cdf3..fbd1772 100755 --- a/wrapyfi/listeners/ros.py +++ b/wrapyfi/listeners/ros.py @@ -24,8 +24,16 @@ class ROSListener(Listener): - def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait: bool = True, - queue_size: int = QUEUE_SIZE, ros_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + in_topic: str, + carrier: str = "tcp", + should_wait: bool = True, + queue_size: int = QUEUE_SIZE, + ros_kwargs: Optional[dict] = None, + **kwargs, + ): """ Initialize the subscriber. @@ -38,11 +46,15 @@ def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait: :param kwargs: dict: Additional kwargs for the subscriber """ if carrier or carrier != "tcp": - logging.warning("[ROS] ROS does not support other carriers than TCP for PUB/SUB pattern. Using TCP.") + logging.warning( + "[ROS] ROS does not support other carriers than TCP for PUB/SUB pattern. Using TCP." + ) carrier = "tcp" - super().__init__(name, in_topic, carrier=carrier, should_wait=should_wait, **kwargs) + super().__init__( + name, in_topic, carrier=carrier, should_wait=should_wait, **kwargs + ) ROSMiddleware.activate(**ros_kwargs or {}) - + self.queue_size = queue_size def close(self): @@ -60,8 +72,16 @@ def __del__(self): @Listeners.register("NativeObject", "ros") class ROSNativeObjectListener(ROSListener): - def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait: bool = True, queue_size: int =QUEUE_SIZE, - deserializer_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + in_topic: str, + carrier: str = "tcp", + should_wait: bool = True, + queue_size: int = QUEUE_SIZE, + deserializer_kwargs: Optional[dict] = None, + **kwargs, + ): """ The NativeObject listener using the ROS String message assuming the data is serialized as a JSON string. Deserializes the data (including plugins) using the decoder and parses it to a Python object. @@ -73,7 +93,14 @@ def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait: :param queue_size: int: Size of the queue for the subscriber. Default is 5 :param deserializer_kwargs: dict: Additional kwargs for the deserializer """ - super().__init__(name, in_topic, carrier=carrier, should_wait=should_wait, queue_size=queue_size, **kwargs) + super().__init__( + name, + in_topic, + carrier=carrier, + should_wait=should_wait, + queue_size=queue_size, + **kwargs, + ) self._subscriber = self._queue = None @@ -86,8 +113,16 @@ def establish(self): """ Establish the subscriber. """ - self._queue = queue.Queue(maxsize=0 if self.queue_size is None or self.queue_size <= 0 else self.queue_size) - self._subscriber = rospy.Subscriber(self.in_topic, std_msgs.msg.String, callback=self._message_callback) + self._queue = queue.Queue( + maxsize=( + 0 + if self.queue_size is None or self.queue_size <= 0 + else self.queue_size + ) + ) + self._subscriber = rospy.Subscriber( + self.in_topic, std_msgs.msg.String, callback=self._message_callback + ) self.established = True def listen(self) -> Any: @@ -100,7 +135,11 @@ def listen(self) -> Any: self.establish() try: obj_str = self._queue.get(block=self.should_wait) - return json.loads(obj_str, object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs) + return json.loads( + obj_str, + object_hook=self._plugin_decoder_hook, + **self._deserializer_kwargs, + ) except queue.Empty: return None @@ -113,14 +152,28 @@ def _message_callback(self, msg): try: self._queue.put(msg.data, block=False) except queue.Full: - logging.warning(f"[ROS] Discarding data because listener queue is full: {self.in_topic}") + logging.warning( + f"[ROS] Discarding data because listener queue is full: {self.in_topic}" + ) @Listeners.register("Image", "ros") class ROSImageListener(ROSListener): - def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait: bool = True, queue_size: int = QUEUE_SIZE, - width: int = -1, height: int = -1, rgb: bool = True, fp: bool = False, jpg: bool = False, **kwargs): + def __init__( + self, + name: str, + in_topic: str, + carrier: str = "tcp", + should_wait: bool = True, + queue_size: int = QUEUE_SIZE, + width: int = -1, + height: int = -1, + rgb: bool = True, + fp: bool = False, + jpg: bool = False, + **kwargs, + ): """ The Image listener using the ROS Image message parsed to a numpy array. @@ -135,7 +188,14 @@ def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait: :param fp: bool: True if the image is floating point, False if it is integer. Default is False :param jpg: bool: True if the image should be decompressed from JPG. Default is False """ - super().__init__(name, in_topic, carrier=carrier, should_wait=should_wait, queue_size=queue_size, **kwargs) + super().__init__( + name, + in_topic, + carrier=carrier, + should_wait=should_wait, + queue_size=queue_size, + **kwargs, + ) self.width = width self.height = height self.rgb = rgb @@ -143,13 +203,13 @@ def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait: self.jpg = jpg if self.fp: - self._encoding = '32FC3' if self.rgb else '32FC1' + self._encoding = "32FC3" if self.rgb else "32FC1" self._type = np.float32 else: - self._encoding = 'bgr8' if self.rgb else 'mono8' + self._encoding = "bgr8" if self.rgb else "mono8" self._type = np.uint8 if self.jpg: - self._encoding = 'jpeg' + self._encoding = "jpeg" self._type = np.uint8 self._pixel_bytes = (3 if self.rgb else 1) * np.dtype(self._type).itemsize @@ -162,11 +222,23 @@ def establish(self): """ Establish the subscriber. """ - self._queue = queue.Queue(maxsize=0 if self.queue_size is None or self.queue_size <= 0 else self.queue_size) + self._queue = queue.Queue( + maxsize=( + 0 + if self.queue_size is None or self.queue_size <= 0 + else self.queue_size + ) + ) if self.jpg: - self._subscriber = rospy.Subscriber(self.in_topic, sensor_msgs.msg.CompressedImage, callback=self._message_callback) + self._subscriber = rospy.Subscriber( + self.in_topic, + sensor_msgs.msg.CompressedImage, + callback=self._message_callback, + ) else: - self._subscriber = rospy.Subscriber(self.in_topic, sensor_msgs.msg.Image, callback=self._message_callback) + self._subscriber = rospy.Subscriber( + self.in_topic, sensor_msgs.msg.Image, callback=self._message_callback + ) self.established = True def listen(self): @@ -180,19 +252,32 @@ def listen(self): try: if self.jpg: format, data = self._queue.get(block=self.should_wait) - if format != 'jpeg': + if format != "jpeg": raise ValueError(f"Unsupported image format: {format}") if self.rgb: img = cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_COLOR) else: - img = cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_GRAYSCALE) + img = cv2.imdecode( + np.frombuffer(data, np.uint8), cv2.IMREAD_GRAYSCALE + ) else: - height, width, encoding, is_bigendian, data = self._queue.get(block=self.should_wait) + height, width, encoding, is_bigendian, data = self._queue.get( + block=self.should_wait + ) if encoding != self._encoding: raise ValueError("Incorrect encoding for listener") - if 0 < self.width != width or 0 < self.height != height or len(data) != height * width * self._pixel_bytes: + if ( + 0 < self.width != width + or 0 < self.height != height + or len(data) != height * width * self._pixel_bytes + ): raise ValueError("Incorrect image shape for listener") - img = np.frombuffer(data, dtype=np.dtype(self._type).newbyteorder('>' if is_bigendian else '<')).reshape((height, width, -1)) + img = np.frombuffer( + data, + dtype=np.dtype(self._type).newbyteorder( + ">" if is_bigendian else "<" + ), + ).reshape((height, width, -1)) if img.shape[2] == 1: img = img.squeeze(axis=2) return img @@ -209,16 +294,37 @@ def _message_callback(self, data): if self.jpg: self._queue.put((data.format, data.data), block=False) else: - self._queue.put((data.height, data.width, data.encoding, data.is_bigendian, data.data), block=False) + self._queue.put( + ( + data.height, + data.width, + data.encoding, + data.is_bigendian, + data.data, + ), + block=False, + ) except queue.Full: - logging.warning(f"[ROS] Discarding data because listener queue is full: {self.in_topic}") + logging.warning( + f"[ROS] Discarding data because listener queue is full: {self.in_topic}" + ) @Listeners.register("AudioChunk", "ros") class ROSAudioChunkListener(ROSListener): - def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait: bool = True, queue_size: int = QUEUE_SIZE, - channels: int = 1, rate: int = 44100, chunk: int = -1, **kwargs): + def __init__( + self, + name: str, + in_topic: str, + carrier: str = "tcp", + should_wait: bool = True, + queue_size: int = QUEUE_SIZE, + channels: int = 1, + rate: int = 44100, + chunk: int = -1, + **kwargs, + ): """ The AudioChunk listener using the ROS Image message parsed to a numpy array. @@ -231,8 +337,19 @@ def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait: :param rate: int: Sampling rate of the audio. Default is 44100 :param chunk: int: Number of samples in the audio chunk. Default is -1 (use the chunk size of the received audio) """ - super().__init__(name, in_topic, carrier=carrier, should_wait=should_wait, queue_size=queue_size, - width=chunk, height=channels, rgb=False, fp=True, jpg=False, **kwargs) + super().__init__( + name, + in_topic, + carrier=carrier, + should_wait=should_wait, + queue_size=queue_size, + width=chunk, + height=channels, + rgb=False, + fp=True, + jpg=False, + **kwargs, + ) self.channels = channels self.rate = rate self.chunk = chunk @@ -248,13 +365,25 @@ def establish(self): from wrapyfi_ros_interfaces.msg import ROSAudioMessage except ImportError: import wrapyfi - logging.error("[ROS] Could not import ROSAudioMessage. " - "Make sure the ROS messages in wrapyfi_extensions/wrapyfi_ros_interfaces are compiled. " - "Refer to the documentation for more information: \n" + - wrapyfi.__doc__ + "ros_interfaces_lnk.html") + + logging.error( + "[ROS] Could not import ROSAudioMessage. " + "Make sure the ROS messages in wrapyfi_extensions/wrapyfi_ros_interfaces are compiled. " + "Refer to the documentation for more information: \n" + + wrapyfi.__doc__ + + "ros_interfaces_lnk.html" + ) sys.exit(1) - self._queue = queue.Queue(maxsize=0 if self.queue_size is None or self.queue_size <= 0 else self.queue_size) - self._subscriber = rospy.Subscriber(self.in_topic, ROSAudioMessage, callback=self._message_callback) + self._queue = queue.Queue( + maxsize=( + 0 + if self.queue_size is None or self.queue_size <= 0 + else self.queue_size + ) + ) + self._subscriber = rospy.Subscriber( + self.in_topic, ROSAudioMessage, callback=self._message_callback + ) self.established = True def listen(self): @@ -266,14 +395,23 @@ def listen(self): if not self.established: self.establish() try: - chunk, channels, rate, encoding, is_bigendian, data = self._queue.get(block=self.should_wait) + chunk, channels, rate, encoding, is_bigendian, data = self._queue.get( + block=self.should_wait + ) if 0 < self.rate != rate: raise ValueError("Incorrect audio rate for listener") - if encoding not in ['S16LE', 'S16BE']: + if encoding not in ["S16LE", "S16BE"]: raise ValueError("Incorrect encoding for listener") - if 0 < self.chunk != chunk or self.channels != channels or len(data) != chunk * channels * 4: + if ( + 0 < self.chunk != chunk + or self.channels != channels + or len(data) != chunk * channels * 4 + ): raise ValueError("Incorrect audio shape for listener") - aud = np.frombuffer(data, dtype=np.dtype(np.float32).newbyteorder('>' if is_bigendian else '<')).reshape((chunk, channels)) + aud = np.frombuffer( + data, + dtype=np.dtype(np.float32).newbyteorder(">" if is_bigendian else "<"), + ).reshape((chunk, channels)) # aud = aud / 32767.0 return aud, rate except queue.Empty: @@ -286,9 +424,21 @@ def _message_callback(self, data): :param data: wrapyfi_ros_interfaces.msg.ROSAudioMessage: The received message """ try: - self._queue.put((data.chunk_size, data.channels, data.sample_rate, data.encoding, data.is_bigendian, data.data), block=False) + self._queue.put( + ( + data.chunk_size, + data.channels, + data.sample_rate, + data.encoding, + data.is_bigendian, + data.data, + ), + block=False, + ) except queue.Full: - logging.warning(f"[ROS] Discarding data because listener queue is full: {self.in_topic}") + logging.warning( + f"[ROS] Discarding data because listener queue is full: {self.in_topic}" + ) @Listeners.register("Properties", "ros") @@ -300,7 +450,16 @@ class ROSPropertiesListener(ROSListener): but care should be taken when using dictionaries, since they are analogous with node namespaces: http://wiki.ros.org/rospy/Overview/Parameter%20Server """ - def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait: bool = True, queue_size: int = QUEUE_SIZE, **kwargs): + + def __init__( + self, + name: str, + in_topic: str, + carrier: str = "tcp", + should_wait: bool = True, + queue_size: int = QUEUE_SIZE, + **kwargs, + ): """ The PropertiesListener using the ROS Parameter Server. @@ -310,7 +469,14 @@ def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait: :param should_wait: bool: Whether the subscriber should wait for a parameter to be set. Default is True :param queue_size: int: Size of the queue for the subscriber. Default is 5 """ - super().__init__(name, in_topic, carrier=carrier, should_wait=should_wait, queue_size=queue_size, **kwargs) + super().__init__( + name, + in_topic, + carrier=carrier, + should_wait=should_wait, + queue_size=queue_size, + **kwargs, + ) self._subscriber = self._queue = None if not self.should_wait: @@ -318,7 +484,9 @@ def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait: self.previous_property = False - def await_connection(self, in_topic: Optional[int] = None, repeats: Optional[int] = None): + def await_connection( + self, in_topic: Optional[int] = None, repeats: Optional[int] = None + ): """ Wait for a parameter to be set. @@ -375,7 +543,15 @@ def listen(self): @Listeners.register("ROSMessage", "ros") class ROSMessageListener(ROSListener): - def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait: bool = True, queue_size: int = QUEUE_SIZE, **kwargs): + def __init__( + self, + name: str, + in_topic: str, + carrier: str = "tcp", + should_wait: bool = True, + queue_size: int = QUEUE_SIZE, + **kwargs, + ): """ The ROSMessageListener using the ROS message type inferred from the message type. Supports standard ROS msgs. @@ -385,7 +561,14 @@ def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait: :param should_wait: bool: Whether the subscriber should wait for a message to be published. Default is True :param queue_size: int: Size of the queue for the subscriber. Default is 5 """ - super().__init__(name, in_topic, carrier=carrier, should_wait=should_wait, queue_size=queue_size, **kwargs) + super().__init__( + name, + in_topic, + carrier=carrier, + should_wait=should_wait, + queue_size=queue_size, + **kwargs, + ) self._subscriber = self._queue = self._topic_type = None ListenerWatchDog().add_listener(self) @@ -394,11 +577,21 @@ def establish(self): """ Establish the subscriber. """ - self._queue = queue.Queue(maxsize=0 if self.queue_size is None or self.queue_size <= 0 else self.queue_size) - self._topic_type, topic_str, _ = rostopic.get_topic_class(self.in_topic, blocking=self.should_wait) + self._queue = queue.Queue( + maxsize=( + 0 + if self.queue_size is None or self.queue_size <= 0 + else self.queue_size + ) + ) + self._topic_type, topic_str, _ = rostopic.get_topic_class( + self.in_topic, blocking=self.should_wait + ) if self._topic_type is None: return - self._subscriber = rospy.Subscriber(self.in_topic, self._topic_type, callback=self._message_callback) + self._subscriber = rospy.Subscriber( + self.in_topic, self._topic_type, callback=self._message_callback + ) self.established = True def listen(self): @@ -425,5 +618,6 @@ def _message_callback(self, msg): try: self._queue.put(msg, block=False) except queue.Full: - logging.warning(f"[ROS] Discarding data because listener queue is full: {self.in_topic}") - + logging.warning( + f"[ROS] Discarding data because listener queue is full: {self.in_topic}" + ) diff --git a/wrapyfi/listeners/ros2.py b/wrapyfi/listeners/ros2.py index 984528b..d498dae 100755 --- a/wrapyfi/listeners/ros2.py +++ b/wrapyfi/listeners/ros2.py @@ -27,8 +27,15 @@ class ROS2Listener(Listener, Node): - def __init__(self, name: str, in_topic: str, should_wait: bool = True, - queue_size: int = QUEUE_SIZE, ros2_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + in_topic: str, + should_wait: bool = True, + queue_size: int = QUEUE_SIZE, + ros2_kwargs: Optional[dict] = None, + **kwargs, + ): """ Initialize the subscriber. @@ -41,12 +48,16 @@ def __init__(self, name: str, in_topic: str, should_wait: bool = True, """ carrier = "tcp" if "carrier" in kwargs and kwargs["carrier"] not in ["", None]: - logging.warning("[ROS 2] ROS 2 currently does not support explicit carrier setting for PUB/SUB pattern. Using TCP.") + logging.warning( + "[ROS 2] ROS 2 currently does not support explicit carrier setting for PUB/SUB pattern. Using TCP." + ) if "carrier" in kwargs: del kwargs["carrier"] ROS2Middleware.activate(**ros2_kwargs or {}) - Listener.__init__(self, name, in_topic, carrier=carrier, should_wait=should_wait, **kwargs) + Listener.__init__( + self, name, in_topic, carrier=carrier, should_wait=should_wait, **kwargs + ) Node.__init__(self, name + str(hex(id(self))), allow_undeclared_parameters=True) self.queue_size = queue_size @@ -66,8 +77,15 @@ def __del__(self): @Listeners.register("NativeObject", "ros2") class ROS2NativeObjectListener(ROS2Listener): - def __init__(self, name: str, in_topic: str, should_wait: bool = True, queue_size: int = QUEUE_SIZE, - deserializer_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + in_topic: str, + should_wait: bool = True, + queue_size: int = QUEUE_SIZE, + deserializer_kwargs: Optional[dict] = None, + **kwargs, + ): """ The NativeObject listener using the ROS 2 String message assuming the data is serialized as a JSON string. Deserializes the data (including plugins) using the decoder and parses it to a native object. @@ -78,7 +96,9 @@ def __init__(self, name: str, in_topic: str, should_wait: bool = True, queue_siz :param queue_size: int: Size of the queue for the subscriber. Default is 5 :param deserializer_kwargs: dict: Additional kwargs for the deserializer """ - super().__init__(name, in_topic, should_wait=should_wait, queue_size=queue_size, **kwargs) + super().__init__( + name, in_topic, should_wait=should_wait, queue_size=queue_size, **kwargs + ) self._subscriber = self._queue = None @@ -91,8 +111,19 @@ def establish(self): """ Establish the subscriber. """ - self._queue = queue.Queue(maxsize=0 if self.queue_size is None or self.queue_size <= 0 else self.queue_size) - self._subscriber = self.create_subscription(std_msgs.msg.String, self.in_topic, callback=self._message_callback, qos_profile=self.queue_size) + self._queue = queue.Queue( + maxsize=( + 0 + if self.queue_size is None or self.queue_size <= 0 + else self.queue_size + ) + ) + self._subscriber = self.create_subscription( + std_msgs.msg.String, + self.in_topic, + callback=self._message_callback, + qos_profile=self.queue_size, + ) self.established = True def listen(self): @@ -106,7 +137,11 @@ def listen(self): try: rclpy.spin_once(self, timeout_sec=WAIT[self.should_wait]) obj_str = self._queue.get(block=self.should_wait) - return json.loads(obj_str, object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs) + return json.loads( + obj_str, + object_hook=self._plugin_decoder_hook, + **self._deserializer_kwargs, + ) except queue.Empty: return None @@ -119,14 +154,27 @@ def _message_callback(self, msg): try: self._queue.put(msg.data, block=False) except queue.Full: - logging.warning(f"[ROS 2] Discarding data because listener queue is full: {self.in_topic}") + logging.warning( + f"[ROS 2] Discarding data because listener queue is full: {self.in_topic}" + ) @Listeners.register("Image", "ros2") class ROS2ImageListener(ROS2Listener): - def __init__(self, name: str, in_topic: str, should_wait: bool = True, queue_size: int = QUEUE_SIZE, - width: int = -1, height: int = -1, rgb: bool = True, fp: bool = False, jpg: bool = False, **kwargs): + def __init__( + self, + name: str, + in_topic: str, + should_wait: bool = True, + queue_size: int = QUEUE_SIZE, + width: int = -1, + height: int = -1, + rgb: bool = True, + fp: bool = False, + jpg: bool = False, + **kwargs, + ): """ The Image listener using the ROS 2 Image message parsed to a numpy array. @@ -140,7 +188,9 @@ def __init__(self, name: str, in_topic: str, should_wait: bool = True, queue_siz :param fp: bool: True if the image is floating point, False if it is integer. Default is False :param jpg: bool: True if the image should be decompressed from JPG. Default is False """ - super().__init__(name, in_topic, should_wait=should_wait, queue_size=queue_size, **kwargs) + super().__init__( + name, in_topic, should_wait=should_wait, queue_size=queue_size, **kwargs + ) self.width = width self.height = height self.rgb = rgb @@ -148,13 +198,13 @@ def __init__(self, name: str, in_topic: str, should_wait: bool = True, queue_siz self.jpg = jpg if self.fp: - self._encoding = '32FC3' if self.rgb else '32FC1' + self._encoding = "32FC3" if self.rgb else "32FC1" self._type = np.float32 else: - self._encoding = 'bgr8' if self.rgb else 'mono8' + self._encoding = "bgr8" if self.rgb else "mono8" self._type = np.uint8 if self.jpg: - self._encoding = 'jpeg' + self._encoding = "jpeg" self._type = np.uint8 self._pixel_bytes = (3 if self.rgb else 1) * np.dtype(self._type).itemsize @@ -167,11 +217,27 @@ def establish(self): """ Establish the subscriber """ - self._queue = queue.Queue(maxsize=0 if self.queue_size is None or self.queue_size <= 0 else self.queue_size) + self._queue = queue.Queue( + maxsize=( + 0 + if self.queue_size is None or self.queue_size <= 0 + else self.queue_size + ) + ) if self.jpg: - self._subscriber = self.create_subscription(sensor_msgs.msg.CompressedImage, self.in_topic, callback=self._message_callback, qos_profile=self.queue_size) + self._subscriber = self.create_subscription( + sensor_msgs.msg.CompressedImage, + self.in_topic, + callback=self._message_callback, + qos_profile=self.queue_size, + ) else: - self._subscriber = self.create_subscription(sensor_msgs.msg.Image, self.in_topic, callback=self._message_callback, qos_profile=self.queue_size) + self._subscriber = self.create_subscription( + sensor_msgs.msg.Image, + self.in_topic, + callback=self._message_callback, + qos_profile=self.queue_size, + ) self.established = True def listen(self): @@ -186,19 +252,32 @@ def listen(self): rclpy.spin_once(self, timeout_sec=WAIT[self.should_wait]) if self.jpg: format, data = self._queue.get(block=self.should_wait) - if format != 'jpeg': + if format != "jpeg": raise ValueError(f"Unsupported image format: {format}") if self.rgb: img = cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_COLOR) else: - img = cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_GRAYSCALE) + img = cv2.imdecode( + np.frombuffer(data, np.uint8), cv2.IMREAD_GRAYSCALE + ) else: - height, width, encoding, is_bigendian, data = self._queue.get(block=self.should_wait) + height, width, encoding, is_bigendian, data = self._queue.get( + block=self.should_wait + ) if encoding != self._encoding: raise ValueError("Incorrect encoding for listener") - if 0 < self.width != width or 0 < self.height != height or len(data) != height * width * self._pixel_bytes: + if ( + 0 < self.width != width + or 0 < self.height != height + or len(data) != height * width * self._pixel_bytes + ): raise ValueError("Incorrect image shape for listener") - img = np.frombuffer(data, dtype=np.dtype(self._type).newbyteorder('>' if is_bigendian else '<')).reshape((height, width, -1)) + img = np.frombuffer( + data, + dtype=np.dtype(self._type).newbyteorder( + ">" if is_bigendian else "<" + ), + ).reshape((height, width, -1)) if img.shape[2] == 1: img = img.squeeze(axis=2) return img @@ -215,16 +294,30 @@ def _message_callback(self, msg): if self.jpg: self._queue.put((msg.format, msg.data), block=False) else: - self._queue.put((msg.height, msg.width, msg.encoding, msg.is_bigendian, msg.data), block=False) + self._queue.put( + (msg.height, msg.width, msg.encoding, msg.is_bigendian, msg.data), + block=False, + ) except queue.Full: - logging.warning(f"[ROS 2] Discarding data because listener queue is full: {self.in_topic}") + logging.warning( + f"[ROS 2] Discarding data because listener queue is full: {self.in_topic}" + ) @Listeners.register("AudioChunk", "ros2") class ROS2AudioChunkListener(ROS2Listener): - def __init__(self, name: str, in_topic: str, should_wait: bool = True, - queue_size: int = QUEUE_SIZE, channels: int = 1, rate: int = 44100, chunk: int = -1, **kwargs): + def __init__( + self, + name: str, + in_topic: str, + should_wait: bool = True, + queue_size: int = QUEUE_SIZE, + channels: int = 1, + rate: int = 44100, + chunk: int = -1, + **kwargs, + ): """ The AudioChunk listener using the ROS 2 Audio message parsed to a numpy array. @@ -236,7 +329,9 @@ def __init__(self, name: str, in_topic: str, should_wait: bool = True, :param rate: int: Sampling rate of the audio. Default is 44100 :param chunk: int: Number of samples in the audio chunk. Default is -1 (use the chunk size of the received audio) """ - super().__init__(name, in_topic, should_wait=should_wait, queue_size=queue_size, **kwargs) + super().__init__( + name, in_topic, should_wait=should_wait, queue_size=queue_size, **kwargs + ) self.channels = channels self.rate = rate self.chunk = chunk @@ -253,13 +348,28 @@ def establish(self): from wrapyfi_ros2_interfaces.msg import ROS2AudioMessage except ImportError: import wrapyfi - logging.error("[ROS 2] Could not import ROS2AudioMessage. " - "Make sure the ROS 2 services in wrapyfi_extensions/wrapyfi_ros2_interfaces are compiled. " - "Refer to the documentation for more information: \n" + - wrapyfi.__doc__ + "ros2_interfaces_lnk.html") + + logging.error( + "[ROS 2] Could not import ROS2AudioMessage. " + "Make sure the ROS 2 services in wrapyfi_extensions/wrapyfi_ros2_interfaces are compiled. " + "Refer to the documentation for more information: \n" + + wrapyfi.__doc__ + + "ros2_interfaces_lnk.html" + ) sys.exit(1) - self._queue = queue.Queue(maxsize=0 if self.queue_size is None or self.queue_size <= 0 else self.queue_size) - self._subscriber = self.create_subscription(ROS2AudioMessage, self.in_topic, callback=self._message_callback, qos_profile=self.queue_size) + self._queue = queue.Queue( + maxsize=( + 0 + if self.queue_size is None or self.queue_size <= 0 + else self.queue_size + ) + ) + self._subscriber = self.create_subscription( + ROS2AudioMessage, + self.in_topic, + callback=self._message_callback, + qos_profile=self.queue_size, + ) self.established = True def listen(self): @@ -272,14 +382,23 @@ def listen(self): self.establish() try: rclpy.spin_once(self, timeout_sec=WAIT[self.should_wait]) - chunk, channels, rate, encoding, is_bigendian, data = self._queue.get(block=self.should_wait) + chunk, channels, rate, encoding, is_bigendian, data = self._queue.get( + block=self.should_wait + ) if 0 < self.rate != rate: raise ValueError("Incorrect audio rate for publisher") - if encoding not in ['S16LE', 'S16BE']: + if encoding not in ["S16LE", "S16BE"]: raise ValueError("Incorrect encoding for listener") - if 0 < self.chunk != chunk or self.channels != channels or len(data) != chunk * channels * 4: + if ( + 0 < self.chunk != chunk + or self.channels != channels + or len(data) != chunk * channels * 4 + ): raise ValueError("Incorrect audio shape for listener") - aud = np.frombuffer(data, dtype=np.dtype(np.float32).newbyteorder('>' if is_bigendian else '<')).reshape((chunk, channels)) + aud = np.frombuffer( + data, + dtype=np.dtype(np.float32).newbyteorder(">" if is_bigendian else "<"), + ).reshape((chunk, channels)) # aud = aud / 32767.0 return aud, rate except queue.Empty: @@ -292,9 +411,21 @@ def _message_callback(self, msg): :param msg: wrapyfi_ros2_interfaces.msg.ROS2AudioMessage: The received message """ try: - self._queue.put((msg.chunk_size, msg.channels, msg.sample_rate, msg.encoding, msg.is_bigendian, msg.data), block=False) + self._queue.put( + ( + msg.chunk_size, + msg.channels, + msg.sample_rate, + msg.encoding, + msg.is_bigendian, + msg.data, + ), + block=False, + ) except queue.Full: - logging.warning(f"[ROS 2] Discarding data because listener queue is full: {self.in_topic}") + logging.warning( + f"[ROS 2] Discarding data because listener queue is full: {self.in_topic}" + ) @Listeners.register("Properties", "ros2") @@ -307,7 +438,14 @@ def __init__(self, name, in_topic, **kwargs): @Listeners.register("ROS2Message", "ros2") class ROS2MessageListener(ROS2Listener): - def __init__(self, name: str, in_topic: str, should_wait: bool = True, queue_size: int = QUEUE_SIZE, **kwargs): + def __init__( + self, + name: str, + in_topic: str, + should_wait: bool = True, + queue_size: int = QUEUE_SIZE, + **kwargs, + ): """ The ROS2MessageListener using the ROS 2 message type inferred from the message type. Supports standard ROS 2 msgs. @@ -316,8 +454,16 @@ def __init__(self, name: str, in_topic: str, should_wait: bool = True, queue_siz :param should_wait: bool: Whether the subscriber should wait for the publisher to transmit a message. Default is True :param queue_size: int: Size of the queue for the subscriber. Default is 5 """ - super().__init__(name, in_topic, should_wait=should_wait, queue_size=queue_size, **kwargs) - self._queue = queue.Queue(maxsize=0 if self.queue_size is None or self.queue_size <= 0 else self.queue_size) + super().__init__( + name, in_topic, should_wait=should_wait, queue_size=queue_size, **kwargs + ) + self._queue = queue.Queue( + maxsize=( + 0 + if self.queue_size is None or self.queue_size <= 0 + else self.queue_size + ) + ) def get_topic_type(self, topic_name): """ @@ -345,11 +491,16 @@ def establish(self): if not topic_type_str: return None - module_name, class_name = topic_type_str.rsplit('/', 1) - module_name = module_name.replace('/', '.') + module_name, class_name = topic_type_str.rsplit("/", 1) + module_name = module_name.replace("/", ".") MessageType = getattr(importlib.import_module(module_name), class_name) - self._subscriber = self.create_subscription(MessageType, self.in_topic, callback=self._message_callback, qos_profile=self.queue_size) + self._subscriber = self.create_subscription( + MessageType, + self.in_topic, + callback=self._message_callback, + qos_profile=self.queue_size, + ) self.established = True def listen(self): @@ -377,4 +528,6 @@ def _message_callback(self, msg): try: self._queue.put(msg, block=False) except queue.Full: - logging.warning(f"[ROS 2] Discarding data because listener queue is full: {self.in_topic}") + logging.warning( + f"[ROS 2] Discarding data because listener queue is full: {self.in_topic}" + ) diff --git a/wrapyfi/listeners/yarp.py b/wrapyfi/listeners/yarp.py index ec2f812..470252f 100755 --- a/wrapyfi/listeners/yarp.py +++ b/wrapyfi/listeners/yarp.py @@ -19,8 +19,16 @@ class YarpListener(Listener): - def __init__(self, name: str, in_topic: str, carrier: Literal["tcp", "udp", "mcast"] = "tcp", should_wait: bool = True, - persistent: bool = True, yarp_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + in_topic: str, + carrier: Literal["tcp", "udp", "mcast"] = "tcp", + should_wait: bool = True, + persistent: bool = True, + yarp_kwargs: Optional[dict] = None, + **kwargs, + ): """ Initialize the subscriber. @@ -32,14 +40,18 @@ def __init__(self, name: str, in_topic: str, carrier: Literal["tcp", "udp", "mca :param yarp_kwargs: dict: Additional kwargs for the Yarp middleware :param kwargs: dict: Additional kwargs for the subscriber """ - super().__init__(name, in_topic, carrier=carrier, should_wait=should_wait, **kwargs) + super().__init__( + name, in_topic, carrier=carrier, should_wait=should_wait, **kwargs + ) self.style = yarp.ContactStyle() self.style.persistent = persistent self.style.carrier = self.carrier YarpMiddleware.activate(**yarp_kwargs or {}) - def await_connection(self, in_topic: Optional[str] = None, repeats: Optional[int] = None): + def await_connection( + self, in_topic: Optional[str] = None, repeats: Optional[int] = None + ): """ Wait for the publisher to connect to the subscriber. @@ -95,8 +107,16 @@ def __del__(self): @Listeners.register("NativeObject", "yarp") class YarpNativeObjectListener(YarpListener): - def __init__(self, name: str, in_topic: str, carrier: Literal["tcp", "udp", "mcast"] = "tcp", should_wait: bool = True, - persistent: bool = True, deserializer_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + in_topic: str, + carrier: Literal["tcp", "udp", "mcast"] = "tcp", + should_wait: bool = True, + persistent: bool = True, + deserializer_kwargs: Optional[dict] = None, + **kwargs, + ): """ The NativeObject listener using the BufferedPortBottle string construct assuming the data is serialized as a JSON string. Deserializes the data (including plugins) using the decoder and parses it to a Python object. @@ -108,7 +128,14 @@ def __init__(self, name: str, in_topic: str, carrier: Literal["tcp", "udp", "mca :param persistent: bool: Whether the subscriber port should remain connected after closure. Default is True :param deserializer_kwargs: dict: Additional kwargs for the deserializer """ - super().__init__(name, in_topic, carrier=carrier, should_wait=should_wait, persistent=persistent, **kwargs) + super().__init__( + name, + in_topic, + carrier=carrier, + should_wait=should_wait, + persistent=persistent, + **kwargs, + ) self._port = self._netconnect = None self._plugin_decoder_hook = JsonDecodeHook(**kwargs).object_hook @@ -130,9 +157,13 @@ def establish(self, repeats: Optional[int] = None, **kwargs): rnd_id = str(np.random.randint(100000, size=1)[0]) self._port.open(self.in_topic + ":in" + rnd_id) if self.style.persistent: - self._netconnect = yarp.Network.connect(self.in_topic, self.in_topic + ":in" + rnd_id, self.style) + self._netconnect = yarp.Network.connect( + self.in_topic, self.in_topic + ":in" + rnd_id, self.style + ) else: - self._netconnect = yarp.Network.connect(self.in_topic, self.in_topic + ":in" + rnd_id, self.carrier) + self._netconnect = yarp.Network.connect( + self.in_topic, self.in_topic + ":in" + rnd_id, self.carrier + ) return self.check_establishment(established) def listen(self): @@ -147,7 +178,11 @@ def listen(self): return None obj_port = self.read_port(self._port) if obj_port is not None: - return json.loads(obj_port.get(0).asString(), object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs) + return json.loads( + obj_port.get(0).asString(), + object_hook=self._plugin_decoder_hook, + **self._deserializer_kwargs, + ) else: return None @@ -155,8 +190,20 @@ def listen(self): @Listeners.register("Image", "yarp") class YarpImageListener(YarpListener): - def __init__(self, name: str, in_topic: str, carrier: Literal["tcp", "udp", "mcast"] = "tcp", should_wait: bool = True, - persistent: bool = True, width: int = -1, height: int = -1, rgb: bool = True, fp: bool = False, jpg: bool = False, **kwargs): + def __init__( + self, + name: str, + in_topic: str, + carrier: Literal["tcp", "udp", "mcast"] = "tcp", + should_wait: bool = True, + persistent: bool = True, + width: int = -1, + height: int = -1, + rgb: bool = True, + fp: bool = False, + jpg: bool = False, + **kwargs, + ): """ The Image listener using the BufferedPortImage construct parsed to a numpy array. @@ -171,7 +218,14 @@ def __init__(self, name: str, in_topic: str, carrier: Literal["tcp", "udp", "mca :param fp: bool: True if the image is floating point, False if it is integer. Default is False :param jpg: bool: True if the image should be decompressed from JPG. Default is False """ - super().__init__(name, in_topic, carrier=carrier, should_wait=should_wait, persistent=persistent, **kwargs) + super().__init__( + name, + in_topic, + carrier=carrier, + should_wait=should_wait, + persistent=persistent, + **kwargs, + ) self.width = width self.height = height self.rgb = rgb @@ -195,16 +249,30 @@ def establish(self, repeats: Optional[int] = None, **kwargs): if self.jpg: self._port = yarp.BufferedPortBottle() elif self.rgb: - self._port = yarp.BufferedPortImageRgbFloat() if self.fp else yarp.BufferedPortImageRgb() + self._port = ( + yarp.BufferedPortImageRgbFloat() + if self.fp + else yarp.BufferedPortImageRgb() + ) else: - self._port = yarp.BufferedPortImageFloat() if self.fp else yarp.BufferedPortImageMono() + self._port = ( + yarp.BufferedPortImageFloat() + if self.fp + else yarp.BufferedPortImageMono() + ) self._type = np.float32 if self.fp else np.uint8 - in_topic_connect = f"{self.in_topic}:in{np.random.randint(100000, size=1).item()}" + in_topic_connect = ( + f"{self.in_topic}:in{np.random.randint(100000, size=1).item()}" + ) self._port.open(in_topic_connect) if self.style.persistent: - self._netconnect = yarp.Network.connect(self.in_topic, in_topic_connect, self.style) + self._netconnect = yarp.Network.connect( + self.in_topic, in_topic_connect, self.style + ) else: - self._netconnect = yarp.Network.connect(self.in_topic, in_topic_connect, self.carrier) + self._netconnect = yarp.Network.connect( + self.in_topic, in_topic_connect, self.carrier + ) return self.check_establishment(established) def listen(self): @@ -222,21 +290,34 @@ def listen(self): return None if self.jpg: img_str = ret_img_msg.get(0).asString() - with io.BytesIO(base64.b64decode(img_str.encode('ascii'))) as memfile: + with io.BytesIO(base64.b64decode(img_str.encode("ascii"))) as memfile: img_str = np.load(memfile) if self.rgb: img = cv2.imdecode(np.frombuffer(img_str, np.uint8), cv2.IMREAD_COLOR) else: - img = cv2.imdecode(np.frombuffer(img_str, np.uint8), cv2.IMREAD_GRAYSCALE) + img = cv2.imdecode( + np.frombuffer(img_str, np.uint8), cv2.IMREAD_GRAYSCALE + ) return img else: - if 0 < self.width != ret_img_msg.width() or 0 < self.height != ret_img_msg.height(): + if ( + 0 < self.width != ret_img_msg.width() + or 0 < self.height != ret_img_msg.height() + ): raise ValueError("Incorrect image shape for listener") elif self.rgb: - img = np.zeros((ret_img_msg.height(), ret_img_msg.width(), 3), dtype=self._type, order='C') + img = np.zeros( + (ret_img_msg.height(), ret_img_msg.width(), 3), + dtype=self._type, + order="C", + ) img_port = yarp.ImageRgbFloat() if self.fp else yarp.ImageRgb() else: - img = np.zeros((ret_img_msg.height(), ret_img_msg.width()), dtype=self._type, order='C') + img = np.zeros( + (ret_img_msg.height(), ret_img_msg.width()), + dtype=self._type, + order="C", + ) img_port = yarp.ImageFloat() if self.fp else yarp.ImageMono() img_port.resize(img.shape[1], img.shape[0]) img_port.setExternal(img.data, img.shape[1], img.shape[0]) @@ -247,8 +328,18 @@ def listen(self): @Listeners.register("AudioChunk", "yarp") class YarpAudioChunkListener(YarpListener): - def __init__(self, name: str, in_topic: str, carrier: Literal["tcp", "udp", "mcast"] = "tcp", should_wait: bool = True, - persistent: bool = True, channels: int = 1, rate: int = 44100, chunk: int = -1, **kwargs): + def __init__( + self, + name: str, + in_topic: str, + carrier: Literal["tcp", "udp", "mcast"] = "tcp", + should_wait: bool = True, + persistent: bool = True, + channels: int = 1, + rate: int = 44100, + chunk: int = -1, + **kwargs, + ): """ The AudioChunk listener using the Sound construct parsed as a numpy array. @@ -261,7 +352,14 @@ def __init__(self, name: str, in_topic: str, carrier: Literal["tcp", "udp", "mca :param rate: int: Sampling rate of the audio. Default is 44100 :param chunk: int: Number of samples in the audio chunk. Default is -1 (use the chunk size of the received audio) """ - super().__init__(name, in_topic, carrier=carrier, should_wait=should_wait, persistent=persistent, **kwargs) + super().__init__( + name, + in_topic, + carrier=carrier, + should_wait=should_wait, + persistent=persistent, + **kwargs, + ) self.channels = channels self.rate = rate self.chunk = chunk @@ -283,7 +381,9 @@ def establish(self, repeats: Optional[int] = None, **kwargs): rnd_id = str(np.random.randint(100000, size=1)[0]) self._port = yarp.Port() self._port.open(self.in_topic + ":in" + rnd_id) - self._netconnect = yarp.Network.connect(self.in_topic, self.in_topic + ":in" + rnd_id, self.carrier) + self._netconnect = yarp.Network.connect( + self.in_topic, self.in_topic + ":in" + rnd_id, self.carrier + ) self._sound_msg = yarp.Sound() self._port.read(self._sound_msg) @@ -307,7 +407,10 @@ def listen(self): if not established: return None self._port.read(self._sound_msg) - aud = np.array([self._sound_msg.get(i) for i in range(self._sound_msg.getSamples())], dtype=np.int16) + aud = np.array( + [self._sound_msg.get(i) for i in range(self._sound_msg.getSamples())], + dtype=np.int16, + ) aud = aud.astype(np.float32) / 32767.0 return aud, self.rate diff --git a/wrapyfi/listeners/zeromq.py b/wrapyfi/listeners/zeromq.py index b668ddf..961ee8e 100755 --- a/wrapyfi/listeners/zeromq.py +++ b/wrapyfi/listeners/zeromq.py @@ -15,18 +15,32 @@ SOCKET_IP = os.environ.get("WRAPYFI_ZEROMQ_SOCKET_IP", "127.0.0.1") SOCKET_PUB_PORT = int(os.environ.get("WRAPYFI_ZEROMQ_SOCKET_PUB_PORT", 5555)) -ZEROMQ_PUBSUB_MONITOR_TOPIC = os.environ.get("WRAPYFI_ZEROMQ_PUBSUB_MONITOR_TOPIC", "ZEROMQ/CONNECTIONS") -ZEROMQ_PUBSUB_MONITOR_LISTENER_SPAWN = os.environ.get("WRAPYFI_ZEROMQ_PUBSUB_MONITOR_LISTENER_SPAWN", "process") +ZEROMQ_PUBSUB_MONITOR_TOPIC = os.environ.get( + "WRAPYFI_ZEROMQ_PUBSUB_MONITOR_TOPIC", "ZEROMQ/CONNECTIONS" +) +ZEROMQ_PUBSUB_MONITOR_LISTENER_SPAWN = os.environ.get( + "WRAPYFI_ZEROMQ_PUBSUB_MONITOR_LISTENER_SPAWN", "process" +) WATCHDOG_POLL_REPEAT = None class ZeroMQListener(Listener): - def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait: bool = True, - socket_ip: str = SOCKET_IP, socket_pub_port: int = SOCKET_PUB_PORT, - pubsub_monitor_topic: str = ZEROMQ_PUBSUB_MONITOR_TOPIC, - pubsub_monitor_listener_spawn: Optional[str] = ZEROMQ_PUBSUB_MONITOR_LISTENER_SPAWN, - zeromq_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + in_topic: str, + carrier: str = "tcp", + should_wait: bool = True, + socket_ip: str = SOCKET_IP, + socket_pub_port: int = SOCKET_PUB_PORT, + pubsub_monitor_topic: str = ZEROMQ_PUBSUB_MONITOR_TOPIC, + pubsub_monitor_listener_spawn: Optional[ + str + ] = ZEROMQ_PUBSUB_MONITOR_LISTENER_SPAWN, + zeromq_kwargs: Optional[dict] = None, + **kwargs, + ): """ Initialize the subscriber. @@ -44,20 +58,28 @@ def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait: :param kwargs: dict: Additional kwargs for the subscriber """ if carrier or carrier != "tcp": - logging.warning("[ZeroMQ] ZeroMQ does not support other carriers than TCP for PUB/SUB pattern. Using TCP.") + logging.warning( + "[ZeroMQ] ZeroMQ does not support other carriers than TCP for PUB/SUB pattern. Using TCP." + ) carrier = "tcp" - super().__init__(name, in_topic, carrier=carrier, should_wait=should_wait, **kwargs) + super().__init__( + name, in_topic, carrier=carrier, should_wait=should_wait, **kwargs + ) self.socket_address = f"{carrier}://{socket_ip}:{socket_pub_port}" - ZeroMQMiddlewarePubSub.activate(socket_pub_address=self.socket_address, - pubsub_monitor_topic=pubsub_monitor_topic, - pubsub_monitor_listener_spawn=pubsub_monitor_listener_spawn, - **zeromq_kwargs or {}) + ZeroMQMiddlewarePubSub.activate( + socket_pub_address=self.socket_address, + pubsub_monitor_topic=pubsub_monitor_topic, + pubsub_monitor_listener_spawn=pubsub_monitor_listener_spawn, + **zeromq_kwargs or {}, + ) ZeroMQMiddlewarePubSub().shared_monitor_data.add_topic(self.in_topic) - def await_connection(self, socket=None, in_topic: Optional[str] = None, repeats: Optional[int] = None): + def await_connection( + self, socket=None, in_topic: Optional[str] = None, repeats: Optional[int] = None + ): """ Wait for the connection to be established. @@ -78,7 +100,9 @@ def await_connection(self, socket=None, in_topic: Optional[str] = None, repeats: while repeats > 0 or repeats <= -1: repeats -= 1 - connected = ZeroMQMiddlewarePubSub().shared_monitor_data.is_connected(in_topic) + connected = ZeroMQMiddlewarePubSub().shared_monitor_data.is_connected( + in_topic + ) if connected: logging.info(f"[ZeroMQ] Connected to input port: {in_topic}") break @@ -118,8 +142,15 @@ def __del__(self): @Listeners.register("NativeObject", "zeromq") class ZeroMQNativeObjectListener(ZeroMQListener): - def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait: bool = True, - deserializer_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + in_topic: str, + carrier: str = "tcp", + should_wait: bool = True, + deserializer_kwargs: Optional[dict] = None, + **kwargs, + ): """ The NativeObject listener using the ZeroMQ message construct assuming the data is serialized as a JSON string. Deserializes the data (including plugins) using the decoder and parses it to a native object. @@ -130,7 +161,9 @@ def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait: :param should_wait: bool: Whether the subscriber should wait for the publisher to transmit a message. Default is True :param deserializer_kwargs: dict: Additional kwargs for the deserializer """ - super().__init__(name, in_topic, carrier=carrier, should_wait=should_wait, **kwargs) + super().__init__( + name, in_topic, carrier=carrier, should_wait=should_wait, **kwargs + ) self._socket = self._netconnect = None self._plugin_decoder_hook = JsonDecodeHook(**kwargs).object_hook @@ -149,9 +182,13 @@ def establish(self, repeats: Optional[int] = None, **kwargs): self._socket = zmq.Context.instance().socket(zmq.SUB) for socket_property in ZeroMQMiddlewarePubSub().zeromq_kwargs.items(): if isinstance(socket_property[1], str): - self._socket.setsockopt_string(getattr(zmq, socket_property[0]), socket_property[1]) + self._socket.setsockopt_string( + getattr(zmq, socket_property[0]), socket_property[1] + ) else: - self._socket.setsockopt(getattr(zmq, socket_property[0]), socket_property[1]) + self._socket.setsockopt( + getattr(zmq, socket_property[0]), socket_property[1] + ) self._socket.connect(self.socket_address) self._topic = self.in_topic.encode("utf-8") self._socket.setsockopt_string(zmq.SUBSCRIBE, self.in_topic) @@ -173,7 +210,11 @@ def listen(self): if self._socket.poll(timeout=None if self.should_wait else 0): obj = self._socket.recv_multipart() if obj is not None: - return json.loads(obj[1].decode(), object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs) + return json.loads( + obj[1].decode(), + object_hook=self._plugin_decoder_hook, + **self._deserializer_kwargs, + ) else: return None else: @@ -183,8 +224,19 @@ def listen(self): @Listeners.register("Image", "zeromq") class ZeroMQImageListener(ZeroMQNativeObjectListener): - def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait: bool = True, - width: int = -1, height: int = -1, rgb: bool = True, fp: bool = False, jpg: bool = False, **kwargs): + def __init__( + self, + name: str, + in_topic: str, + carrier: str = "tcp", + should_wait: bool = True, + width: int = -1, + height: int = -1, + rgb: bool = True, + fp: bool = False, + jpg: bool = False, + **kwargs, + ): """ The Image listener using the ZeroMQ message construct parsed to a numpy array. @@ -198,7 +250,9 @@ def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait: :param fp: bool: True if the image is floating point, False if it is integer. Default is False :param jpg: bool: True if the image should be decompressed from JPG. Default is False """ - super().__init__(name, in_topic, carrier=carrier, should_wait=should_wait, **kwargs) + super().__init__( + name, in_topic, carrier=carrier, should_wait=should_wait, **kwargs + ) self.width = width self.height = height self.rgb = rgb @@ -223,14 +277,28 @@ def listen(self): return None elif self.jpg: if self.rgb: - img = cv2.imdecode(np.frombuffer(obj[2], np.uint8), cv2.IMREAD_COLOR) + img = cv2.imdecode( + np.frombuffer(obj[2], np.uint8), cv2.IMREAD_COLOR + ) else: - img = cv2.imdecode(np.frombuffer(obj[2], np.uint8), cv2.IMREAD_GRAYSCALE) + img = cv2.imdecode( + np.frombuffer(obj[2], np.uint8), cv2.IMREAD_GRAYSCALE + ) return img else: - img = json.loads(obj[2].decode(), object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs) - if 0 < self.width != img.shape[1] or 0 < self.height != img.shape[0] or \ - not ((img.ndim == 2 and not self.rgb) or (img.ndim == 3 and self.rgb and img.shape[2] == 3)): + img = json.loads( + obj[2].decode(), + object_hook=self._plugin_decoder_hook, + **self._deserializer_kwargs, + ) + if ( + 0 < self.width != img.shape[1] + or 0 < self.height != img.shape[0] + or not ( + (img.ndim == 2 and not self.rgb) + or (img.ndim == 3 and self.rgb and img.shape[2] == 3) + ) + ): raise ValueError("Incorrect image shape for listener") return img else: @@ -239,8 +307,17 @@ def listen(self): @Listeners.register("AudioChunk", "zeromq") class ZeroMQAudioChunkListener(ZeroMQImageListener): - def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait: bool = True, - channels: int = 1, rate: int = 44100, chunk: int = -1, **kwargs): + def __init__( + self, + name: str, + in_topic: str, + carrier: str = "tcp", + should_wait: bool = True, + channels: int = 1, + rate: int = 44100, + chunk: int = -1, + **kwargs, + ): """ The AudioChunk listener using the ZeroMQ message construct parsed to a numpy array. @@ -253,8 +330,18 @@ def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait: :param rate: int: Sampling rate of the audio. Default is 44100 :param chunk: int: Number of samples in the audio chunk. Default is -1 (use the chunk size of the received audio) """ - super().__init__(name, in_topic, carrier=carrier, should_wait=should_wait, - width=chunk, height=channels, rgb=False, fp=True, jpg=False, **kwargs) + super().__init__( + name, + in_topic, + carrier=carrier, + should_wait=should_wait, + width=chunk, + height=channels, + rgb=False, + fp=True, + jpg=False, + **kwargs, + ) self.channels = channels self.rate = rate self.chunk = chunk @@ -271,10 +358,22 @@ def listen(self): return None if self._socket.poll(timeout=None if self.should_wait else 0): obj = self._socket.recv_multipart() - chunk, channels, rate, aud = json.loads(obj[2].decode(), object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs) if obj is not None else None + chunk, channels, rate, aud = ( + json.loads( + obj[2].decode(), + object_hook=self._plugin_decoder_hook, + **self._deserializer_kwargs, + ) + if obj is not None + else None + ) if 0 < self.rate != rate: raise ValueError("Incorrect audio rate for listener") - if 0 < self.chunk != chunk or self.channels != channels or aud.size != chunk * channels: + if ( + 0 < self.chunk != chunk + or self.channels != channels + or aud.size != chunk * channels + ): raise ValueError("Incorrect audio shape for listener") return aud, rate else: diff --git a/wrapyfi/middlewares/ros.py b/wrapyfi/middlewares/ros.py index 56e9441..1f77e16 100755 --- a/wrapyfi/middlewares/ros.py +++ b/wrapyfi/middlewares/ros.py @@ -17,6 +17,7 @@ class ROSMiddleware(metaclass=SingletonOptimized): and destroy all connections. The ``activate`` and ``deinit`` methods are automatically called when the class is instantiated and when the program exits, respectively. """ + @staticmethod def activate(**kwargs): """ @@ -26,7 +27,14 @@ def activate(**kwargs): """ ROSMiddleware(**kwargs) - def __init__(self, node_name: str = "wrapyfi", anonymous: bool = True, disable_signals: bool = True, *args, **kwargs): + def __init__( + self, + node_name: str = "wrapyfi", + anonymous: bool = True, + disable_signals: bool = True, + *args, + **kwargs, + ): """ Initialize the ROS middleware. This method is automatically called when the class is instantiated. @@ -47,18 +55,18 @@ def deinit(): Deinitialize the ROS middleware. This method is automatically called when the program exits. """ logging.info("Deinitializing ROS middleware") - rospy.signal_shutdown('Deinit') + rospy.signal_shutdown("Deinit") class ROSNativeObjectService(object): - _type = 'wrapyfi_services/ROSNativeObject' - _md5sum = '46a550fd1ca640b396e26ebf988aed7b' # AddTwoInts '6a2e34150c00229791cc89ff309fff21' - _request_class = std_msgs.msg.String - _response_class = std_msgs.msg.String + _type = "wrapyfi_services/ROSNativeObject" + _md5sum = "46a550fd1ca640b396e26ebf988aed7b" # AddTwoInts '6a2e34150c00229791cc89ff309fff21' + _request_class = std_msgs.msg.String + _response_class = std_msgs.msg.String class ROSImageService(object): - _type = 'wrapyfi_services/ROSImage' - _md5sum = 'f720f2021b4bbbe86b0f93b08906381c' # AddTwoInts '6a2e34150c00229791cc89ff309fff21' - _request_class = std_msgs.msg.String - _response_class = sensor_msgs.msg.Image + _type = "wrapyfi_services/ROSImage" + _md5sum = "f720f2021b4bbbe86b0f93b08906381c" # AddTwoInts '6a2e34150c00229791cc89ff309fff21' + _request_class = std_msgs.msg.String + _response_class = sensor_msgs.msg.Image diff --git a/wrapyfi/middlewares/ros2.py b/wrapyfi/middlewares/ros2.py index 6786dc1..b1abd50 100755 --- a/wrapyfi/middlewares/ros2.py +++ b/wrapyfi/middlewares/ros2.py @@ -45,5 +45,6 @@ def deinit(): if rclpy.ok(): rclpy.shutdown() else: - logging.info("ROS 2 context is already shutdown or not initialized. Skipping shutdown.") - + logging.info( + "ROS 2 context is already shutdown or not initialized. Skipping shutdown." + ) diff --git a/wrapyfi/middlewares/yarp.py b/wrapyfi/middlewares/yarp.py index 96ba3b0..c0d300b 100755 --- a/wrapyfi/middlewares/yarp.py +++ b/wrapyfi/middlewares/yarp.py @@ -14,6 +14,7 @@ class YarpMiddleware(metaclass=SingletonOptimized): and destroy all connections. The ``activate`` and ``deinit`` methods are automatically called when the class is instantiated and when the program exits, respectively. """ + @staticmethod def activate(**kwargs): """ diff --git a/wrapyfi/middlewares/zeromq.py b/wrapyfi/middlewares/zeromq.py index 59c42cc..bc70aae 100644 --- a/wrapyfi/middlewares/zeromq.py +++ b/wrapyfi/middlewares/zeromq.py @@ -12,8 +12,20 @@ from wrapyfi.utils import SingletonOptimized from wrapyfi.connect.wrapper import MiddlewareCommunicator -ZEROMQ_POST_OPTS = ["SUBSCRIBE", "UNSUBSCRIBE", "LINGER", "ROUTER_HANDOVER", "ROUTER_MANDATORY", "PROBE_ROUTER", - "XPUB_VERBOSE", "XPUB_VERBOSER", "REQ_CORRELATE", "REQ_RELAXED", "SNDHWM", "RCVHWM"] +ZEROMQ_POST_OPTS = [ + "SUBSCRIBE", + "UNSUBSCRIBE", + "LINGER", + "ROUTER_HANDOVER", + "ROUTER_MANDATORY", + "PROBE_ROUTER", + "XPUB_VERBOSE", + "XPUB_VERBOSER", + "REQ_CORRELATE", + "REQ_RELAXED", + "SNDHWM", + "RCVHWM", +] class ZeroMQMiddlewarePubSub(metaclass=SingletonOptimized): @@ -135,9 +147,18 @@ def activate(**kwargs): except AttributeError: pass - ZeroMQMiddlewarePubSub(zeromq_proxy_kwargs=kwargs, zeromq_post_kwargs=zeromq_post_kwargs, **zeromq_pre_kwargs) - - def __init__(self, zeromq_proxy_kwargs: Optional[dict] = None, zeromq_post_kwargs: Optional[dict] = None, **kwargs): + ZeroMQMiddlewarePubSub( + zeromq_proxy_kwargs=kwargs, + zeromq_post_kwargs=zeromq_post_kwargs, + **zeromq_pre_kwargs, + ) + + def __init__( + self, + zeromq_proxy_kwargs: Optional[dict] = None, + zeromq_post_kwargs: Optional[dict] = None, + **kwargs, + ): """ Initialize the ZeroMQ PUB/SUB middleware. This method is automatically called when the class is instantiated. @@ -151,9 +172,13 @@ def __init__(self, zeromq_proxy_kwargs: Optional[dict] = None, zeromq_post_kwarg self.ctx = zmq.Context.instance() for socket_property in kwargs.items(): if isinstance(socket_property[1], str): - self.ctx.setsockopt_string(getattr(zmq, socket_property[0]), socket_property[1]) + self.ctx.setsockopt_string( + getattr(zmq, socket_property[0]), socket_property[1] + ) else: - self.ctx.setsockopt(getattr(zmq, socket_property[0]), socket_property[1]) + self.ctx.setsockopt( + getattr(zmq, socket_property[0]), socket_property[1] + ) atexit.register(MiddlewareCommunicator.close_all_instances) atexit.register(self.deinit) @@ -162,38 +187,59 @@ def __init__(self, zeromq_proxy_kwargs: Optional[dict] = None, zeromq_post_kwarg # start the pubsub monitor listener if zeromq_proxy_kwargs.get("pubsub_monitor_listener_spawn", False): if zeromq_proxy_kwargs["pubsub_monitor_listener_spawn"] == "process": - self.shared_monitor_data = self.ZeroMQSharedMonitorData(use_multiprocessing=True) - self.monitor = multiprocessing.Process(name='zeromq_pubsub_monitor_listener', - target=self.__init_monitor_listener, - kwargs=zeromq_proxy_kwargs) + self.shared_monitor_data = self.ZeroMQSharedMonitorData( + use_multiprocessing=True + ) + self.monitor = multiprocessing.Process( + name="zeromq_pubsub_monitor_listener", + target=self.__init_monitor_listener, + kwargs=zeromq_proxy_kwargs, + ) self.monitor.daemon = True self.monitor.start() else: # if threaded - self.shared_monitor_data = self.ZeroMQSharedMonitorData(use_multiprocessing=False) - self.monitor = threading.Thread(name='pubsub_monitor_listener_spawn', - target=self.__init_monitor_listener, kwargs=zeromq_proxy_kwargs) - self.monitor.setDaemon(True) # deprecation warning Python3.10. Previous Python versions only support this + self.shared_monitor_data = self.ZeroMQSharedMonitorData( + use_multiprocessing=False + ) + self.monitor = threading.Thread( + name="pubsub_monitor_listener_spawn", + target=self.__init_monitor_listener, + kwargs=zeromq_proxy_kwargs, + ) + self.monitor.setDaemon( + True + ) # deprecation warning Python3.10. Previous Python versions only support this self.monitor.start() time.sleep(1) if zeromq_proxy_kwargs.get("start_proxy_broker", False): if zeromq_proxy_kwargs["proxy_broker_spawn"] == "process": - self.proxy = multiprocessing.Process(name='zeromq_pubsub_broker', target=self.__init_proxy, - kwargs=zeromq_proxy_kwargs) + self.proxy = multiprocessing.Process( + name="zeromq_pubsub_broker", + target=self.__init_proxy, + kwargs=zeromq_proxy_kwargs, + ) self.proxy.daemon = True self.proxy.start() else: # if threaded - self.proxy = threading.Thread(name='zeromq_pubsub_broker', target=self.__init_proxy, - kwargs=zeromq_proxy_kwargs) - self.proxy.setDaemon(True) # deprecation warning Python3.10. Previous Python versions only support this + self.proxy = threading.Thread( + name="zeromq_pubsub_broker", + target=self.__init_proxy, + kwargs=zeromq_proxy_kwargs, + ) + self.proxy.setDaemon( + True + ) # deprecation warning Python3.10. Previous Python versions only support this self.proxy.start() pass @staticmethod - def proxy_thread(socket_pub_address: str = "tcp://127.0.0.1:5555", - socket_sub_address: str = "tcp://127.0.0.1:5556", - inproc_address: str = "inproc://monitor"): + def proxy_thread( + socket_pub_address: str = "tcp://127.0.0.1:5555", + socket_sub_address: str = "tcp://127.0.0.1:5556", + inproc_address: str = "inproc://monitor", + ): """ Proxy thread for the ZeroMQ PUB/SUB proxy. @@ -224,12 +270,17 @@ def proxy_thread(socket_pub_address: str = "tcp://127.0.0.1:5555", if str(e) == "Socket operation on non-socket": pass else: - logging.error(f"[ZeroMQ BROKER] An error occurred in the ZeroMQ proxy: {str(e)}.") + logging.error( + f"[ZeroMQ BROKER] An error occurred in the ZeroMQ proxy: {str(e)}." + ) @staticmethod - def subscription_monitor_thread(inproc_address: str = "inproc://monitor", - socket_sub_address: str = "tcp://127.0.0.1:5556", - pubsub_monitor_topic: str = "ZEROMQ/CONNECTIONS", verbose: bool = False): + def subscription_monitor_thread( + inproc_address: str = "inproc://monitor", + socket_sub_address: str = "tcp://127.0.0.1:5556", + pubsub_monitor_topic: str = "ZEROMQ/CONNECTIONS", + verbose: bool = False, + ): """ Subscription monitor thread for the ZeroMQ PUB/SUB proxy. @@ -259,10 +310,12 @@ def subscription_monitor_thread(inproc_address: str = "inproc://monitor", # ensure the message is a subscription/unsubscription message if len(message) > 1 and (message[0] == 1 or message[0] == 0): event = message[0] - topic = message[1:].decode('utf-8') + topic = message[1:].decode("utf-8") if verbose: - logging.info(f"[ZeroMQ BROKER] Received event: {event}, topic: {topic}") + logging.info( + f"[ZeroMQ BROKER] Received event: {event}, topic: {topic}" + ) # avoid processing messages on the monitor topic. if topic == pubsub_monitor_topic: @@ -275,19 +328,29 @@ def subscription_monitor_thread(inproc_address: str = "inproc://monitor", topic_subscriber_count[topic] = 0 if verbose: - logging.info(f"[ZeroMQ BROKER] Current topic subscriber count: {dict(topic_subscriber_count)}") + logging.info( + f"[ZeroMQ BROKER] Current topic subscriber count: {dict(topic_subscriber_count)}" + ) # publish the updated counts publisher.send_multipart( - [pubsub_monitor_topic.encode(), json.dumps(dict(topic_subscriber_count)).encode()] + [ + pubsub_monitor_topic.encode(), + json.dumps(dict(topic_subscriber_count)).encode(), + ] ) except Exception as e: - logging.error(f"[ZeroMQ BROKER] An error occurred in the ZeroMQ subscription monitor: {str(e)}") - - def __init_proxy(self, socket_pub_address: str = "tcp://127.0.0.1:5555", - socket_sub_address: str = "tcp://127.0.0.1:5556", - pubsub_monitor_topic: str = "ZEROMQ/CONNECTIONS", - **kwargs): + logging.error( + f"[ZeroMQ BROKER] An error occurred in the ZeroMQ subscription monitor: {str(e)}" + ) + + def __init_proxy( + self, + socket_pub_address: str = "tcp://127.0.0.1:5555", + socket_sub_address: str = "tcp://127.0.0.1:5556", + pubsub_monitor_topic: str = "ZEROMQ/CONNECTIONS", + **kwargs, + ): """ Initialize the ZeroMQ PUB/SUB proxy and subscription monitor. @@ -298,20 +361,32 @@ def __init_proxy(self, socket_pub_address: str = "tcp://127.0.0.1:5555", """ inproc_address = "inproc://monitor" - threading.Thread(target=self.proxy_thread, - kwargs={"socket_pub_address": socket_pub_address, - "socket_sub_address": socket_sub_address, - "inproc_address": inproc_address}).start(), - - threading.Thread(target=self.subscription_monitor_thread, - kwargs={"socket_sub_address": socket_sub_address, - "inproc_address": inproc_address, - "pubsub_monitor_topic": pubsub_monitor_topic, - "verbose": kwargs.get("verbose", False)}).start() - - def __init_monitor_listener(self, socket_pub_address: str = "tcp://127.0.0.1:5555", - pubsub_monitor_topic: str = "ZEROMQ/CONNECTIONS", - verbose: bool = False, **kwargs): + threading.Thread( + target=self.proxy_thread, + kwargs={ + "socket_pub_address": socket_pub_address, + "socket_sub_address": socket_sub_address, + "inproc_address": inproc_address, + }, + ).start(), + + threading.Thread( + target=self.subscription_monitor_thread, + kwargs={ + "socket_sub_address": socket_sub_address, + "inproc_address": inproc_address, + "pubsub_monitor_topic": pubsub_monitor_topic, + "verbose": kwargs.get("verbose", False), + }, + ).start() + + def __init_monitor_listener( + self, + socket_pub_address: str = "tcp://127.0.0.1:5555", + pubsub_monitor_topic: str = "ZEROMQ/CONNECTIONS", + verbose: bool = False, + **kwargs, + ): """ Initialize the ZeroMQ PUB/SUB monitor listener. @@ -328,7 +403,7 @@ def __init_monitor_listener(self, socket_pub_address: str = "tcp://127.0.0.1:555 while True: _, message = subscriber.recv_multipart() - data = json.loads(message.decode('utf-8')) + data = json.loads(message.decode("utf-8")) topic = list(data.keys())[0] if verbose: logging.info(f"[ZeroMQ] Topic: {topic}, Data: {data}") @@ -336,17 +411,23 @@ def __init_monitor_listener(self, socket_pub_address: str = "tcp://127.0.0.1:555 if topic in self.shared_monitor_data.get_topics(): self.shared_monitor_data.update_connection(topic, data) if data[topic] == 0: - logging.info(f"[ZeroMQ] Subscriber disconnected from topic: {topic}") + logging.info( + f"[ZeroMQ] Subscriber disconnected from topic: {topic}" + ) self.shared_monitor_data.remove_connection(topic) else: logging.info(f"[ZeroMQ] Subscriber connected to topic: {topic}") if verbose: for monitored_topic in self.shared_monitor_data.get_topics(): - logging.info(f"[ZeroMQ] Monitored topic from main process: {monitored_topic}") + logging.info( + f"[ZeroMQ] Monitored topic from main process: {monitored_topic}" + ) except Exception as e: - logging.error(f"[ZeroMQ] An error occurred in the ZeroMQ subscription monitor listener: {str(e)}") + logging.error( + f"[ZeroMQ] An error occurred in the ZeroMQ subscription monitor listener: {str(e)}" + ) @staticmethod def deinit(): @@ -381,10 +462,19 @@ def activate(**kwargs): except AttributeError: pass - ZeroMQMiddlewareReqRep(zeromq_proxy_kwargs=kwargs, zeromq_post_kwargs=zeromq_post_kwargs, **zeromq_pre_kwargs) - - def __init__(self, zeromq_proxy_kwargs: Optional[dict] = None, zeromq_post_kwargs: Optional[dict] = None, *args, - **kwargs): + ZeroMQMiddlewareReqRep( + zeromq_proxy_kwargs=kwargs, + zeromq_post_kwargs=zeromq_post_kwargs, + **zeromq_pre_kwargs, + ) + + def __init__( + self, + zeromq_proxy_kwargs: Optional[dict] = None, + zeromq_post_kwargs: Optional[dict] = None, + *args, + **kwargs, + ): """ Initialize the ZeroMQ REQ/REP middleware. This method is automatically called when the class is instantiated. @@ -399,28 +489,43 @@ def __init__(self, zeromq_proxy_kwargs: Optional[dict] = None, zeromq_post_kwarg self.ctx = zmq.Context.instance() for socket_property in kwargs.items(): if isinstance(socket_property[1], str): - self.ctx.setsockopt_string(getattr(zmq, socket_property[0]), socket_property[1]) + self.ctx.setsockopt_string( + getattr(zmq, socket_property[0]), socket_property[1] + ) else: - self.ctx.setsockopt(getattr(zmq, socket_property[0]), socket_property[1]) + self.ctx.setsockopt( + getattr(zmq, socket_property[0]), socket_property[1] + ) atexit.register(MiddlewareCommunicator.close_all_instances) atexit.register(self.deinit) if zeromq_proxy_kwargs is not None and zeromq_proxy_kwargs: if zeromq_proxy_kwargs["proxy_broker_spawn"] == "process": - self.proxy = multiprocessing.Process(name='zeromq_reqrep_broker', target=self.__init_device, - kwargs=zeromq_proxy_kwargs) + self.proxy = multiprocessing.Process( + name="zeromq_reqrep_broker", + target=self.__init_device, + kwargs=zeromq_proxy_kwargs, + ) self.proxy.daemon = True self.proxy.start() else: # if threaded - self.proxy = threading.Thread(name='zeromq_reqrep_broker', target=self.__init_device, - kwargs=zeromq_proxy_kwargs) - self.proxy.setDaemon(True) # deprecation warning Python3.10. Previous Python versions only support this + self.proxy = threading.Thread( + name="zeromq_reqrep_broker", + target=self.__init_device, + kwargs=zeromq_proxy_kwargs, + ) + self.proxy.setDaemon( + True + ) # deprecation warning Python3.10. Previous Python versions only support this self.proxy.start() pass @staticmethod - def __init_device(socket_rep_address: str = "tcp://127.0.0.1:5559", - socket_req_address: str = "tcp://127.0.0.1:5560", **kwargs): + def __init_device( + socket_rep_address: str = "tcp://127.0.0.1:5559", + socket_req_address: str = "tcp://127.0.0.1:5560", + **kwargs, + ): """ Initialize the ZeroMQ REQ/REP device broker. @@ -453,7 +558,9 @@ def __init_device(socket_rep_address: str = "tcp://127.0.0.1:5559", if str(e) == "Socket operation on non-socket": pass else: - logging.error(f"[ZeroMQ] An error occurred in the ZeroMQ proxy: {str(e)}.") + logging.error( + f"[ZeroMQ] An error occurred in the ZeroMQ proxy: {str(e)}." + ) @staticmethod def deinit(): @@ -485,11 +592,19 @@ def activate(**kwargs): except AttributeError: pass - ZeroMQMiddlewareParamServer(zeromq_proxy_kwargs=kwargs, zeromq_post_kwargs=zeromq_post_kwargs, - **zeromq_pre_kwargs) - - def __init__(self, zeromq_proxy_kwargs: Optional[dict] = None, zeromq_post_kwargs: Optional = None, *args, - **kwargs): + ZeroMQMiddlewareParamServer( + zeromq_proxy_kwargs=kwargs, + zeromq_post_kwargs=zeromq_post_kwargs, + **zeromq_pre_kwargs, + ) + + def __init__( + self, + zeromq_proxy_kwargs: Optional[dict] = None, + zeromq_post_kwargs: Optional = None, + *args, + **kwargs, + ): """ Initialize the ZeroMQ parameter server middleware. This method is automatically called when the class is instantiated. @@ -504,9 +619,13 @@ def __init__(self, zeromq_proxy_kwargs: Optional[dict] = None, zeromq_post_kwarg self.ctx = zmq.Context.instance() for socket_property in kwargs.items(): if isinstance(socket_property[1], str): - self.ctx.setsockopt_string(getattr(zmq, socket_property[0]), socket_property[1]) + self.ctx.setsockopt_string( + getattr(zmq, socket_property[0]), socket_property[1] + ) else: - self.ctx.setsockopt(getattr(zmq, socket_property[0]), socket_property[1]) + self.ctx.setsockopt( + getattr(zmq, socket_property[0]), socket_property[1] + ) atexit.register(MiddlewareCommunicator.close_all_instances) atexit.register(self.deinit) @@ -516,32 +635,55 @@ def __init__(self, zeromq_proxy_kwargs: Optional[dict] = None, zeromq_post_kwarg self.params = self.manager.dict() self.params["WRAPYFI_ACTIVE"] = "True" if zeromq_proxy_kwargs["proxy_broker_spawn"] == "process": - self.param_broadcaster = multiprocessing.Process(name='zeromq_param_broadcaster', - target=self.__init_broadcaster, - kwargs=zeromq_proxy_kwargs, args=(self.params,)) + self.param_broadcaster = multiprocessing.Process( + name="zeromq_param_broadcaster", + target=self.__init_broadcaster, + kwargs=zeromq_proxy_kwargs, + args=(self.params,), + ) self.param_broadcaster.daemon = True self.param_broadcaster.start() - self.param_server = multiprocessing.Process(name='zeromq_param_server', target=self.__init_server, - kwargs=zeromq_proxy_kwargs, args=(self.params,)) + self.param_server = multiprocessing.Process( + name="zeromq_param_server", + target=self.__init_server, + kwargs=zeromq_proxy_kwargs, + args=(self.params,), + ) self.param_server.daemon = True self.param_server.start() else: # if threaded - self.param_broadcaster = threading.Thread(name='zeromq_param_broadcaster', - target=self.__init_broadcaster, - kwargs=zeromq_proxy_kwargs, args=(self.params,)) - self.param_broadcaster.setDaemon(True) # deprecation warning Python3.10. Previous Python versions only support this + self.param_broadcaster = threading.Thread( + name="zeromq_param_broadcaster", + target=self.__init_broadcaster, + kwargs=zeromq_proxy_kwargs, + args=(self.params,), + ) + self.param_broadcaster.setDaemon( + True + ) # deprecation warning Python3.10. Previous Python versions only support this self.param_broadcaster.start() - self.param_server = threading.Thread(name='zeromq_param_server', target=self.__init_server, - kwargs=zeromq_proxy_kwargs, args=(self.params,)) - self.param_server.setDaemon(True) # deprecation warning Python3.10. Previous Python versions only support this + self.param_server = threading.Thread( + name="zeromq_param_server", + target=self.__init_server, + kwargs=zeromq_proxy_kwargs, + args=(self.params,), + ) + self.param_server.setDaemon( + True + ) # deprecation warning Python3.10. Previous Python versions only support this self.param_server.start() pass @staticmethod - def __init_broadcaster(params, param_pub_address: str = "tcp://127.0.0.1:5655", - param_sub_address: str = "tcp://127.0.0.1:5656", - param_poll_interval=1, verbose=False, **kwargs): + def __init_broadcaster( + params, + param_pub_address: str = "tcp://127.0.0.1:5655", + param_sub_address: str = "tcp://127.0.0.1:5656", + param_poll_interval=1, + verbose=False, + **kwargs, + ): """ Initialize the ZeroMQ parameter server broadcaster. @@ -580,7 +722,9 @@ def __init_broadcaster(params, param_pub_address: str = "tcp://127.0.0.1:5655", if xpub_socket in event: message = xpub_socket.recv_multipart() if verbose: - logging.info("[ZeroMQ BROKER] xpub_socket recv message: %r" % message) + logging.info( + "[ZeroMQ BROKER] xpub_socket recv message: %r" % message + ) if message[0].startswith(b"\x00"): root_topics.remove(message[0][1:].decode("utf-8")) elif message[0].startswith(b"\x01"): @@ -590,29 +734,52 @@ def __init_broadcaster(params, param_pub_address: str = "tcp://127.0.0.1:5655", if xsub_socket in event: message = xsub_socket.recv_multipart() if verbose: - logging.info("[ZeroMQ BROKER] xsub_socket recv message: %r" % message) + logging.info( + "[ZeroMQ BROKER] xsub_socket recv message: %r" % message + ) if message[0].startswith(b"\x01") or message[0].startswith(b"\x00"): xpub_socket.send_multipart(message) else: fltr_key = message[0].decode("utf-8") - fltr_message = {key: val for key, val in params.items() - if key.startswith(fltr_key)} + fltr_message = { + key: val + for key, val in params.items() + if key.startswith(fltr_key) + } if verbose: - logging.info("[ZeroMQ BROKER] xsub_socket filtered message: %r" % fltr_message) + logging.info( + "[ZeroMQ BROKER] xsub_socket filtered message: %r" + % fltr_message + ) for key, val in fltr_message.items(): prefix, param = key.rsplit("/", 1) if "/" in key else ("", key) - xpub_socket.send_multipart([prefix.encode("utf-8"), param.encode("utf-8"), val.encode("utf-8")]) + xpub_socket.send_multipart( + [ + prefix.encode("utf-8"), + param.encode("utf-8"), + val.encode("utf-8"), + ] + ) # xpub_socket.send_multipart(message) if event: update_trigger = True if param_server is not None: - update_trigger, cached_params = ZeroMQMiddlewareParamServer.publish_params( - param_server, params, cached_params, root_topics, update_trigger) + update_trigger, cached_params = ( + ZeroMQMiddlewareParamServer.publish_params( + param_server, params, cached_params, root_topics, update_trigger + ) + ) @staticmethod - def publish_params(param_server, params: dict, cached_params: dict, root_topics: set, update_trigger: bool): + def publish_params( + param_server, + params: dict, + cached_params: dict, + root_topics: set, + update_trigger: bool, + ): """ Publish parameters to the parameter server. @@ -637,12 +804,16 @@ def publish_params(param_server, params: dict, cached_params: dict, root_topics: # publish updates for all parameters to subscribed clients for key, val in params.items(): prefix, param = key.rsplit("/", 1) if "/" in key else ("", key) - param_server.send_multipart([prefix.encode("utf-8"), param.encode("utf-8"), val.encode("utf-8")]) + param_server.send_multipart( + [prefix.encode("utf-8"), param.encode("utf-8"), val.encode("utf-8")] + ) return update_trigger, cached_params @staticmethod - def __init_server(params: dict, param_reqrep_address: str = "tcp://127.0.0.1:5659", **kwargs): + def __init_server( + params: dict, param_reqrep_address: str = "tcp://127.0.0.1:5659", **kwargs + ): """ Initialize the ZeroMQ parameter server. @@ -658,7 +829,11 @@ def __init_server(params: dict, param_reqrep_address: str = "tcp://127.0.0.1:565 if request.startswith("get"): try: # extract the parameter name and namespace prefix from the request - prefix, param = request[4:].rsplit("/", 1) if "/" in request[4:] else ("", request[4:]) + prefix, param = ( + request[4:].rsplit("/", 1) + if "/" in request[4:] + else ("", request[4:]) + ) # construct the full parameter name with the namespace prefix full_param = "/".join([prefix, param]) if prefix else param if full_param in params: diff --git a/wrapyfi/plugins/cupy_array.py b/wrapyfi/plugins/cupy_array.py index f4aac63..00e318b 100644 --- a/wrapyfi/plugins/cupy_array.py +++ b/wrapyfi/plugins/cupy_array.py @@ -27,6 +27,7 @@ try: import cupy as cp + HAVE_CUPY = True except ImportError: HAVE_CUPY = False @@ -75,7 +76,9 @@ def __init__(self, load_cupy_device=None, map_cupy_devices=None, **kwargs): self.map_cupy_devices = map_cupy_devices or {} if load_cupy_device is not None: self.map_cupy_devices["default"] = load_cupy_device - self.map_cupy_devices = {k: cupy_device_to_str(v) for k, v in self.map_cupy_devices.items()} + self.map_cupy_devices = { + k: cupy_device_to_str(v) for k, v in self.map_cupy_devices.items() + } def encode(self, obj, *args, **kwargs): """ @@ -86,9 +89,11 @@ def encode(self, obj, *args, **kwargs): """ with io.BytesIO() as memfile: np.save(memfile, cp.asnumpy(obj)) - obj_data = base64.b64encode(memfile.getvalue()).decode('ascii') + obj_data = base64.b64encode(memfile.getvalue()).decode("ascii") obj_device = cupy_device_to_str(obj.device) - return True, dict(__wrapyfi__=(str(self.__class__.__name__), obj_data, obj_device)) + return True, dict( + __wrapyfi__=(str(self.__class__.__name__), obj_data, obj_device) + ) def decode(self, obj_type, obj_full, *args, **kwargs): """ @@ -97,9 +102,10 @@ def decode(self, obj_type, obj_full, *args, **kwargs): :param obj_full: tuple: A tuple containing the encoded data string and device string :return: Tuple[bool, cp.ndarray] """ - with io.BytesIO(base64.b64decode(obj_full[1].encode('ascii'))) as memfile: - obj_device_str = self.map_cupy_devices.get(obj_full[2], self.map_cupy_devices.get("default", "cuda:0")) + with io.BytesIO(base64.b64decode(obj_full[1].encode("ascii"))) as memfile: + obj_device_str = self.map_cupy_devices.get( + obj_full[2], self.map_cupy_devices.get("default", "cuda:0") + ) obj_device = cupy_str_to_device(obj_device_str) with obj_device: return True, cp.array(np.load(memfile)) - diff --git a/wrapyfi/plugins/dask_data.py b/wrapyfi/plugins/dask_data.py index f8b3e40..e206168 100644 --- a/wrapyfi/plugins/dask_data.py +++ b/wrapyfi/plugins/dask_data.py @@ -29,12 +29,19 @@ import dask.dataframe as dd import dask.array as da import pandas as pd + HAVE_DASK = True except ImportError: HAVE_DASK = False -@PluginRegistrar.register(types=None if not HAVE_DASK else dd.DataFrame.__mro__[:-1] + dd.Series.__mro__[:-1] + da.Array.__mro__[:-1]) +@PluginRegistrar.register( + types=( + None + if not HAVE_DASK + else dd.DataFrame.__mro__[:-1] + dd.Series.__mro__[:-1] + da.Array.__mro__[:-1] + ) +) class DaskData(Plugin): def __init__(self, **kwargs): """ @@ -59,20 +66,27 @@ def encode(self, obj, *args, **kwargs): with io.BytesIO() as memfile: if isinstance(obj, dd.DataFrame): pandas_df = obj.compute().reset_index() - memfile.write(pandas_df.to_json(orient="records").encode('ascii')) + memfile.write(pandas_df.to_json(orient="records").encode("ascii")) obj_partitions = obj.npartitions - obj_type = 'DataFrame' + obj_type = "DataFrame" elif isinstance(obj, dd.Series): pandas_ds = obj.compute().reset_index() - memfile.write(pandas_ds.to_json(orient="records").encode('ascii')) + memfile.write(pandas_ds.to_json(orient="records").encode("ascii")) obj_partitions = obj.npartitions - obj_type = 'Series' + obj_type = "Series" elif isinstance(obj, da.Array): np.save(memfile, obj.compute(), allow_pickle=True) - obj_type = 'Array' + obj_type = "Array" memfile.seek(0) - obj_data = base64.b64encode(memfile.read()).decode('ascii') - return True, dict(__wrapyfi__=(str(self.__class__.__name__), obj_data, obj_type, obj_partitions)) + obj_data = base64.b64encode(memfile.read()).decode("ascii") + return True, dict( + __wrapyfi__=( + str(self.__class__.__name__), + obj_data, + obj_type, + obj_partitions, + ) + ) def decode(self, obj_type, obj_full, *args, **kwargs): """ @@ -86,18 +100,26 @@ def decode(self, obj_type, obj_full, *args, **kwargs): - bool: Always True, indicating that the decoding was successful - Union[dd.DataFrame, da.Array]: The decoded Dask data """ - obj_data = base64.b64decode(obj_full[1].encode('ascii')) + obj_data = base64.b64decode(obj_full[1].encode("ascii")) obj_type = obj_full[2] obj_partitions = obj_full[3] with io.BytesIO(obj_data) as memfile: - if obj_type == 'DataFrame': - pandas_df = pd.read_json(memfile.read().decode('ascii'), orient="records") - pandas_df.set_index('index', inplace=True) # Set the index back after reading from JSON + if obj_type == "DataFrame": + pandas_df = pd.read_json( + memfile.read().decode("ascii"), orient="records" + ) + pandas_df.set_index( + "index", inplace=True + ) # Set the index back after reading from JSON return True, dd.from_pandas(pandas_df, npartitions=obj_partitions) - elif obj_type == 'Series': - pandas_ds = pd.read_json(memfile.read().decode('ascii'), orient="records") - pandas_ds.set_index('index', inplace=True) # Set the index back after reading from JSON + elif obj_type == "Series": + pandas_ds = pd.read_json( + memfile.read().decode("ascii"), orient="records" + ) + pandas_ds.set_index( + "index", inplace=True + ) # Set the index back after reading from JSON return True, dd.from_pandas(pandas_ds, npartitions=obj_partitions) - elif obj_type == 'Array': + elif obj_type == "Array": np_array = np.load(memfile, allow_pickle=True) - return True, da.from_array(np_array, chunks=np_array.shape) \ No newline at end of file + return True, da.from_array(np_array, chunks=np_array.shape) diff --git a/wrapyfi/plugins/jax_tensor.py b/wrapyfi/plugins/jax_tensor.py index fdf43ff..d7d3bf1 100644 --- a/wrapyfi/plugins/jax_tensor.py +++ b/wrapyfi/plugins/jax_tensor.py @@ -26,6 +26,7 @@ try: import jax + HAVE_JAX = True except ImportError: HAVE_JAX = False @@ -59,7 +60,7 @@ def encode(self, obj, *args, **kwargs): """ with io.BytesIO() as memfile: np.save(memfile, np.asarray(obj)) - obj_data = base64.b64encode(memfile.getvalue()).decode('ascii') + obj_data = base64.b64encode(memfile.getvalue()).decode("ascii") return True, dict(__wrapyfi__=(str(self.__class__.__name__), obj_data)) def decode(self, obj_type, obj_full, *args, **kwargs): @@ -74,6 +75,5 @@ def decode(self, obj_type, obj_full, *args, **kwargs): - bool: Always True, indicating that the decoding was successful - jax.numpy.DeviceArray: The decoded JAX tensor data """ - with io.BytesIO(base64.b64decode(obj_full[1].encode('ascii'))) as memfile: + with io.BytesIO(base64.b64decode(obj_full[1].encode("ascii"))) as memfile: return True, jax.numpy.array(np.load(memfile)) - diff --git a/wrapyfi/plugins/mxnet_tensor.py b/wrapyfi/plugins/mxnet_tensor.py index 400bf6f..6e396cf 100644 --- a/wrapyfi/plugins/mxnet_tensor.py +++ b/wrapyfi/plugins/mxnet_tensor.py @@ -28,6 +28,7 @@ try: import mxnet + HAVE_MXNET = True except ImportError: HAVE_MXNET = False @@ -41,7 +42,7 @@ def mxnet_device_to_str(device): :return: Union[str, dict]: A string or dictionary representing the MXNet device """ if device is None: - return 'cpu:0' + return "cpu:0" elif isinstance(device, dict): device_rets = {} for k, v in device.items(): @@ -52,7 +53,7 @@ def mxnet_device_to_str(device): elif isinstance(device, str): return device.replace("gpu", "cuda") else: - raise ValueError(f'Unknown device type {type(device)}') + raise ValueError(f"Unknown device type {type(device)}") @lru_cache(maxsize=None) @@ -69,16 +70,18 @@ def mxnet_str_to_device(device): return device elif isinstance(device, str): try: - device_type, device_id = device.split(':') + device_type, device_id = device.split(":") except ValueError: device_type = device device_id = 0 return mxnet.Context(device_type.replace("cuda", "gpu"), int(device_id)) else: - raise ValueError(f'Unknown device type {type(device)}') + raise ValueError(f"Unknown device type {type(device)}") -@PluginRegistrar.register(types=None if not HAVE_MXNET else mxnet.nd.NDArray.__mro__[:-1]) +@PluginRegistrar.register( + types=None if not HAVE_MXNET else mxnet.nd.NDArray.__mro__[:-1] +) class MXNetTensor(Plugin): def __init__(self, load_mxnet_device=None, map_mxnet_devices=None, **kwargs): """ @@ -89,7 +92,7 @@ def __init__(self, load_mxnet_device=None, map_mxnet_devices=None, **kwargs): """ self.map_mxnet_devices = map_mxnet_devices or {} if load_mxnet_device is not None: - self.map_mxnet_devices['default'] = load_mxnet_device + self.map_mxnet_devices["default"] = load_mxnet_device self.map_mxnet_devices = mxnet_device_to_str(self.map_mxnet_devices) def encode(self, obj, *args, **kwargs): @@ -106,9 +109,11 @@ def encode(self, obj, *args, **kwargs): """ with io.BytesIO() as memfile: np.save(memfile, obj.asnumpy()) - obj_data = base64.b64encode(memfile.getvalue()).decode('ascii') + obj_data = base64.b64encode(memfile.getvalue()).decode("ascii") obj_device = mxnet_device_to_str(obj.context) - return True, dict(__wrapyfi__=(str(self.__class__.__name__), obj_data, obj_device)) + return True, dict( + __wrapyfi__=(str(self.__class__.__name__), obj_data, obj_device) + ) def decode(self, obj_type, obj_full, *args, **kwargs): """ @@ -122,8 +127,10 @@ def decode(self, obj_type, obj_full, *args, **kwargs): - bool: Always True, indicating that the decoding was successful - mxnet.nd.NDArray: The decoded MXNet tensor data """ - with io.BytesIO(base64.b64decode(obj_full[1].encode('ascii'))) as memfile: - obj_device = self.map_mxnet_devices.get(obj_full[2], self.map_mxnet_devices.get('default', None)) + with io.BytesIO(base64.b64decode(obj_full[1].encode("ascii"))) as memfile: + obj_device = self.map_mxnet_devices.get( + obj_full[2], self.map_mxnet_devices.get("default", None) + ) if obj_device is not None: obj_device = mxnet_str_to_device(obj_device) return True, mxnet.nd.array(np.load(memfile), ctx=obj_device) diff --git a/wrapyfi/plugins/paddle_tensor.py b/wrapyfi/plugins/paddle_tensor.py index 2a1113e..c66dcc3 100644 --- a/wrapyfi/plugins/paddle_tensor.py +++ b/wrapyfi/plugins/paddle_tensor.py @@ -29,6 +29,7 @@ try: import paddle from paddle import Tensor + HAVE_PADDLE = True except ImportError: HAVE_PADDLE = False @@ -53,7 +54,7 @@ def paddle_device_to_str(device): elif isinstance(device, str): return device.replace("cuda", "gpu") else: - raise ValueError(f'Unknown device type {type(device)}') + raise ValueError(f"Unknown device type {type(device)}") @lru_cache(maxsize=None) @@ -70,12 +71,12 @@ def paddle_str_to_device(device): return device elif isinstance(device, str): try: - device_type, device_id = device.split(':') + device_type, device_id = device.split(":") except ValueError: device_type = device return paddle.device._convert_to_place(device_type.replace("cuda", "gpu")) else: - raise ValueError(f'Unknown device type {type(device)}') + raise ValueError(f"Unknown device type {type(device)}") @PluginRegistrar.register(types=None if not HAVE_PADDLE else paddle.Tensor.__mro__[:-1]) @@ -89,7 +90,7 @@ def __init__(self, load_paddle_device=None, map_paddle_devices=None, **kwargs): """ self.map_paddle_devices = map_paddle_devices or {} if load_paddle_device is not None: - self.map_paddle_devices['default'] = load_paddle_device + self.map_paddle_devices["default"] = load_paddle_device self.map_paddle_devices = paddle_device_to_str(self.map_paddle_devices) def encode(self, obj, *args, **kwargs): @@ -106,9 +107,11 @@ def encode(self, obj, *args, **kwargs): """ with io.BytesIO() as memfile: paddle.save(obj, memfile) - obj_data = base64.b64encode(memfile.getvalue()).decode('ascii') + obj_data = base64.b64encode(memfile.getvalue()).decode("ascii") obj_device = paddle_device_to_str(obj.place) - return True, dict(__wrapyfi__=(str(self.__class__.__name__), obj_data, obj_device)) + return True, dict( + __wrapyfi__=(str(self.__class__.__name__), obj_data, obj_device) + ) def decode(self, obj_type, obj_full, *args, **kwargs): """ @@ -122,11 +125,12 @@ def decode(self, obj_type, obj_full, *args, **kwargs): - bool: Always True, indicating that the decoding was successful - paddle.Tensor: The decoded PaddlePaddle tensor data """ - with io.BytesIO(base64.b64decode(obj_full[1].encode('ascii'))) as memfile: - obj_device = self.map_paddle_devices.get(obj_full[2], self.map_paddle_devices.get('default', None)) + with io.BytesIO(base64.b64decode(obj_full[1].encode("ascii"))) as memfile: + obj_device = self.map_paddle_devices.get( + obj_full[2], self.map_paddle_devices.get("default", None) + ) if obj_device is not None: obj_device = paddle_str_to_device(obj_device) return True, paddle.Tensor(paddle.load(memfile), place=obj_device) else: return True, paddle.load(memfile) - diff --git a/wrapyfi/plugins/pandas_data.py b/wrapyfi/plugins/pandas_data.py index 6a0d005..c0641a6 100644 --- a/wrapyfi/plugins/pandas_data.py +++ b/wrapyfi/plugins/pandas_data.py @@ -26,12 +26,19 @@ try: import pandas + HAVE_PANDAS = True except ImportError: HAVE_PANDAS = False -@PluginRegistrar.register(types=None if not HAVE_PANDAS else pandas.DataFrame.__mro__[:-1] + pandas.Series.__mro__[:-1]) +@PluginRegistrar.register( + types=( + None + if not HAVE_PANDAS + else pandas.DataFrame.__mro__[:-1] + pandas.Series.__mro__[:-1] + ) +) class PandasData(Plugin): def __init__(self, **kwargs): """ @@ -53,13 +60,15 @@ def encode(self, obj, *args, **kwargs): """ with io.BytesIO() as memfile: if isinstance(obj, pandas.DataFrame): - obj_type = 'DataFrame' + obj_type = "DataFrame" obj.to_json(memfile, orient="split") elif isinstance(obj, pandas.Series): - obj_type = 'Series' + obj_type = "Series" obj.to_frame().to_json(memfile, orient="split") - obj_data = base64.b64encode(memfile.getvalue()).decode('ascii') - return True, dict(__wrapyfi__=(str(self.__class__.__name__), obj_data, obj_type)) + obj_data = base64.b64encode(memfile.getvalue()).decode("ascii") + return True, dict( + __wrapyfi__=(str(self.__class__.__name__), obj_data, obj_type) + ) def decode(self, obj_type, obj_full, *args, **kwargs): """ @@ -73,10 +82,9 @@ def decode(self, obj_type, obj_full, *args, **kwargs): - bool: Always True, indicating that the decoding was successful - Union[pandas.DataFrame, pandas.Series]: The decoded pandas data """ - with io.BytesIO(base64.b64decode(obj_full[1].encode('ascii'))) as memfile: + with io.BytesIO(base64.b64decode(obj_full[1].encode("ascii"))) as memfile: obj = pandas.read_json(memfile, orient="split") obj_type = obj_full[2] - if obj_type == 'Series': + if obj_type == "Series": obj = obj.iloc[:, 0] return True, obj - diff --git a/wrapyfi/plugins/pillow_image.py b/wrapyfi/plugins/pillow_image.py index 42f56fd..f4f0447 100644 --- a/wrapyfi/plugins/pillow_image.py +++ b/wrapyfi/plugins/pillow_image.py @@ -24,6 +24,7 @@ try: from PIL import Image + HAVE_PIL = True except ImportError: HAVE_PIL = False @@ -50,12 +51,14 @@ def encode(self, obj, *args, **kwargs): - '__wrapyfi__': A tuple containing the class name and encoded data string, with optional image size and mode for raw data """ if obj.format is None: - obj_data = base64.b64encode(obj.tobytes()).decode('ascii') - return True, dict(__wrapyfi__=(str(self.__class__.__name__), obj_data, obj.size, obj.mode)) + obj_data = base64.b64encode(obj.tobytes()).decode("ascii") + return True, dict( + __wrapyfi__=(str(self.__class__.__name__), obj_data, obj.size, obj.mode) + ) else: with io.BytesIO() as memfile: obj.save(memfile, format=obj.format) - obj_data = memfile.getvalue().decode('latin1') + obj_data = memfile.getvalue().decode("latin1") return True, dict(__wrapyfi__=(str(self.__class__.__name__), obj_data)) def decode(self, obj_type, obj_full, *args, **kwargs): @@ -71,10 +74,11 @@ def decode(self, obj_type, obj_full, *args, **kwargs): - Image.Image: The decoded PIL Image data """ if len(obj_full) == 4: - with io.BytesIO(obj_full[1].encode('ascii')) as memfile: - return True, Image.frombytes(obj_full[3], obj_full[2], memfile.getvalue(), "raw") + with io.BytesIO(obj_full[1].encode("ascii")) as memfile: + return True, Image.frombytes( + obj_full[3], obj_full[2], memfile.getvalue(), "raw" + ) else: - with io.BytesIO(obj_full[1].encode('latin1')) as memfile: + with io.BytesIO(obj_full[1].encode("latin1")) as memfile: memfile.seek(0) return True, Image.open(memfile).copy() - diff --git a/wrapyfi/plugins/pint_quantities.py b/wrapyfi/plugins/pint_quantities.py index 6a9ab9e..ad64cc6 100644 --- a/wrapyfi/plugins/pint_quantities.py +++ b/wrapyfi/plugins/pint_quantities.py @@ -24,6 +24,7 @@ try: from pint import Quantity + HAVE_PINT = True except ImportError: HAVE_PINT = False @@ -50,13 +51,14 @@ def encode(self, obj, *args, **kwargs): - '__wrapyfi__': A tuple containing the class name, encoded data string, and object type """ if isinstance(obj, Quantity): - obj_type = 'Quantity' - obj_data = json.dumps({ - 'magnitude': obj.magnitude, - 'units': str(obj.units) - }).encode('ascii') - obj_data = base64.b64encode(obj_data).decode('ascii') - return True, dict(__wrapyfi__=(str(self.__class__.__name__), obj_data, obj_type)) + obj_type = "Quantity" + obj_data = json.dumps( + {"magnitude": obj.magnitude, "units": str(obj.units)} + ).encode("ascii") + obj_data = base64.b64encode(obj_data).decode("ascii") + return True, dict( + __wrapyfi__=(str(self.__class__.__name__), obj_data, obj_type) + ) else: # TypeError("Unknown object type: {}".format(obj_type)) return False, {} @@ -73,13 +75,16 @@ def decode(self, obj_type, obj_full, *args, **kwargs): - bool: Indicating that the decoding was successful - pint.Quantity: The decoded Pint Quantity data """ - obj_data = base64.b64decode(obj_full[1].encode('ascii')).decode('ascii') + obj_data = base64.b64decode(obj_full[1].encode("ascii")).decode("ascii") obj_data = json.loads(obj_data) obj_type = obj_full[2] - if obj_type == 'Quantity': + if obj_type == "Quantity": from pint import UnitRegistry + ureg = UnitRegistry() - obj = Quantity(obj_data['magnitude'], ureg.parse_expression(obj_data['units'])) + obj = Quantity( + obj_data["magnitude"], ureg.parse_expression(obj_data["units"]) + ) return True, obj else: # TypeError("Unknown object type: {}".format(obj_type)) diff --git a/wrapyfi/plugins/pyarrow_array.py b/wrapyfi/plugins/pyarrow_array.py index 3fd29be..24f9a27 100644 --- a/wrapyfi/plugins/pyarrow_array.py +++ b/wrapyfi/plugins/pyarrow_array.py @@ -24,12 +24,15 @@ try: import pyarrow as pa + HAVE_PYARROW = True except ImportError: HAVE_PYARROW = False -@PluginRegistrar.register(types=None if not HAVE_PYARROW else pa.StructArray.__mro__[:-1]) +@PluginRegistrar.register( + types=None if not HAVE_PYARROW else pa.StructArray.__mro__[:-1] +) class PyArrowArray(Plugin): def __init__(self, **kwargs): """ @@ -51,8 +54,16 @@ def encode(self, obj, *args, **kwargs): """ buffers = [] obj_data = pickle.dumps(obj, protocol=5, buffer_callback=buffers.append) - obj_buffers = list(map(lambda x: base64.b64encode(memoryview(x)).decode('ascii'), buffers)) - return True, dict(__wrapyfi__=(str(self.__class__.__name__), obj_data.decode('latin1'), *obj_buffers)) + obj_buffers = list( + map(lambda x: base64.b64encode(memoryview(x)).decode("ascii"), buffers) + ) + return True, dict( + __wrapyfi__=( + str(self.__class__.__name__), + obj_data.decode("latin1"), + *obj_buffers, + ) + ) def decode(self, obj_type, obj_full, *args, **kwargs): """ @@ -66,8 +77,10 @@ def decode(self, obj_type, obj_full, *args, **kwargs): - bool: Always True, indicating that the decoding was successful - pa.StructArray: The decoded PyArrow StructArray data """ - obj_data = obj_full[1].encode('latin1') - obj_buffers = list(map(lambda x: base64.b64decode(x.encode('ascii')), obj_full[2:])) + obj_data = obj_full[1].encode("latin1") + obj_buffers = list( + map(lambda x: base64.b64decode(x.encode("ascii")), obj_full[2:]) + ) obj_data = bytearray(obj_data) for buf in obj_buffers: obj_data += buf diff --git a/wrapyfi/plugins/pytorch_tensor.py b/wrapyfi/plugins/pytorch_tensor.py index 75f8e18..533ca10 100644 --- a/wrapyfi/plugins/pytorch_tensor.py +++ b/wrapyfi/plugins/pytorch_tensor.py @@ -24,6 +24,7 @@ try: import torch + HAVE_TORCH = True except ImportError: HAVE_TORCH = False @@ -40,7 +41,7 @@ def __init__(self, load_torch_device=None, map_torch_devices=None, **kwargs): """ self.map_torch_devices = map_torch_devices or {} if load_torch_device is not None: - self.map_torch_devices['default'] = load_torch_device + self.map_torch_devices["default"] = load_torch_device def encode(self, obj, *args, **kwargs): """ @@ -56,9 +57,11 @@ def encode(self, obj, *args, **kwargs): """ with io.BytesIO() as memfile: torch.save(obj, memfile) - obj_data = base64.b64encode(memfile.getvalue()).decode('ascii') + obj_data = base64.b64encode(memfile.getvalue()).decode("ascii") obj_device = str(obj.device) - return True, dict(__wrapyfi__=(str(self.__class__.__name__), obj_data, obj_device)) + return True, dict( + __wrapyfi__=(str(self.__class__.__name__), obj_data, obj_device) + ) def decode(self, obj_type, obj_full, *args, **kwargs): """ @@ -72,10 +75,11 @@ def decode(self, obj_type, obj_full, *args, **kwargs): - bool: Always True, indicating that the decoding was successful - torch.Tensor: The decoded PyTorch tensor data """ - with io.BytesIO(base64.b64decode(obj_full[1].encode('ascii'))) as memfile: - obj_device = self.map_torch_devices.get(obj_full[2], self.map_torch_devices.get('default', None)) + with io.BytesIO(base64.b64decode(obj_full[1].encode("ascii"))) as memfile: + obj_device = self.map_torch_devices.get( + obj_full[2], self.map_torch_devices.get("default", None) + ) if obj_device is not None: return True, torch.load(memfile, map_location=obj_device) else: return True, torch.load(memfile) - diff --git a/wrapyfi/plugins/tensorflow_tensor.py b/wrapyfi/plugins/tensorflow_tensor.py index af2ae70..52bbba7 100644 --- a/wrapyfi/plugins/tensorflow_tensor.py +++ b/wrapyfi/plugins/tensorflow_tensor.py @@ -26,12 +26,15 @@ try: import tensorflow + HAVE_TENSORFLOW = True except ImportError: HAVE_TENSORFLOW = False -@PluginRegistrar.register(types=None if not HAVE_TENSORFLOW else tensorflow.Tensor.__mro__[:-1]) +@PluginRegistrar.register( + types=None if not HAVE_TENSORFLOW else tensorflow.Tensor.__mro__[:-1] +) class TensorflowTensor(Plugin): def __init__(self, **kwargs): """ @@ -53,7 +56,7 @@ def encode(self, obj, *args, **kwargs): """ with io.BytesIO() as memfile: np.save(memfile, obj.numpy()) - obj_data = base64.b64encode(memfile.getvalue()).decode('ascii') + obj_data = base64.b64encode(memfile.getvalue()).decode("ascii") return True, dict(__wrapyfi__=(str(self.__class__.__name__), obj_data)) def decode(self, obj_type, obj_full, *args, **kwargs): @@ -68,5 +71,5 @@ def decode(self, obj_type, obj_full, *args, **kwargs): - bool: Always True, indicating that the decoding was successful - tensorflow.Tensor: The decoded TensorFlow tensor data """ - with io.BytesIO(base64.b64decode(obj_full[1].encode('ascii'))) as memfile: + with io.BytesIO(base64.b64decode(obj_full[1].encode("ascii"))) as memfile: return True, tensorflow.convert_to_tensor(np.load(memfile)) diff --git a/wrapyfi/plugins/xarray_data.py b/wrapyfi/plugins/xarray_data.py index f06df4d..24986b7 100644 --- a/wrapyfi/plugins/xarray_data.py +++ b/wrapyfi/plugins/xarray_data.py @@ -28,12 +28,17 @@ try: import xarray as xr + HAVE_XARRAY = True except ImportError: HAVE_XARRAY = False -@PluginRegistrar.register(types=None if not HAVE_XARRAY else xr.DataArray.__mro__[:-1] + xr.Dataset.__mro__[:-1]) +@PluginRegistrar.register( + types=( + None if not HAVE_XARRAY else xr.DataArray.__mro__[:-1] + xr.Dataset.__mro__[:-1] + ) +) class XArrayData(Plugin): def __init__(self, **kwargs): """ @@ -53,11 +58,11 @@ def encode(self, obj, *args, **kwargs): - dict: A dictionary containing: - '__wrapyfi__': A tuple containing the class name, encoded data string, data type, and object name """ - obj_type = 'Dataset' + obj_type = "Dataset" obj_name = None if isinstance(obj, xr.DataArray): obj_dict = obj.to_dataset().to_dict() - obj_type = 'DataArray' + obj_type = "DataArray" obj_name = obj.name elif isinstance(obj, xr.Dataset): obj_dict = obj.to_dict() @@ -68,17 +73,19 @@ def traverse_and_convert(obj): elif isinstance(obj, dict): return {key: traverse_and_convert(value) for key, value in obj.items()} elif isinstance(obj, np.datetime64): - return {'data': str(obj), 'dtype': str(obj.dtype)} + return {"data": str(obj), "dtype": str(obj.dtype)} elif isinstance(obj, datetime): - return {'data': obj.isoformat(), 'dtype': 'datetime'} + return {"data": obj.isoformat(), "dtype": "datetime"} else: return obj converted_obj_dict = traverse_and_convert(obj_dict) obj_json = json.dumps(converted_obj_dict) - obj_data = base64.b64encode(obj_json.encode('ascii')).decode('ascii') + obj_data = base64.b64encode(obj_json.encode("ascii")).decode("ascii") - return True, dict(__wrapyfi__=(str(self.__class__.__name__), obj_data, obj_type, obj_name)) + return True, dict( + __wrapyfi__=(str(self.__class__.__name__), obj_data, obj_type, obj_name) + ) def decode(self, obj_type, obj_full, *args, **kwargs): """ @@ -92,7 +99,7 @@ def decode(self, obj_type, obj_full, *args, **kwargs): - bool: Always True, indicating that the decoding was successful - Union[xr.DataArray, xr.Dataset]: The decoded XArray Data """ - obj_data = base64.b64decode(obj_full[1].encode('ascii')).decode('ascii') + obj_data = base64.b64decode(obj_full[1].encode("ascii")).decode("ascii") xarray_type = obj_full[2] xarray_name = obj_full[3] @@ -100,21 +107,23 @@ def traverse_and_reconvert(obj): if isinstance(obj, list): return [traverse_and_reconvert(item) for item in obj] elif isinstance(obj, dict): - if 'data' in obj and 'dtype' in obj: - if obj['dtype'] == 'datetime': - return datetime.fromisoformat(obj['data']) + if "data" in obj and "dtype" in obj: + if obj["dtype"] == "datetime": + return datetime.fromisoformat(obj["data"]) else: - return np.datetime64(obj['data'], obj['dtype']) - return {key: traverse_and_reconvert(value) for key, value in obj.items()} + return np.datetime64(obj["data"], obj["dtype"]) + return { + key: traverse_and_reconvert(value) for key, value in obj.items() + } else: return obj obj_dict = json.loads(obj_data) reconverted_obj_dict = traverse_and_reconvert(obj_dict) - if xarray_type == 'DataArray': + if xarray_type == "DataArray": dataset = xr.Dataset.from_dict(reconverted_obj_dict) data_array = dataset[xarray_name] return True, data_array else: - return True, xr.Dataset.from_dict(reconverted_obj_dict) \ No newline at end of file + return True, xr.Dataset.from_dict(reconverted_obj_dict) diff --git a/wrapyfi/plugins/zarr_array.py b/wrapyfi/plugins/zarr_array.py index 47ded18..1fb16ff 100644 --- a/wrapyfi/plugins/zarr_array.py +++ b/wrapyfi/plugins/zarr_array.py @@ -26,12 +26,15 @@ try: import zipfile import zarr + HAVE_ZARR = True except ImportError: HAVE_ZARR = False -@PluginRegistrar.register(types=None if not HAVE_ZARR else zarr.Array.__mro__[:-1] + zarr.Group.__mro__[:-1]) +@PluginRegistrar.register( + types=None if not HAVE_ZARR else zarr.Array.__mro__[:-1] + zarr.Group.__mro__[:-1] +) class ZarrData(Plugin): def __init__(self, **kwargs): """ @@ -51,27 +54,32 @@ def encode(self, obj, *args, **kwargs): - dict: A dictionary containing: - '__wrapyfi__': A tuple containing the class name, encoded data string, data type, and object name. """ - obj_type = 'Group' if isinstance(obj, zarr.Group) else 'Array' + obj_type = "Group" if isinstance(obj, zarr.Group) else "Array" obj_name = obj.name if isinstance(obj, zarr.Array) else None with tempfile.TemporaryDirectory() as tmpdirname: - store_path = os.path.join(tmpdirname, 'zarr_store') + store_path = os.path.join(tmpdirname, "zarr_store") - if obj_type == 'Array': + if obj_type == "Array": zarr.save_array(store_path, obj) else: zarr.save_group(store_path, obj) with io.BytesIO() as binary_stream: - with zipfile.ZipFile(binary_stream, 'w') as zipf: + with zipfile.ZipFile(binary_stream, "w") as zipf: for foldername, subfolders, filenames in os.walk(store_path): for filename in filenames: - zipf.write(os.path.join(foldername, filename), - arcname=os.path.relpath(os.path.join(foldername, filename), - store_path)) - obj_data = base64.b64encode(binary_stream.getvalue()).decode('ascii') - - return True, dict(__wrapyfi__=(str(self.__class__.__name__), obj_data, obj_type, obj_name)) + zipf.write( + os.path.join(foldername, filename), + arcname=os.path.relpath( + os.path.join(foldername, filename), store_path + ), + ) + obj_data = base64.b64encode(binary_stream.getvalue()).decode("ascii") + + return True, dict( + __wrapyfi__=(str(self.__class__.__name__), obj_data, obj_type, obj_name) + ) def decode(self, obj_type, obj_full, *args, **kwargs): """ @@ -87,17 +95,19 @@ def decode(self, obj_type, obj_full, *args, **kwargs): """ obj_data = obj_full[1] zarr_type = obj_full[2] - zarr_name = obj_full[3] # zarr_name is used only if zarr_type is 'Array'. Currently not used. + zarr_name = obj_full[ + 3 + ] # zarr_name is used only if zarr_type is 'Array'. Currently not used. - with io.BytesIO(base64.b64decode(obj_data.encode('ascii'))) as binary_stream: + with io.BytesIO(base64.b64decode(obj_data.encode("ascii"))) as binary_stream: with tempfile.TemporaryDirectory() as tmpdirname: - store_path = os.path.join(tmpdirname, 'zarr_store') - with zipfile.ZipFile(binary_stream, 'r') as zipf: + store_path = os.path.join(tmpdirname, "zarr_store") + with zipfile.ZipFile(binary_stream, "r") as zipf: zipf.extractall(path=store_path) - if zarr_type == 'Array': - array = zarr.open_array(store_path, mode='r') + if zarr_type == "Array": + array = zarr.open_array(store_path, mode="r") return True, array else: - group = zarr.open_group(store_path, mode='r') + group = zarr.open_group(store_path, mode="r") return True, group diff --git a/wrapyfi/publishers/__init__.py b/wrapyfi/publishers/__init__.py index 6880b6e..947d5b0 100755 --- a/wrapyfi/publishers/__init__.py +++ b/wrapyfi/publishers/__init__.py @@ -6,11 +6,22 @@ @Publishers.register("MMO", "fallback") class FallbackPublisher(Publisher): - def __init__(self, name: str, out_topic: str, carrier: str = "", - should_wait: bool = True, missing_middleware_object: str = "", **kwargs): - logging.warning(f"Fallback publisher employed due to missing middleware or object type: " - f"{missing_middleware_object}") - Publisher.__init__(self, name, out_topic, carrier=carrier, should_wait=should_wait, **kwargs) + def __init__( + self, + name: str, + out_topic: str, + carrier: str = "", + should_wait: bool = True, + missing_middleware_object: str = "", + **kwargs, + ): + logging.warning( + f"Fallback publisher employed due to missing middleware or object type: " + f"{missing_middleware_object}" + ) + Publisher.__init__( + self, name, out_topic, carrier=carrier, should_wait=should_wait, **kwargs + ) self.missing_middleware_object = missing_middleware_object def establish(self, repeats: int = -1, **kwargs): @@ -20,4 +31,4 @@ def publish(self, obj): return obj def close(self): - return None \ No newline at end of file + return None diff --git a/wrapyfi/publishers/ros.py b/wrapyfi/publishers/ros.py index f4570d5..ec6537c 100755 --- a/wrapyfi/publishers/ros.py +++ b/wrapyfi/publishers/ros.py @@ -25,8 +25,16 @@ class ROSPublisher(Publisher): - def __init__(self, name: str, out_topic: str, carrier: str = "tcp", should_wait: bool = True, - queue_size: int = QUEUE_SIZE, ros_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + out_topic: str, + carrier: str = "tcp", + should_wait: bool = True, + queue_size: int = QUEUE_SIZE, + ros_kwargs: Optional[dict] = None, + **kwargs, + ): """ Initialize the publisher. @@ -39,14 +47,20 @@ def __init__(self, name: str, out_topic: str, carrier: str = "tcp", should_wait: :param kwargs: dict: Additional kwargs for the publisher """ if carrier or carrier != "tcp": - logging.warning("[ROS] ROS does not support other carriers than TCP for PUB/SUB pattern. Using TCP.") + logging.warning( + "[ROS] ROS does not support other carriers than TCP for PUB/SUB pattern. Using TCP." + ) carrier = "tcp" - super().__init__(name, out_topic, carrier=carrier, should_wait=should_wait, **kwargs) + super().__init__( + name, out_topic, carrier=carrier, should_wait=should_wait, **kwargs + ) ROSMiddleware.activate(**ros_kwargs or {}) self.queue_size = queue_size - def await_connection(self, publisher, out_topic: Optional[str] = None, repeats: Optional[int] = None): + def await_connection( + self, publisher, out_topic: Optional[str] = None, repeats: Optional[int] = None + ): """ Wait for at least one subscriber to connect to the publisher. @@ -88,8 +102,16 @@ def __del__(self): @Publishers.register("NativeObject", "ros") class ROSNativeObjectPublisher(ROSPublisher): - def __init__(self, name: str, out_topic: str, carrier: str = "tcp", should_wait: bool = True, - queue_size: int = QUEUE_SIZE, serializer_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + out_topic: str, + carrier: str = "tcp", + should_wait: bool = True, + queue_size: int = QUEUE_SIZE, + serializer_kwargs: Optional[dict] = None, + **kwargs, + ): """ The NativeObject publisher using the ROS String message assuming a combination of python native objects. @@ -101,7 +123,14 @@ def __init__(self, name: str, out_topic: str, carrier: str = "tcp", should_wait: :param queue_size: int: Queue size for the publisher. Default is 5 :param serializer_kwargs: dict: Additional kwargs for the serializer """ - super().__init__(name, out_topic, carrier=carrier, should_wait=should_wait, queue_size=queue_size, **kwargs) + super().__init__( + name, + out_topic, + carrier=carrier, + should_wait=should_wait, + queue_size=queue_size, + **kwargs, + ) self._plugin_encoder = JsonEncoder self._plugin_kwargs = kwargs self._serializer_kwargs = serializer_kwargs or {} @@ -118,7 +147,9 @@ def establish(self, repeats: Optional[int] = None, **kwargs): :param repeats: int: Number of repeats to await connection. None for infinite. Default is None :return: bool: True if connection established, False otherwise """ - self._publisher = rospy.Publisher(self.out_topic, std_msgs.msg.String, queue_size=self.queue_size) + self._publisher = rospy.Publisher( + self.out_topic, std_msgs.msg.String, queue_size=self.queue_size + ) established = self.await_connection(self._publisher, repeats=repeats) return self.check_establishment(established) @@ -134,16 +165,32 @@ def publish(self, obj): return else: time.sleep(0.2) - obj_str = json.dumps(obj, cls=self._plugin_encoder, **self._plugin_kwargs, - serializer_kwrags=self._serializer_kwargs) + obj_str = json.dumps( + obj, + cls=self._plugin_encoder, + **self._plugin_kwargs, + serializer_kwrags=self._serializer_kwargs, + ) self._publisher.publish(obj_str) @Publishers.register("Image", "ros") class ROSImagePublisher(ROSPublisher): - def __init__(self, name: str, out_topic: str, carrier: str = "tcp", should_wait: bool = True, queue_size: int = QUEUE_SIZE, - width: int = -1, height: int = -1, rgb: bool = True, fp: bool = False, jpg: bool = False, **kwargs): + def __init__( + self, + name: str, + out_topic: str, + carrier: str = "tcp", + should_wait: bool = True, + queue_size: int = QUEUE_SIZE, + width: int = -1, + height: int = -1, + rgb: bool = True, + fp: bool = False, + jpg: bool = False, + **kwargs, + ): """ The ImagePublisher using the ROS Image message assuming a numpy array as input. @@ -158,7 +205,14 @@ def __init__(self, name: str, out_topic: str, carrier: str = "tcp", should_wait :param fp: bool: True if the image is floating point, False if it is integer. Default is False :param jpg: bool: True if the image should be compressed as JPG. Default is False """ - super().__init__(name, out_topic, carrier=carrier, should_wait=should_wait, queue_size=queue_size, **kwargs) + super().__init__( + name, + out_topic, + carrier=carrier, + should_wait=should_wait, + queue_size=queue_size, + **kwargs, + ) self.width = width self.height = height self.rgb = rgb @@ -166,13 +220,13 @@ def __init__(self, name: str, out_topic: str, carrier: str = "tcp", should_wait self.jpg = jpg if self.fp: - self._encoding = '32FC3' if self.rgb else '32FC1' + self._encoding = "32FC3" if self.rgb else "32FC1" self._type = np.float32 else: - self._encoding = 'bgr8' if self.rgb else 'mono8' + self._encoding = "bgr8" if self.rgb else "mono8" self._type = np.uint8 if self.jpg: - self._encoding = 'jpeg' + self._encoding = "jpeg" self._type = np.uint8 self._publisher = None @@ -188,9 +242,15 @@ def establish(self, repeats: Optional[int] = None, **kwargs): :return: bool: True if connection established, False otherwise """ if self.jpg: - self._publisher = rospy.Publisher(self.out_topic, sensor_msgs.msg.CompressedImage, queue_size=self.queue_size) + self._publisher = rospy.Publisher( + self.out_topic, + sensor_msgs.msg.CompressedImage, + queue_size=self.queue_size, + ) else: - self._publisher = rospy.Publisher(self.out_topic, sensor_msgs.msg.Image, queue_size=self.queue_size) + self._publisher = rospy.Publisher( + self.out_topic, sensor_msgs.msg.Image, queue_size=self.queue_size + ) established = self.await_connection(self._publisher) return self.check_establishment(established) @@ -210,23 +270,31 @@ def publish(self, img: np.ndarray): else: time.sleep(0.2) - if 0 < self.width != img.shape[1] or 0 < self.height != img.shape[0] or \ - not ((img.ndim == 2 and not self.rgb) or (img.ndim == 3 and self.rgb and img.shape[2] == 3)): + if ( + 0 < self.width != img.shape[1] + or 0 < self.height != img.shape[0] + or not ( + (img.ndim == 2 and not self.rgb) + or (img.ndim == 3 and self.rgb and img.shape[2] == 3) + ) + ): raise ValueError("Incorrect image shape for publisher") - img = np.require(img, dtype=self._type, requirements='C') + img = np.require(img, dtype=self._type, requirements="C") if self.jpg: img_msg = sensor_msgs.msg.CompressedImage() img_msg.header.stamp = rospy.Time.now() img_msg.format = "jpeg" - img_msg.data = np.array(cv2.imencode('.jpg', img)[1]).tobytes() + img_msg.data = np.array(cv2.imencode(".jpg", img)[1]).tobytes() else: img_msg = sensor_msgs.msg.Image() img_msg.header.stamp = rospy.Time.now() img_msg.height = img.shape[0] img_msg.width = img.shape[1] img_msg.encoding = self._encoding - img_msg.is_bigendian = img.dtype.byteorder == '>' or (img.dtype.byteorder == '=' and sys.byteorder == 'big') + img_msg.is_bigendian = img.dtype.byteorder == ">" or ( + img.dtype.byteorder == "=" and sys.byteorder == "big" + ) img_msg.step = img.strides[0] img_msg.data = img.tobytes() self._publisher.publish(img_msg) @@ -235,8 +303,18 @@ def publish(self, img: np.ndarray): @Publishers.register("AudioChunk", "ros") class ROSAudioChunkPublisher(ROSPublisher): - def __init__(self, name: str, out_topic: str, carrier: str = "tcp", should_wait: bool = True, queue_size: int = QUEUE_SIZE, - channels: int = 1, rate: int = 44100, chunk: int = -1, **kwargs): + def __init__( + self, + name: str, + out_topic: str, + carrier: str = "tcp", + should_wait: bool = True, + queue_size: int = QUEUE_SIZE, + channels: int = 1, + rate: int = 44100, + chunk: int = -1, + **kwargs, + ): """ The AudioChunkPublisher using the ROS Audio message assuming a numpy array as input. @@ -249,7 +327,14 @@ def __init__(self, name: str, out_topic: str, carrier: str = "tcp", should_wait: :param rate: int: Sampling rate. Default is 44100 :param chunk: int: Chunk size. Default is -1 meaning that the chunk size is not fixed """ - super().__init__(name, out_topic, carrier=carrier, should_wait=should_wait, queue_size=queue_size, **kwargs) + super().__init__( + name, + out_topic, + carrier=carrier, + should_wait=should_wait, + queue_size=queue_size, + **kwargs, + ) self.channels = channels self.rate = rate self.chunk = chunk @@ -270,12 +355,18 @@ def establish(self, repeats: Optional[int] = None, **kwargs): from wrapyfi_ros_interfaces.msg import ROSAudioMessage except ImportError: import wrapyfi - logging.error("[ROS] Could not import ROSAudioMessage. " - "Make sure the ROS messages in wrapyfi_extensions/wrapyfi_ros_interfaces are compiled. " - "Refer to the documentation for more information: \n" + - wrapyfi.__doc__ + "ros_interfaces_lnk.html") + + logging.error( + "[ROS] Could not import ROSAudioMessage. " + "Make sure the ROS messages in wrapyfi_extensions/wrapyfi_ros_interfaces are compiled. " + "Refer to the documentation for more information: \n" + + wrapyfi.__doc__ + + "ros_interfaces_lnk.html" + ) sys.exit(1) - self._publisher = rospy.Publisher(self.out_topic, ROSAudioMessage, queue_size=self.queue_size) + self._publisher = rospy.Publisher( + self.out_topic, ROSAudioMessage, queue_size=self.queue_size + ) self._sound_msg = ROSAudioMessage() established = self.await_connection(self._publisher) return self.check_establishment(established) @@ -303,15 +394,17 @@ def publish(self, aud: Tuple[np.ndarray, int]): self.channels = channels if self.channels == -1 else self.channels if 0 < self.chunk != chunk or 0 < self.channels != channels: raise ValueError("Incorrect audio shape for publisher") - aud = np.require(aud, dtype=np.float32, requirements='C') + aud = np.require(aud, dtype=np.float32, requirements="C") aud_msg = self._sound_msg aud_msg.header.stamp = rospy.Time.now() aud_msg.chunk_size = chunk aud_msg.channels = channels aud_msg.sample_rate = rate - aud_msg.is_bigendian = aud.dtype.byteorder == '>' or (aud.dtype.byteorder == '=' and sys.byteorder == 'big') - aud_msg.encoding = 'S16BE' if aud_msg.is_bigendian else 'S16LE' + aud_msg.is_bigendian = aud.dtype.byteorder == ">" or ( + aud.dtype.byteorder == "=" and sys.byteorder == "big" + ) + aud_msg.encoding = "S16BE" if aud_msg.is_bigendian else "S16LE" aud_msg.step = aud.strides[0] aud_msg.data = aud.tobytes() # (aud * 32767.0).tobytes() self._publisher.publish(aud_msg) @@ -326,7 +419,15 @@ class ROSPropertiesPublisher(ROSPublisher): but care should be taken when using dictionaries, since they are analogous with node namespaces: http://wiki.ros.org/rospy/Overview/Parameter%20Server """ - def __init__(self, name: str, out_topic: str, carrier: str = "tcp", persistent: bool = True, **kwargs): + + def __init__( + self, + name: str, + out_topic: str, + carrier: str = "tcp", + persistent: bool = True, + **kwargs, + ): """ The PropertiesPublisher using the ROS parameter server. @@ -371,8 +472,15 @@ def __del__(self): @Publishers.register("ROSMessage", "ros") class ROSMessagePublisher(ROSPublisher): - def __init__(self, name: str, out_topic: str, carrier: str = "tcp", - should_wait: bool = True, queue_size: int = QUEUE_SIZE, **kwargs): + def __init__( + self, + name: str, + out_topic: str, + carrier: str = "tcp", + should_wait: bool = True, + queue_size: int = QUEUE_SIZE, + **kwargs, + ): """ The ROSMessagePublisher using the ROS message type inferred from the message type. Supports standard ROS msgs. @@ -382,7 +490,14 @@ def __init__(self, name: str, out_topic: str, carrier: str = "tcp", :param should_wait: bool: Whether to wait for at least one listener before unblocking the script. Default is True :param queue_size: int: Queue size for the publisher. Default is 5 """ - super().__init__(name, out_topic, carrier=carrier, should_wait=should_wait, queue_size=queue_size, **kwargs) + super().__init__( + name, + out_topic, + carrier=carrier, + should_wait=should_wait, + queue_size=queue_size, + **kwargs, + ) self._publisher = None @@ -402,7 +517,9 @@ def establish(self, repeats: Optional[int] = None, obj=None, **kwargs): obj_type = obj._type.split("/") import_msg = importlib.import_module(f"{obj_type[0]}.msg") msg_type = getattr(import_msg, obj_type[1]) - self._publisher = rospy.Publisher(self.out_topic, msg_type, queue_size=self.queue_size) + self._publisher = rospy.Publisher( + self.out_topic, msg_type, queue_size=self.queue_size + ) established = self.await_connection(self._publisher, repeats=repeats) return self.check_establishment(established) diff --git a/wrapyfi/publishers/ros2.py b/wrapyfi/publishers/ros2.py index 40c6ae9..b4441b8 100755 --- a/wrapyfi/publishers/ros2.py +++ b/wrapyfi/publishers/ros2.py @@ -24,8 +24,15 @@ class ROS2Publisher(Publisher, Node): - def __init__(self, name: str, out_topic: str, should_wait: bool = True, - queue_size: int = QUEUE_SIZE, ros2_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + out_topic: str, + should_wait: bool = True, + queue_size: int = QUEUE_SIZE, + ros2_kwargs: Optional[dict] = None, + **kwargs, + ): """ Initialize the publisher. @@ -38,16 +45,22 @@ def __init__(self, name: str, out_topic: str, should_wait: bool = True, """ carrier = "tcp" if "carrier" in kwargs and kwargs["carrier"] not in ["", None]: - logging.warning("[ROS 2] ROS 2 currently does not support explicit carrier setting for PUB/SUB pattern. Using TCP.") + logging.warning( + "[ROS 2] ROS 2 currently does not support explicit carrier setting for PUB/SUB pattern. Using TCP." + ) if "carrier" in kwargs: del kwargs["carrier"] ROS2Middleware.activate(**ros2_kwargs or {}) - Publisher.__init__(self, name, out_topic, carrier=carrier, should_wait=should_wait, **kwargs) + Publisher.__init__( + self, name, out_topic, carrier=carrier, should_wait=should_wait, **kwargs + ) Node.__init__(self, name + str(hex(id(self)))) self.queue_size = queue_size - def await_connection(self, publisher, out_topic: Optional[str] = None, repeats: Optional[int] = None): + def await_connection( + self, publisher, out_topic: Optional[str] = None, repeats: Optional[int] = None + ): """ Wait for at least one subscriber to connect to the publisher. @@ -80,7 +93,7 @@ def close(self): """ if hasattr(self, "_publisher") and self._publisher: if self._publisher is not None: - self.destroy_node() + self.destroy_node() def __del__(self): self.close() @@ -89,8 +102,15 @@ def __del__(self): @Publishers.register("NativeObject", "ros2") class ROS2NativeObjectPublisher(ROS2Publisher): - def __init__(self, name, out_topic: str, should_wait: bool = True, - queue_size: int = QUEUE_SIZE, serializer_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name, + out_topic: str, + should_wait: bool = True, + queue_size: int = QUEUE_SIZE, + serializer_kwargs: Optional[dict] = None, + **kwargs, + ): """ The NativeObject publisher using the ROS 2 String message assuming a combination of python native objects and numpy arrays as input. Serializes the data (including plugins) using the encoder and sends it as a string. @@ -101,7 +121,9 @@ def __init__(self, name, out_topic: str, should_wait: bool = True, :param queue_size: int: Queue size for the publisher. Default is 5 :param serializer_kwargs: dict: Additional kwargs for the serializer """ - super().__init__(name, out_topic, should_wait=should_wait, queue_size=queue_size, **kwargs) + super().__init__( + name, out_topic, should_wait=should_wait, queue_size=queue_size, **kwargs + ) self._plugin_encoder = JsonEncoder self._plugin_kwargs = kwargs self._serializer_kwargs = serializer_kwargs or {} @@ -118,7 +140,9 @@ def establish(self, repeats: Optional[int] = None, **kwargs): :param repeats: int: Number of repeats to await connection. None for infinite. Default is None :return: bool: True if connection established, False otherwise """ - self._publisher = self.create_publisher(std_msgs.msg.String, self.out_topic, qos_profile=self.queue_size) + self._publisher = self.create_publisher( + std_msgs.msg.String, self.out_topic, qos_profile=self.queue_size + ) established = self.await_connection(self._publisher, repeats=repeats) return self.check_establishment(established) @@ -134,8 +158,12 @@ def publish(self, obj): return else: time.sleep(0.2) - obj_str = json.dumps(obj, cls=self._plugin_encoder, **self._plugin_kwargs, - serializer_kwrags=self._serializer_kwargs) + obj_str = json.dumps( + obj, + cls=self._plugin_encoder, + **self._plugin_kwargs, + serializer_kwrags=self._serializer_kwargs, + ) obj_str_msg = std_msgs.msg.String() obj_str_msg.data = obj_str self._publisher.publish(obj_str_msg) @@ -144,8 +172,19 @@ def publish(self, obj): @Publishers.register("Image", "ros2") class ROS2ImagePublisher(ROS2Publisher): - def __init__(self, name: str, out_topic: str, should_wait: bool = True, queue_size: int = QUEUE_SIZE, - width: int = -1, height: int = -1, rgb: bool = True, fp: bool = False, jpg: bool = False, **kwargs): + def __init__( + self, + name: str, + out_topic: str, + should_wait: bool = True, + queue_size: int = QUEUE_SIZE, + width: int = -1, + height: int = -1, + rgb: bool = True, + fp: bool = False, + jpg: bool = False, + **kwargs, + ): """ The ImagePublisher using the ROS 2 Image message assuming a numpy array as input. @@ -159,7 +198,9 @@ def __init__(self, name: str, out_topic: str, should_wait: bool = True, queue_si :param fp: bool: True if the image is floating point, False if it is integer. Default is False :param jpg: bool: True if the image should be compressed as JPG. Default is False """ - super().__init__(name, out_topic, should_wait=should_wait, queue_size=queue_size, **kwargs) + super().__init__( + name, out_topic, should_wait=should_wait, queue_size=queue_size, **kwargs + ) self.width = width self.height = height self.rgb = rgb @@ -167,13 +208,13 @@ def __init__(self, name: str, out_topic: str, should_wait: bool = True, queue_si self.jpg = jpg if self.fp: - self._encoding = '32FC3' if self.rgb else '32FC1' + self._encoding = "32FC3" if self.rgb else "32FC1" self._type = np.float32 else: - self._encoding = 'bgr8' if self.rgb else 'mono8' + self._encoding = "bgr8" if self.rgb else "mono8" self._type = np.uint8 if self.jpg: - self._encoding = 'jpeg' + self._encoding = "jpeg" self._type = np.uint8 self._publisher = None @@ -189,9 +230,15 @@ def establish(self, repeats: Optional[int] = None, **kwargs): :return: bool: True if connection established, False otherwise """ if self.jpg: - self._publisher = self.create_publisher(sensor_msgs.msg.CompressedImage, self.out_topic, qos_profile=self.queue_size) + self._publisher = self.create_publisher( + sensor_msgs.msg.CompressedImage, + self.out_topic, + qos_profile=self.queue_size, + ) else: - self._publisher = self.create_publisher(sensor_msgs.msg.Image, self.out_topic, qos_profile=self.queue_size) + self._publisher = self.create_publisher( + sensor_msgs.msg.Image, self.out_topic, qos_profile=self.queue_size + ) established = self.await_connection(self._publisher) return self.check_establishment(established) @@ -211,23 +258,31 @@ def publish(self, img: np.ndarray): else: time.sleep(0.2) - if 0 < self.width != img.shape[1] or 0 < self.height != img.shape[0] or \ - not ((img.ndim == 2 and not self.rgb) or (img.ndim == 3 and self.rgb and img.shape[2] == 3)): + if ( + 0 < self.width != img.shape[1] + or 0 < self.height != img.shape[0] + or not ( + (img.ndim == 2 and not self.rgb) + or (img.ndim == 3 and self.rgb and img.shape[2] == 3) + ) + ): raise ValueError("Incorrect image shape for publisher") - img = np.require(img, dtype=self._type, requirements='C') + img = np.require(img, dtype=self._type, requirements="C") if self.jpg: img_msg = sensor_msgs.msg.CompressedImage() img_msg.header.stamp = rclpy.clock.Clock().now().to_msg() img_msg.format = "jpeg" - img_msg.data = np.array(cv2.imencode('.jpg', img)[1]).tobytes() + img_msg.data = np.array(cv2.imencode(".jpg", img)[1]).tobytes() else: img_msg = sensor_msgs.msg.Image() img_msg.header.stamp = self.get_clock().now().to_msg() img_msg.height = img.shape[0] img_msg.width = img.shape[1] img_msg.encoding = self._encoding - img_msg.is_bigendian = img.dtype.byteorder == '>' or (img.dtype.byteorder == '=' and sys.byteorder == 'big') + img_msg.is_bigendian = img.dtype.byteorder == ">" or ( + img.dtype.byteorder == "=" and sys.byteorder == "big" + ) img_msg.step = img.strides[0] img_msg.data = img.tobytes() self._publisher.publish(img_msg) @@ -236,8 +291,17 @@ def publish(self, img: np.ndarray): @Publishers.register("AudioChunk", "ros2") class ROS2AudioChunkPublisher(ROS2Publisher): - def __init__(self, name: str, out_topic: str, should_wait: bool = True, queue_size: int = QUEUE_SIZE, - channels: int = 1, rate: int = 44100, chunk: int = -1, **kwargs): + def __init__( + self, + name: str, + out_topic: str, + should_wait: bool = True, + queue_size: int = QUEUE_SIZE, + channels: int = 1, + rate: int = 44100, + chunk: int = -1, + **kwargs, + ): """ The AudioChunkPublisher using the ROS 2 Audio message assuming a numpy array as input. @@ -249,7 +313,9 @@ def __init__(self, name: str, out_topic: str, should_wait: bool = True, queue_si :param rate: int: Sampling rate. Default is 44100 :param chunk: int: Chunk size. Default is -1 meaning that the chunk size is not fixed """ - super().__init__(name, out_topic, should_wait=should_wait, queue_size=queue_size, **kwargs) + super().__init__( + name, out_topic, should_wait=should_wait, queue_size=queue_size, **kwargs + ) self.channels = channels self.rate = rate self.chunk = chunk @@ -270,12 +336,18 @@ def establish(self, repeats: Optional[int] = None, **kwargs): from wrapyfi_ros2_interfaces.msg import ROS2AudioMessage except ImportError: import wrapyfi - logging.error("[ROS 2] Could not import ROS2AudioMessage. " - "Make sure the ROS 2 services in wrapyfi_extensions/wrapyfi_ros2_interfaces are compiled. " - "Refer to the documentation for more information: \n" + - wrapyfi.__doc__ + "ros2_interfaces_lnk.html") + + logging.error( + "[ROS 2] Could not import ROS2AudioMessage. " + "Make sure the ROS 2 services in wrapyfi_extensions/wrapyfi_ros2_interfaces are compiled. " + "Refer to the documentation for more information: \n" + + wrapyfi.__doc__ + + "ros2_interfaces_lnk.html" + ) sys.exit(1) - self._publisher = self.create_publisher(ROS2AudioMessage, self.out_topic, qos_profile=self.queue_size) + self._publisher = self.create_publisher( + ROS2AudioMessage, self.out_topic, qos_profile=self.queue_size + ) self._sound_msg = ROS2AudioMessage() established = self.await_connection(self._publisher) return self.check_establishment(established) @@ -303,15 +375,17 @@ def publish(self, aud: Tuple[np.ndarray, int]): self.channels = channels if self.channels == -1 else self.channels if 0 < self.chunk != chunk or 0 < self.channels != channels: raise ValueError("Incorrect audio shape for publisher") - aud = np.require(aud, dtype=np.float32, requirements='C') + aud = np.require(aud, dtype=np.float32, requirements="C") aud_msg = self._sound_msg aud_msg.header.stamp = self.get_clock().now().to_msg() aud_msg.chunk_size = chunk aud_msg.channels = channels aud_msg.sample_rate = rate - aud_msg.is_bigendian = aud.dtype.byteorder == '>' or (aud.dtype.byteorder == '=' and sys.byteorder == 'big') - aud_msg.encoding = 'S16BE' if aud_msg.is_bigendian else 'S16LE' + aud_msg.is_bigendian = aud.dtype.byteorder == ">" or ( + aud.dtype.byteorder == "=" and sys.byteorder == "big" + ) + aud_msg.encoding = "S16BE" if aud_msg.is_bigendian else "S16LE" aud_msg.step = aud.strides[0] aud_msg.data = aud.tobytes() # (aud * 32767.0).tobytes() self._publisher.publish(aud_msg) @@ -327,7 +401,14 @@ def __init__(self, name, out_topic, **kwargs): @Publishers.register("ROS2Message", "ros2") class ROS2MessagePublisher(ROS2Publisher): - def __init__(self, name: str, out_topic: str, should_wait: bool = True, queue_size: int = QUEUE_SIZE, **kwargs): + def __init__( + self, + name: str, + out_topic: str, + should_wait: bool = True, + queue_size: int = QUEUE_SIZE, + **kwargs, + ): """ The ROS2MessagePublisher using the ROS 2 message type determined dynamically. @@ -336,7 +417,9 @@ def __init__(self, name: str, out_topic: str, should_wait: bool = True, queue_si :param should_wait: bool: Whether to wait for at least one listener before unblocking the script. Default is True :param queue_size: int: Queue size for the publisher. Default is 5 """ - super().__init__(name, out_topic, should_wait=should_wait, queue_size=queue_size, **kwargs) + super().__init__( + name, out_topic, should_wait=should_wait, queue_size=queue_size, **kwargs + ) self._publisher = None self._msg_type = None @@ -362,7 +445,9 @@ def establish(self, msg, repeats: Optional[int] = None, **kwargs): """ self._msg_type = self.get_message_type(msg) - self._publisher = self.create_publisher(self._msg_type, self.out_topic, qos_profile=self.queue_size) + self._publisher = self.create_publisher( + self._msg_type, self.out_topic, qos_profile=self.queue_size + ) established = self.await_connection(self._publisher) return self.check_establishment(established) diff --git a/wrapyfi/publishers/yarp.py b/wrapyfi/publishers/yarp.py index 6a8d893..1ac954d 100755 --- a/wrapyfi/publishers/yarp.py +++ b/wrapyfi/publishers/yarp.py @@ -19,8 +19,17 @@ class YarpPublisher(Publisher): - def __init__(self, name: str, out_topic: str, carrier: Literal["tcp", "udp", "mcast"] = "tcp", should_wait: bool = True, - persistent: bool = True, out_topic_connect: Optional[str] = None, yarp_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + out_topic: str, + carrier: Literal["tcp", "udp", "mcast"] = "tcp", + should_wait: bool = True, + persistent: bool = True, + out_topic_connect: Optional[str] = None, + yarp_kwargs: Optional[dict] = None, + **kwargs, + ): """ Initialize the publisher. @@ -34,16 +43,22 @@ def __init__(self, name: str, out_topic: str, carrier: Literal["tcp", "udp", "mc :param yarp_kwargs: dict: Additional kwargs for the Yarp middleware :param kwargs: dict: Additional kwargs for the publisher """ - super().__init__(name, out_topic, carrier=carrier, should_wait=should_wait, **kwargs) + super().__init__( + name, out_topic, carrier=carrier, should_wait=should_wait, **kwargs + ) YarpMiddleware.activate(**yarp_kwargs or {}) self.style = yarp.ContactStyle() self.style.persistent = persistent self.style.carrier = self.carrier - self.out_topic_connect = out_topic + ":out" if out_topic_connect is None else out_topic_connect + self.out_topic_connect = ( + out_topic + ":out" if out_topic_connect is None else out_topic_connect + ) - def await_connection(self, port, out_topic: Optional[str] = None, repeats: Optional[int] = None): + def await_connection( + self, port, out_topic: Optional[str] = None, repeats: Optional[int] = None + ): """ Wait for at least one subscriber to connect to the publisher. @@ -85,8 +100,17 @@ def __del__(self): @Publishers.register("NativeObject", "yarp") class YarpNativeObjectPublisher(YarpPublisher): - def __init__(self, name: str, out_topic: str, carrier: Literal["tcp", "udp", "mcast"] = "tcp", should_wait: bool = True, - persistent: bool = True, out_topic_connect: str = None, serializer_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + out_topic: str, + carrier: Literal["tcp", "udp", "mcast"] = "tcp", + should_wait: bool = True, + persistent: bool = True, + out_topic_connect: str = None, + serializer_kwargs: Optional[dict] = None, + **kwargs, + ): """ The NativeObject publisher using the BufferedPortBottle string construct assuming a combination of python native objects and numpy arrays as input. Serializes the data (including plugins) using the encoder and sends it as a string. @@ -100,8 +124,15 @@ def __init__(self, name: str, out_topic: str, carrier: Literal["tcp", "udp", "mc None appends ':out' to the out_topic. Default is None :param serializer_kwargs: dict: Additional kwargs for the serializer """ - super().__init__(name, out_topic, carrier=carrier, should_wait=should_wait, persistent=persistent, - out_topic_connect=out_topic_connect, **kwargs) + super().__init__( + name, + out_topic, + carrier=carrier, + should_wait=should_wait, + persistent=persistent, + out_topic_connect=out_topic_connect, + **kwargs, + ) self._plugin_encoder = JsonEncoder self._plugin_kwargs = kwargs self._serializer_kwargs = serializer_kwargs or {} @@ -121,9 +152,13 @@ def establish(self, repeats: Optional[int] = None, **kwargs): self._port = yarp.BufferedPortBottle() self._port.open(self.out_topic) if self.style.persistent: - self._netconnect = yarp.Network.connect(self.out_topic, self.out_topic_connect, self.style) + self._netconnect = yarp.Network.connect( + self.out_topic, self.out_topic_connect, self.style + ) else: - self._netconnect = yarp.Network.connect(self.out_topic, self.out_topic_connect, self.carrier) + self._netconnect = yarp.Network.connect( + self.out_topic, self.out_topic_connect, self.carrier + ) established = self.await_connection(self._port, repeats=repeats) return self.check_establishment(established) @@ -139,8 +174,12 @@ def publish(self, obj): return else: time.sleep(0.2) - obj_str = json.dumps(obj, cls=self._plugin_encoder, **self._plugin_kwargs, - serializer_kwrags=self._serializer_kwargs) + obj_str = json.dumps( + obj, + cls=self._plugin_encoder, + **self._plugin_kwargs, + serializer_kwrags=self._serializer_kwargs, + ) obj_port = self._port.prepare() obj_port.clear() obj_port.addString(obj_str) @@ -150,9 +189,21 @@ def publish(self, obj): @Publishers.register("Image", "yarp") class YarpImagePublisher(YarpPublisher): - def __init__(self, name: str, out_topic: str, carrier: Literal["tcp", "udp", "mcast"] = "tcp", should_wait: bool = True, - persistent: bool = True, out_topic_connect: Optional[str] = None, width: int = -1, height: int = -1, - rgb: bool = True, fp: bool = False, jpg: bool = False, **kwargs): + def __init__( + self, + name: str, + out_topic: str, + carrier: Literal["tcp", "udp", "mcast"] = "tcp", + should_wait: bool = True, + persistent: bool = True, + out_topic_connect: Optional[str] = None, + width: int = -1, + height: int = -1, + rgb: bool = True, + fp: bool = False, + jpg: bool = False, + **kwargs, + ): """ The Image publisher using the BufferedPortImage construct assuming a numpy array as input. @@ -169,8 +220,15 @@ def __init__(self, name: str, out_topic: str, carrier: Literal["tcp", "udp", "mc :param fp: bool: True if the image is floating point, False if it is integer. Default is False :param jpg: bool: True if the image should be compressed as JPG. Default is False """ - super().__init__(name, out_topic, carrier=carrier, should_wait=should_wait, persistent=persistent, - out_topic_connect=out_topic_connect, **kwargs) + super().__init__( + name, + out_topic, + carrier=carrier, + should_wait=should_wait, + persistent=persistent, + out_topic_connect=out_topic_connect, + **kwargs, + ) self.width = width self.height = height self.rgb = rgb @@ -192,15 +250,27 @@ def establish(self, repeats: Optional[int] = None, **kwargs): if self.jpg: self._port = yarp.BufferedPortBottle() elif self.rgb: - self._port = yarp.BufferedPortImageRgbFloat() if self.fp else yarp.BufferedPortImageRgb() + self._port = ( + yarp.BufferedPortImageRgbFloat() + if self.fp + else yarp.BufferedPortImageRgb() + ) else: - self._port = yarp.BufferedPortImageFloat() if self.fp else yarp.BufferedPortImageMono() + self._port = ( + yarp.BufferedPortImageFloat() + if self.fp + else yarp.BufferedPortImageMono() + ) self._type = np.float32 if self.fp else np.uint8 self._port.open(self.out_topic) if self.style.persistent: - self._netconnect = yarp.Network.connect(self.out_topic, self.out_topic_connect, self.style) + self._netconnect = yarp.Network.connect( + self.out_topic, self.out_topic_connect, self.style + ) else: - self._netconnect = yarp.Network.connect(self.out_topic, self.out_topic_connect, self.carrier) + self._netconnect = yarp.Network.connect( + self.out_topic, self.out_topic_connect, self.carrier + ) established = self.await_connection(self._port, repeats=repeats) return self.check_establishment(established) @@ -220,16 +290,22 @@ def publish(self, img: np.ndarray): else: time.sleep(0.2) - if 0 < self.width != img.shape[1] or 0 < self.height != img.shape[0] or \ - not ((img.ndim == 2 and not self.rgb) or (img.ndim == 3 and self.rgb and img.shape[2] == 3)): + if ( + 0 < self.width != img.shape[1] + or 0 < self.height != img.shape[0] + or not ( + (img.ndim == 2 and not self.rgb) + or (img.ndim == 3 and self.rgb and img.shape[2] == 3) + ) + ): raise ValueError("Incorrect image shape for publisher") - img = np.require(img, dtype=self._type, requirements='C') + img = np.require(img, dtype=self._type, requirements="C") if self.jpg: - img_str = np.array(cv2.imencode('.jpg', img)[1]).tostring() + img_str = np.array(cv2.imencode(".jpg", img)[1]).tostring() with io.BytesIO() as memfile: np.save(memfile, img_str) - img_str = base64.b64encode(memfile.getvalue()).decode('ascii') + img_str = base64.b64encode(memfile.getvalue()).decode("ascii") img_port = self._port.prepare() img_port.clear() img_port.addString(img_str) @@ -245,9 +321,19 @@ def publish(self, img: np.ndarray): @Publishers.register("AudioChunk", "yarp") class YarpAudioChunkPublisher(YarpPublisher): - def __init__(self, name: str, out_topic: str, carrier: Literal["tcp", "udp", "mcast"] = "tcp", should_wait: bool = True, - persistent: bool = True, out_topic_connect: Optional[str] = None, - channels: int = 1, rate: int = 44100, chunk: int = -1, **kwargs): + def __init__( + self, + name: str, + out_topic: str, + carrier: Literal["tcp", "udp", "mcast"] = "tcp", + should_wait: bool = True, + persistent: bool = True, + out_topic_connect: Optional[str] = None, + channels: int = 1, + rate: int = 44100, + chunk: int = -1, + **kwargs, + ): """ The AudioChunk publisher using the Sound construct assuming a numpy array as input. @@ -262,8 +348,15 @@ def __init__(self, name: str, out_topic: str, carrier: Literal["tcp", "udp", "mc :param rate: int: Sampling rate. Default is 44100 :param chunk: int: Chunk size. Default is -1 meaning that the chunk size is not fixed """ - super().__init__(name, out_topic, carrier=carrier, should_wait=should_wait, out_topic_connect=out_topic_connect, - persistent=persistent, **kwargs) + super().__init__( + name, + out_topic, + carrier=carrier, + should_wait=should_wait, + out_topic_connect=out_topic_connect, + persistent=persistent, + **kwargs, + ) self.channels = channels self.rate = rate self.chunk = chunk @@ -283,7 +376,9 @@ def establish(self, repeats: Optional[int] = None, **kwargs): # create a dummy sound object for transmitting the sound props. This could be cleaner but left for future impl. self._port = yarp.Port() self._port.open(self.out_topic) - self._netconnect = yarp.Network.connect(self.out_topic, self.out_topic_connect, self.carrier) + self._netconnect = yarp.Network.connect( + self.out_topic, self.out_topic_connect, self.carrier + ) self._sound_msg = yarp.Sound() self._sound_msg.setFrequency(self.rate) self._sound_msg.resize(self.chunk, self.channels) @@ -316,10 +411,12 @@ def publish(self, aud: Tuple[np.ndarray, int]): self.channels = channels if self.channels == -1 else self.channels if 0 < self.chunk != chunk or 0 < self.channels != channels: raise ValueError("Incorrect audio shape for publisher") - aud = np.require(aud, dtype=np.float32, requirements='C') + aud = np.require(aud, dtype=np.float32, requirements="C") for i in range(aud.size): - self._sound_msg.set(int(aud.data[i] * 32767), i) # Convert float samples to 16-bit int + self._sound_msg.set( + int(aud.data[i] * 32767), i + ) # Convert float samples to 16-bit int self._port.write(self._sound_msg) diff --git a/wrapyfi/publishers/zeromq.py b/wrapyfi/publishers/zeromq.py index 63109d6..d531b16 100644 --- a/wrapyfi/publishers/zeromq.py +++ b/wrapyfi/publishers/zeromq.py @@ -22,20 +22,38 @@ PARAM_SUB_PORT = int(os.environ.get("WRAPYFI_ZEROMQ_PARAM_SUB_PORT", 5656)) PARAM_REQREP_PORT = int(os.environ.get("WRAPYFI_ZEROMQ_PARAM_REQREP_PORT", 5659)) PARAM_POLL_INTERVAL = int(os.environ.get("WRAPYFI_ZEROMQ_PARAM_POLL_INTERVAL", 1)) -START_PROXY_BROKER = os.environ.get("WRAPYFI_ZEROMQ_START_PROXY_BROKER", True) != "False" +START_PROXY_BROKER = ( + os.environ.get("WRAPYFI_ZEROMQ_START_PROXY_BROKER", True) != "False" +) PROXY_BROKER_SPAWN = os.environ.get("WRAPYFI_ZEROMQ_PROXY_BROKER_SPAWN", "process") -ZEROMQ_PUBSUB_MONITOR_TOPIC = os.environ.get("WRAPYFI_ZEROMQ_PUBSUB_MONITOR_TOPIC", "ZEROMQ/CONNECTIONS") -ZEROMQ_PUBSUB_MONITOR_LISTENER_SPAWN = os.environ.get("WRAPYFI_ZEROMQ_PUBSUB_MONITOR_LISTENER_SPAWN", "process") +ZEROMQ_PUBSUB_MONITOR_TOPIC = os.environ.get( + "WRAPYFI_ZEROMQ_PUBSUB_MONITOR_TOPIC", "ZEROMQ/CONNECTIONS" +) +ZEROMQ_PUBSUB_MONITOR_LISTENER_SPAWN = os.environ.get( + "WRAPYFI_ZEROMQ_PUBSUB_MONITOR_LISTENER_SPAWN", "process" +) WATCHDOG_POLL_REPEAT = None class ZeroMQPublisher(Publisher): - def __init__(self, name: str, out_topic: str, carrier: str = "tcp", should_wait: bool = True, - socket_ip: str = SOCKET_IP, socket_pub_port: int = SOCKET_PUB_PORT, socket_sub_port: int = SOCKET_SUB_PORT, - start_proxy_broker: bool = START_PROXY_BROKER, proxy_broker_spawn: str = PROXY_BROKER_SPAWN, - pubsub_monitor_topic: str = ZEROMQ_PUBSUB_MONITOR_TOPIC, - pubsub_monitor_listener_spawn: Optional[str] = ZEROMQ_PUBSUB_MONITOR_LISTENER_SPAWN, - zeromq_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + out_topic: str, + carrier: str = "tcp", + should_wait: bool = True, + socket_ip: str = SOCKET_IP, + socket_pub_port: int = SOCKET_PUB_PORT, + socket_sub_port: int = SOCKET_SUB_PORT, + start_proxy_broker: bool = START_PROXY_BROKER, + proxy_broker_spawn: str = PROXY_BROKER_SPAWN, + pubsub_monitor_topic: str = ZEROMQ_PUBSUB_MONITOR_TOPIC, + pubsub_monitor_listener_spawn: Optional[ + str + ] = ZEROMQ_PUBSUB_MONITOR_LISTENER_SPAWN, + zeromq_kwargs: Optional[dict] = None, + **kwargs, + ): """ Initialize the publisher and start the proxy broker if necessary. @@ -54,24 +72,32 @@ def __init__(self, name: str, out_topic: str, carrier: str = "tcp", should_wait: :param kwargs: Additional kwargs for the publisher """ if carrier or carrier != "tcp": - logging.warning("[ZeroMQ] ZeroMQ does not support other carriers than TCP for PUB/SUB pattern. Using TCP.") + logging.warning( + "[ZeroMQ] ZeroMQ does not support other carriers than TCP for PUB/SUB pattern. Using TCP." + ) carrier = "tcp" - super().__init__(name, out_topic, carrier=carrier, should_wait=should_wait, **kwargs) + super().__init__( + name, out_topic, carrier=carrier, should_wait=should_wait, **kwargs + ) self.socket_pub_address = f"{carrier}://{socket_ip}:{socket_pub_port}" self.socket_sub_address = f"{carrier}://{socket_ip}:{socket_sub_port}" - ZeroMQMiddlewarePubSub.activate(socket_pub_address=self.socket_pub_address, - socket_sub_address=self.socket_sub_address, - start_proxy_broker=start_proxy_broker, - proxy_broker_spawn=proxy_broker_spawn, - pubsub_monitor_topic=pubsub_monitor_topic, - pubsub_monitor_listener_spawn=pubsub_monitor_listener_spawn, - **zeromq_kwargs or {}) + ZeroMQMiddlewarePubSub.activate( + socket_pub_address=self.socket_pub_address, + socket_sub_address=self.socket_sub_address, + start_proxy_broker=start_proxy_broker, + proxy_broker_spawn=proxy_broker_spawn, + pubsub_monitor_topic=pubsub_monitor_topic, + pubsub_monitor_listener_spawn=pubsub_monitor_listener_spawn, + **zeromq_kwargs or {}, + ) ZeroMQMiddlewarePubSub().shared_monitor_data.add_topic(self.out_topic) - def await_connection(self, out_topic: Optional[str] = None, repeats: Optional[int] = None): + def await_connection( + self, out_topic: Optional[str] = None, repeats: Optional[int] = None + ): """ Wait for the connection to be established. @@ -90,7 +116,9 @@ def await_connection(self, out_topic: Optional[str] = None, repeats: Optional[in return True while repeats > 0 or repeats <= -1: repeats -= 1 - connected = ZeroMQMiddlewarePubSub().shared_monitor_data.is_connected(out_topic) + connected = ZeroMQMiddlewarePubSub().shared_monitor_data.is_connected( + out_topic + ) if connected: break time.sleep(0.02) @@ -115,8 +143,15 @@ def __del__(self): @Publishers.register("NativeObject", "zeromq") class ZeroMQNativeObjectPublisher(ZeroMQPublisher): - def __init__(self, name: str, out_topic: str, carrier: str = "tcp", should_wait: bool = True, - serializer_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + out_topic: str, + carrier: str = "tcp", + should_wait: bool = True, + serializer_kwargs: Optional[dict] = None, + **kwargs, + ): """ The NativeObjectPublisher using the ZeroMQ message construct assuming a combination of python native objects and numpy arrays as input. Serializes the data (including plugins) using the encoder and sends it as a string. @@ -128,7 +163,9 @@ def __init__(self, name: str, out_topic: str, carrier: str = "tcp", should_wait: :param serializer_kwargs: dict: Additional kwargs for the serializer :param kwargs: dict: Additional kwargs for the publisher """ - super().__init__(name, out_topic, carrier=carrier, should_wait=should_wait, **kwargs) + super().__init__( + name, out_topic, carrier=carrier, should_wait=should_wait, **kwargs + ) self._socket = self._netconnect = None self._plugin_encoder = JsonEncoder @@ -148,9 +185,13 @@ def establish(self, repeats: Optional[int] = None, **kwargs): self._socket = zmq.Context.instance().socket(zmq.PUB) for socket_property in ZeroMQMiddlewarePubSub().zeromq_kwargs.items(): if isinstance(socket_property[1], str): - self._socket.setsockopt_string(getattr(zmq, socket_property[0]), socket_property[1]) + self._socket.setsockopt_string( + getattr(zmq, socket_property[0]), socket_property[1] + ) else: - self._socket.setsockopt(getattr(zmq, socket_property[0]), socket_property[1]) + self._socket.setsockopt( + getattr(zmq, socket_property[0]), socket_property[1] + ) self._socket.connect(self.socket_sub_address) self._topic = self.out_topic.encode() established = self.await_connection(repeats=repeats) @@ -168,16 +209,31 @@ def publish(self, obj): return else: time.sleep(0.2) - obj_str = json.dumps(obj, cls=self._plugin_encoder, **self._plugin_kwargs, - serializer_kwrags=self._serializer_kwargs) + obj_str = json.dumps( + obj, + cls=self._plugin_encoder, + **self._plugin_kwargs, + serializer_kwrags=self._serializer_kwargs, + ) self._socket.send_multipart([self._topic, obj_str.encode()]) @Publishers.register("Image", "zeromq") class ZeroMQImagePublisher(ZeroMQNativeObjectPublisher): - def __init__(self, name: str, out_topic: str, carrier: str = "tcp", should_wait: bool = True, - width: int = -1, height: int = -1, rgb: bool = True, fp: bool = False, jpg: bool = False, **kwargs): + def __init__( + self, + name: str, + out_topic: str, + carrier: str = "tcp", + should_wait: bool = True, + width: int = -1, + height: int = -1, + rgb: bool = True, + fp: bool = False, + jpg: bool = False, + **kwargs, + ): """ The ImagePublisher using the ZeroMQ message construct assuming a numpy array as input. @@ -191,7 +247,9 @@ def __init__(self, name: str, out_topic: str, carrier: str = "tcp", should_wait: :param fp: bool: True if the image is floating point, False if it is integer. Default is False :param jpg: bool: True if the image should be compressed as JPG. Default is False """ - super().__init__(name, out_topic, carrier=carrier, should_wait=should_wait, **kwargs) + super().__init__( + name, out_topic, carrier=carrier, should_wait=should_wait, **kwargs + ) self.width = width self.height = height self.rgb = rgb @@ -215,25 +273,44 @@ def publish(self, img: np.ndarray): return else: time.sleep(0.2) - if 0 < self.width != img.shape[1] or 0 < self.height != img.shape[0] or \ - not ((img.ndim == 2 and not self.rgb) or (img.ndim == 3 and self.rgb and img.shape[2] == 3)): + if ( + 0 < self.width != img.shape[1] + or 0 < self.height != img.shape[0] + or not ( + (img.ndim == 2 and not self.rgb) + or (img.ndim == 3 and self.rgb and img.shape[2] == 3) + ) + ): raise ValueError("Incorrect image shape for publisher") - if not img.flags['C_CONTIGUOUS']: + if not img.flags["C_CONTIGUOUS"]: img = np.ascontiguousarray(img) if self.jpg: - img_str = np.array(cv2.imencode('.jpg', img)[1]).tostring() + img_str = np.array(cv2.imencode(".jpg", img)[1]).tostring() else: - img_str = json.dumps(img, cls=self._plugin_encoder, **self._plugin_kwargs, - serializer_kwrags=self._serializer_kwargs).encode() - img_header = '{timestamp:' + str(time.time()) + '}' + img_str = json.dumps( + img, + cls=self._plugin_encoder, + **self._plugin_kwargs, + serializer_kwrags=self._serializer_kwargs, + ).encode() + img_header = "{timestamp:" + str(time.time()) + "}" self._socket.send_multipart([self._topic, img_header.encode(), img_str]) @Publishers.register("AudioChunk", "zeromq") class ZeroMQAudioChunkPublisher(ZeroMQNativeObjectPublisher): - def __init__(self, name: str, out_topic: str, carrier: str = "tcp", should_wait: bool = True, - channels: int = 1, rate: int = 44100, chunk: int = -1, **kwargs): + def __init__( + self, + name: str, + out_topic: str, + carrier: str = "tcp", + should_wait: bool = True, + channels: int = 1, + rate: int = 44100, + chunk: int = -1, + **kwargs, + ): """ The AudioChunkPublisher using the ZeroMQ message construct assuming a numpy array as input. @@ -245,8 +322,18 @@ def __init__(self, name: str, out_topic: str, carrier: str = "tcp", should_wait: :param rate: int: Sampling rate. Default is 44100 :param chunk: int: Chunk size. Default is -1 meaning that the chunk size is not fixed """ - super().__init__(name, out_topic, carrier=carrier, should_wait=should_wait, - width=chunk, height=channels, rgb=False, fp=True, jpg=False, **kwargs) + super().__init__( + name, + out_topic, + carrier=carrier, + should_wait=should_wait, + width=chunk, + height=channels, + rgb=False, + fp=True, + jpg=False, + **kwargs, + ) self.channels = channels self.rate = rate self.chunk = chunk @@ -274,11 +361,15 @@ def publish(self, aud: Tuple[np.ndarray, int]): self.channels = channels if self.channels == -1 else self.channels if 0 < self.chunk != chunk or 0 < self.channels != channels: raise ValueError("Incorrect audio shape for publisher") - aud = np.require(aud, dtype=np.float32, requirements='C') - - aud_str = json.dumps((chunk, channels, rate, aud), cls=self._plugin_encoder, **self._plugin_kwargs, - serializer_kwrags=self._serializer_kwargs).encode() - aud_header = '{timestamp:' + str(time.time()) + '}' + aud = np.require(aud, dtype=np.float32, requirements="C") + + aud_str = json.dumps( + (chunk, channels, rate, aud), + cls=self._plugin_encoder, + **self._plugin_kwargs, + serializer_kwrags=self._serializer_kwargs, + ).encode() + aud_header = "{timestamp:" + str(time.time()) + "}" self._socket.send_multipart([self._topic, aud_header.encode(), aud_str]) diff --git a/wrapyfi/servers/__init__.py b/wrapyfi/servers/__init__.py index ef2329e..9ef2390 100755 --- a/wrapyfi/servers/__init__.py +++ b/wrapyfi/servers/__init__.py @@ -6,9 +6,18 @@ @Servers.register("MMO", "fallback") class FallbackServer(Server): - def __init__(self, name: str, out_topic: str, carrier: str = "", missing_middleware_object: str = "", **kwargs): - logging.warning(f"Fallback server employed due to missing middleware or object type: " - f"{missing_middleware_object}") + def __init__( + self, + name: str, + out_topic: str, + carrier: str = "", + missing_middleware_object: str = "", + **kwargs, + ): + logging.warning( + f"Fallback server employed due to missing middleware or object type: " + f"{missing_middleware_object}" + ) Server.__init__(self, name, out_topic, carrier=carrier, **kwargs) self.missing_middleware_object = missing_middleware_object diff --git a/wrapyfi/servers/ros.py b/wrapyfi/servers/ros.py index c5c1fda..b22753b 100755 --- a/wrapyfi/servers/ros.py +++ b/wrapyfi/servers/ros.py @@ -13,13 +13,24 @@ import sensor_msgs.msg from wrapyfi.connect.servers import Server, Servers -from wrapyfi.middlewares.ros import ROSMiddleware, ROSNativeObjectService, ROSImageService +from wrapyfi.middlewares.ros import ( + ROSMiddleware, + ROSNativeObjectService, + ROSImageService, +) from wrapyfi.encoders import JsonEncoder, JsonDecodeHook class ROSServer(Server): - def __init__(self, name: str, out_topic: str, carrier: str = "tcp", ros_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + out_topic: str, + carrier: str = "tcp", + ros_kwargs: Optional[dict] = None, + **kwargs, + ): """ Initialize the server. @@ -30,7 +41,9 @@ def __init__(self, name: str, out_topic: str, carrier: str = "tcp", ros_kwargs: :param kwargs: dict: Additional kwargs for the server """ if carrier or carrier != "tcp": - logging.warning("[ROS] ROS does not support other carriers than TCP for REQ/REP pattern. Using TCP.") + logging.warning( + "[ROS] ROS does not support other carriers than TCP for REQ/REP pattern. Using TCP." + ) carrier = "tcp" super().__init__(name, out_topic, carrier=carrier, **kwargs) ROSMiddleware.activate(**ros_kwargs or {}) @@ -52,8 +65,15 @@ class ROSNativeObjectServer(ROSServer): SEND_QUEUE = queue.Queue(maxsize=1) RECEIVE_QUEUE = queue.Queue(maxsize=1) - def __init__(self, name: str, out_topic: str, carrier: str = "tcp", - serializer_kwargs: Optional[dict] = None, deserializer_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + out_topic: str, + carrier: str = "tcp", + serializer_kwargs: Optional[dict] = None, + deserializer_kwargs: Optional[dict] = None, + **kwargs, + ): """ Specific server handling native Python objects, serializing them to JSON strings for transmission. @@ -77,7 +97,9 @@ def establish(self): """ Establish the connection to the server. """ - self._server = rospy.Service(self.out_topic, ROSNativeObjectService, self._service_callback) + self._server = rospy.Service( + self.out_topic, ROSNativeObjectService, self._service_callback + ) self.established = True def await_request(self, *args, **kwargs): @@ -94,7 +116,11 @@ def await_request(self, *args, **kwargs): self.establish() try: request = ROSNativeObjectServer.RECEIVE_QUEUE.get(block=True) - [args, kwargs] = json.loads(request.data, object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs) + [args, kwargs] = json.loads( + request.data, + object_hook=self._plugin_decoder_hook, + **self._deserializer_kwargs, + ) return args, kwargs except rospy.ServiceException as e: logging.error("[ROS] Service call failed: %s" % e) @@ -118,14 +144,20 @@ def reply(self, obj): :param obj: Any: The Python object to be serialized and sent """ try: - obj_str = json.dumps(obj, cls=self._plugin_encoder, **self._plugin_kwargs, - serializer_kwrags=self._serializer_kwargs) + obj_str = json.dumps( + obj, + cls=self._plugin_encoder, + **self._plugin_kwargs, + serializer_kwrags=self._serializer_kwargs, + ) obj_msg = std_msgs.msg.String() obj_msg.data = obj_str ROSNativeObjectServer.SEND_QUEUE.put(obj_msg, block=False) except queue.Full: - logging.warning(f"[ROS] Discarding data because queue is full. " - f"This happened due to bad synchronization in {self.__name__}") + logging.warning( + f"[ROS] Discarding data because queue is full. " + f"This happened due to bad synchronization in {self.__name__}" + ) @Servers.register("Image", "ros") @@ -133,9 +165,18 @@ class ROSImageServer(ROSServer): SEND_QUEUE = queue.Queue(maxsize=1) RECEIVE_QUEUE = queue.Queue(maxsize=1) - def __init__(self, name: str, out_topic: str, carrier: str = "tcp", - width: int = -1, height: int = -1, rgb: bool = True, fp: bool = False, - deserializer_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + out_topic: str, + carrier: str = "tcp", + width: int = -1, + height: int = -1, + rgb: bool = True, + fp: bool = False, + deserializer_kwargs: Optional[dict] = None, + **kwargs, + ): """ Specific server handling native Python objects, serializing them to JSON strings for transmission. @@ -150,7 +191,9 @@ def __init__(self, name: str, out_topic: str, carrier: str = "tcp", """ super().__init__(name, out_topic, carrier=carrier, **kwargs) if "jpg" in kwargs: - logging.warning("[ROS] ROS currently does not support JPG encoding in REQ/REP. Using raw image.") + logging.warning( + "[ROS] ROS currently does not support JPG encoding in REQ/REP. Using raw image." + ) kwargs.pop("jpg") self.width = width self.height = height @@ -158,10 +201,10 @@ def __init__(self, name: str, out_topic: str, carrier: str = "tcp", self.fp = fp if self.fp: - self._encoding = '32FC3' if self.rgb else '32FC1' + self._encoding = "32FC3" if self.rgb else "32FC1" self._type = np.float32 else: - self._encoding = 'bgr8' if self.rgb else 'mono8' + self._encoding = "bgr8" if self.rgb else "mono8" self._type = np.uint8 self._plugin_kwargs = kwargs @@ -174,7 +217,9 @@ def establish(self): """ Establish the connection to the server. """ - self._server = rospy.Service(self.out_topic, ROSImageService, self._service_callback) + self._server = rospy.Service( + self.out_topic, ROSImageService, self._service_callback + ) self.established = True def await_request(self, *args, **kwargs): @@ -191,7 +236,11 @@ def await_request(self, *args, **kwargs): self.establish() try: request = ROSImageServer.RECEIVE_QUEUE.get(block=True) - [args, kwargs] = json.loads(request.data, object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs) + [args, kwargs] = json.loads( + request.data, + object_hook=self._plugin_decoder_hook, + **self._deserializer_kwargs, + ) return args, kwargs except rospy.ServiceException as e: logging.error("[ROS] Service call failed: %s" % e) @@ -215,22 +264,32 @@ def reply(self, img: np.ndarray): :param img: np.ndarray: Image to publish """ try: - if 0 < self.width != img.shape[1] or 0 < self.height != img.shape[0] or \ - not ((img.ndim == 2 and not self.rgb) or (img.ndim == 3 and self.rgb and img.shape[2] == 3)): + if ( + 0 < self.width != img.shape[1] + or 0 < self.height != img.shape[0] + or not ( + (img.ndim == 2 and not self.rgb) + or (img.ndim == 3 and self.rgb and img.shape[2] == 3) + ) + ): raise ValueError("Incorrect image shape for publisher") - img = np.require(img, dtype=self._type, requirements='C') + img = np.require(img, dtype=self._type, requirements="C") img_msg = sensor_msgs.msg.Image() img_msg.header.stamp = rospy.Time.now() img_msg.height = img.shape[0] img_msg.width = img.shape[1] img_msg.encoding = self._encoding - img_msg.is_bigendian = img.dtype.byteorder == '>' or (img.dtype.byteorder == '=' and sys.byteorder == 'big') + img_msg.is_bigendian = img.dtype.byteorder == ">" or ( + img.dtype.byteorder == "=" and sys.byteorder == "big" + ) img_msg.step = img.strides[0] img_msg.data = img.tobytes() ROSImageServer.SEND_QUEUE.put(img_msg, block=False) except queue.Full: - logging.warning(f"[ROS] Discarding data because queue is full. " - f"This happened due to bad synchronization in {self.__name__}") + logging.warning( + f"[ROS] Discarding data because queue is full. " + f"This happened due to bad synchronization in {self.__name__}" + ) @Servers.register("AudioChunk", "ros") @@ -238,9 +297,17 @@ class ROSAudioChunkServer(ROSServer): SEND_QUEUE = queue.Queue(maxsize=1) RECEIVE_QUEUE = queue.Queue(maxsize=1) - def __init__(self, name: str, out_topic: str, carrier: str = "tcp", - channels: int = 1, rate: int = 44100, chunk: int = -1, - deserializer_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + out_topic: str, + carrier: str = "tcp", + channels: int = 1, + rate: int = 44100, + chunk: int = -1, + deserializer_kwargs: Optional[dict] = None, + **kwargs, + ): """ Specific server handling audio data as numpy arrays. @@ -271,12 +338,18 @@ def establish(self): from wrapyfi_ros_interfaces.srv import ROSAudioService except ImportError: import wrapyfi - logging.error("[ROS] Could not import ROSAudioService. " - "Make sure the ROS services in wrapyfi_extensions/wrapyfi_ros_interfaces are compiled. " - "Refer to the documentation for more information: \n" + - wrapyfi.__doc__ + "ros_interfaces_lnk.html") + + logging.error( + "[ROS] Could not import ROSAudioService. " + "Make sure the ROS services in wrapyfi_extensions/wrapyfi_ros_interfaces are compiled. " + "Refer to the documentation for more information: \n" + + wrapyfi.__doc__ + + "ros_interfaces_lnk.html" + ) sys.exit(1) - self._server = rospy.Service(self.out_topic, ROSAudioService, self._service_callback) + self._server = rospy.Service( + self.out_topic, ROSAudioService, self._service_callback + ) self._rep_msg = ROSAudioService._response_class().response self.established = True @@ -294,8 +367,11 @@ def await_request(self, *args, **kwargs): self.establish() try: request = ROSAudioChunkServer.RECEIVE_QUEUE.get(block=True) - [args, kwargs] = json.loads(request.request, object_hook=self._plugin_decoder_hook, - **self._deserializer_kwargs) + [args, kwargs] = json.loads( + request.request, + object_hook=self._plugin_decoder_hook, + **self._deserializer_kwargs, + ) return args, kwargs except rospy.ServiceException as e: logging.error("[ROS] Service call failed: %s" % e) @@ -329,18 +405,22 @@ def reply(self, aud: Tuple[np.ndarray, int]): self.channels = channels if self.channels == -1 else self.channels if 0 < self.chunk != chunk or 0 < self.channels != channels: raise ValueError("Incorrect audio shape for publisher") - aud = np.require(aud, dtype=np.float32, requirements='C') + aud = np.require(aud, dtype=np.float32, requirements="C") aud_msg = self._rep_msg aud_msg.header.stamp = rospy.Time.now() aud_msg.chunk_size = chunk aud_msg.channels = channels aud_msg.sample_rate = rate - aud_msg.is_bigendian = aud.dtype.byteorder == '>' or (aud.dtype.byteorder == '=' and sys.byteorder == 'big') - aud_msg.encoding = 'S16BE' if aud_msg.is_bigendian else 'S16LE' + aud_msg.is_bigendian = aud.dtype.byteorder == ">" or ( + aud.dtype.byteorder == "=" and sys.byteorder == "big" + ) + aud_msg.encoding = "S16BE" if aud_msg.is_bigendian else "S16LE" aud_msg.step = aud.strides[0] aud_msg.data = aud.tobytes() # (aud * 32767.0).tobytes() ROSAudioChunkServer.SEND_QUEUE.put(aud_msg, block=False) except queue.Full: - logging.warning(f"[ROS] Discarding data because queue is full. " - f"This happened due to bad synchronization in {self.__name__}") + logging.warning( + f"[ROS] Discarding data because queue is full. " + f"This happened due to bad synchronization in {self.__name__}" + ) diff --git a/wrapyfi/servers/ros2.py b/wrapyfi/servers/ros2.py index 89f16ae..7ab1eda 100755 --- a/wrapyfi/servers/ros2.py +++ b/wrapyfi/servers/ros2.py @@ -22,7 +22,9 @@ class ROS2Server(Server, Node): - def __init__(self, name: str, out_topic: str, ros2_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, name: str, out_topic: str, ros2_kwargs: Optional[dict] = None, **kwargs + ): """ Initialize the server. @@ -34,7 +36,8 @@ def __init__(self, name: str, out_topic: str, ros2_kwargs: Optional[dict] = None carrier = "tcp" if "carrier" in kwargs and kwargs["carrier"] not in ["", None]: logging.warning( - "[ROS 2] ROS 2 currently does not support explicit carrier setting for REQ/REP pattern. Using TCP.") + "[ROS 2] ROS 2 currently does not support explicit carrier setting for REQ/REP pattern. Using TCP." + ) if "carrier" in kwargs: del kwargs["carrier"] @@ -62,8 +65,14 @@ class ROS2NativeObjectServer(ROS2Server): SEND_QUEUE = queue.Queue(maxsize=1) RECEIVE_QUEUE = queue.Queue(maxsize=1) - def __init__(self, name: str, out_topic: str, - serializer_kwargs: Optional[dict] = None, deserializer_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + out_topic: str, + serializer_kwargs: Optional[dict] = None, + deserializer_kwargs: Optional[dict] = None, + **kwargs, + ): """ Specific server handling native Python objects, serializing them to JSON strings for transmission. @@ -89,13 +98,19 @@ def establish(self): from wrapyfi_ros2_interfaces.srv import ROS2NativeObjectService except ImportError: import wrapyfi - logging.error("[ROS 2] Could not import ROS2NativeObjectService. " - "Make sure the ROS 2 services in wrapyfi_extensions/wrapyfi_ros2_interfaces are compiled. " - "Refer to the documentation for more information: \n" + - wrapyfi.__doc__ + "ros2_interfaces_lnk.html") + + logging.error( + "[ROS 2] Could not import ROS2NativeObjectService. " + "Make sure the ROS 2 services in wrapyfi_extensions/wrapyfi_ros2_interfaces are compiled. " + "Refer to the documentation for more information: \n" + + wrapyfi.__doc__ + + "ros2_interfaces_lnk.html" + ) sys.exit(1) - self._server = self.create_service(ROS2NativeObjectService, self.out_topic, self._service_callback) + self._server = self.create_service( + ROS2NativeObjectService, self.out_topic, self._service_callback + ) self._req_msg = ROS2NativeObjectService.Request() self._rep_msg = ROS2NativeObjectService.Response() @@ -114,13 +129,18 @@ def await_request(self, *args, **kwargs): if not self.established: self.establish() try: - self._background_callback = threading.Thread(name='ros2_server', target=rclpy.spin_once, - args=(self,), kwargs={}) + self._background_callback = threading.Thread( + name="ros2_server", target=rclpy.spin_once, args=(self,), kwargs={} + ) self._background_callback.setDaemon(True) self._background_callback.start() request = ROS2NativeObjectServer.RECEIVE_QUEUE.get(block=True) - [args, kwargs] = json.loads(request.request, object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs) + [args, kwargs] = json.loads( + request.request, + object_hook=self._plugin_decoder_hook, + **self._deserializer_kwargs, + ) return args, kwargs except Exception as e: logging.error("[ROS 2] Service call failed %s" % e) @@ -145,13 +165,19 @@ def reply(self, obj): :param obj: Any: The Python object to be serialized and sent """ try: - obj_str = json.dumps(obj, cls=self._plugin_encoder, **self._plugin_kwargs, - serializer_kwrags=self._serializer_kwargs) + obj_str = json.dumps( + obj, + cls=self._plugin_encoder, + **self._plugin_kwargs, + serializer_kwrags=self._serializer_kwargs, + ) self._rep_msg.response = obj_str ROS2NativeObjectServer.SEND_QUEUE.put(self._rep_msg, block=False) except queue.Full: - logging.warning(f"[ROS 2] Discarding data because queue is full. " - f"This happened due to bad synchronization in {self.__name__}") + logging.warning( + f"[ROS 2] Discarding data because queue is full. " + f"This happened due to bad synchronization in {self.__name__}" + ) @Servers.register("Image", "ros2") @@ -159,9 +185,18 @@ class ROS2ImageServer(ROS2Server): SEND_QUEUE = queue.Queue(maxsize=1) RECEIVE_QUEUE = queue.Queue(maxsize=1) - def __init__(self, name: str, out_topic: str, - width: int = -1, height: int = -1, rgb: bool = True, fp: bool = False, jpg: bool = False, - deserializer_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + out_topic: str, + width: int = -1, + height: int = -1, + rgb: bool = True, + fp: bool = False, + jpg: bool = False, + deserializer_kwargs: Optional[dict] = None, + **kwargs, + ): """ Specific server handling native Python objects, serializing them to JSON strings for transmission. @@ -187,13 +222,13 @@ def __init__(self, name: str, out_topic: str, self.jpg = jpg if self.fp: - self._encoding = '32FC3' if self.rgb else '32FC1' + self._encoding = "32FC3" if self.rgb else "32FC1" self._type = np.float32 else: - self._encoding = 'bgr8' if self.rgb else 'mono8' + self._encoding = "bgr8" if self.rgb else "mono8" self._type = np.uint8 if self.jpg: - self._encoding = 'jpeg' + self._encoding = "jpeg" self._type = np.uint8 self._server = None @@ -203,20 +238,31 @@ def establish(self): Establish the connection to the server. """ try: - from wrapyfi_ros2_interfaces.srv import ROS2ImageService, ROS2CompressedImageService + from wrapyfi_ros2_interfaces.srv import ( + ROS2ImageService, + ROS2CompressedImageService, + ) except ImportError: import wrapyfi - logging.error("[ROS 2] Could not import ROS2NativeObjectService. " - "Make sure the ROS 2 services in wrapyfi_extensions/wrapyfi_ros2_interfaces are compiled. " - "Refer to the documentation for more information: \n" + - wrapyfi.__doc__ + "ros2_interfaces_lnk.html") + + logging.error( + "[ROS 2] Could not import ROS2NativeObjectService. " + "Make sure the ROS 2 services in wrapyfi_extensions/wrapyfi_ros2_interfaces are compiled. " + "Refer to the documentation for more information: \n" + + wrapyfi.__doc__ + + "ros2_interfaces_lnk.html" + ) sys.exit(1) if self.jpg: - self._server = self.create_service(ROS2CompressedImageService, self.out_topic, self._service_callback) + self._server = self.create_service( + ROS2CompressedImageService, self.out_topic, self._service_callback + ) self._req_msg = ROS2CompressedImageService.Request() self._rep_msg = ROS2CompressedImageService.Response() else: - self._server = self.create_service(ROS2ImageService, self.out_topic, self._service_callback) + self._server = self.create_service( + ROS2ImageService, self.out_topic, self._service_callback + ) self._req_msg = ROS2ImageService.Request() self._rep_msg = ROS2ImageService.Response() self.established = True @@ -234,13 +280,18 @@ def await_request(self, *args, **kwargs): if not self.established: self.establish() try: - self._background_callback = threading.Thread(name='ros2_server', target=rclpy.spin_once, - args=(self,), kwargs={}) + self._background_callback = threading.Thread( + name="ros2_server", target=rclpy.spin_once, args=(self,), kwargs={} + ) self._background_callback.setDaemon(True) self._background_callback.start() request = ROS2ImageServer.RECEIVE_QUEUE.get(block=True) - [args, kwargs] = json.loads(request.request, object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs) + [args, kwargs] = json.loads( + request.request, + object_hook=self._plugin_decoder_hook, + **self._deserializer_kwargs, + ) return args, kwargs except Exception as e: logging.error("[ROS 2] Service call failed %s" % e) @@ -265,28 +316,38 @@ def reply(self, img: np.ndarray): :param img: np.ndarray: Image to send formatted as a cv2 image - np.ndarray[img_height, img_width, channels] """ try: - if 0 < self.width != img.shape[1] or 0 < self.height != img.shape[0] or \ - not ((img.ndim == 2 and not self.rgb) or (img.ndim == 3 and self.rgb and img.shape[2] == 3)): + if ( + 0 < self.width != img.shape[1] + or 0 < self.height != img.shape[0] + or not ( + (img.ndim == 2 and not self.rgb) + or (img.ndim == 3 and self.rgb and img.shape[2] == 3) + ) + ): raise ValueError("Incorrect image shape for publisher") - img = np.require(img, dtype=self._type, requirements='C') + img = np.require(img, dtype=self._type, requirements="C") img_msg = self._rep_msg.response if self.jpg: img_msg.header.stamp = rclpy.clock.Clock().now().to_msg() img_msg.format = "jpeg" - img_msg.data = np.array(cv2.imencode('.jpg', img)[1]).tobytes() + img_msg.data = np.array(cv2.imencode(".jpg", img)[1]).tobytes() else: img_msg.header.stamp = rclpy.clock.Clock().now().to_msg() img_msg.height = img.shape[0] img_msg.width = img.shape[1] img_msg.encoding = self._encoding - img_msg.is_bigendian = img.dtype.byteorder == '>' or (img.dtype.byteorder == '=' and sys.byteorder == 'big') + img_msg.is_bigendian = img.dtype.byteorder == ">" or ( + img.dtype.byteorder == "=" and sys.byteorder == "big" + ) img_msg.step = img.strides[0] img_msg.data = img.tobytes() self._rep_msg.response = img_msg ROS2ImageServer.SEND_QUEUE.put(self._rep_msg, block=False) except queue.Full: - logging.warning(f"[ROS 2] Discarding data because queue is full. " - f"This happened due to bad synchronization in {self.__name__}") + logging.warning( + f"[ROS 2] Discarding data because queue is full. " + f"This happened due to bad synchronization in {self.__name__}" + ) @Servers.register("AudioChunk", "ros2") @@ -294,9 +355,16 @@ class ROS2AudioChunkServer(ROS2Server): SEND_QUEUE = queue.Queue(maxsize=1) RECEIVE_QUEUE = queue.Queue(maxsize=1) - def __init__(self, name: str, out_topic: str, - channels: int = 1, rate: int = 44100, chunk: int = -1, - deserializer_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + out_topic: str, + channels: int = 1, + rate: int = 44100, + chunk: int = -1, + deserializer_kwargs: Optional[dict] = None, + **kwargs, + ): """ Specific server handling audio data as numpy arrays. @@ -325,12 +393,18 @@ def establish(self): from wrapyfi_ros2_interfaces.srv import ROS2AudioService except ImportError: import wrapyfi - logging.error("[ROS 2] Could not import ROS2AudioService. " - "Make sure the ROS 2 services in wrapyfi_extensions/wrapyfi_ros2_interfaces are compiled. " - "Refer to the documentation for more information: \n" + - wrapyfi.__doc__ + "ros2_interfaces_lnk.html") + + logging.error( + "[ROS 2] Could not import ROS2AudioService. " + "Make sure the ROS 2 services in wrapyfi_extensions/wrapyfi_ros2_interfaces are compiled. " + "Refer to the documentation for more information: \n" + + wrapyfi.__doc__ + + "ros2_interfaces_lnk.html" + ) sys.exit(1) - self._server = self.create_service(ROS2AudioService, self.out_topic, self._service_callback) + self._server = self.create_service( + ROS2AudioService, self.out_topic, self._service_callback + ) self._req_msg = ROS2AudioService.Request() self._rep_msg = ROS2AudioService.Response() self.established = True @@ -348,13 +422,18 @@ def await_request(self, *args, **kwargs): if not self.established: self.establish() try: - self._background_callback = threading.Thread(name='ros2_server', target=rclpy.spin_once, - args=(self,), kwargs={}) + self._background_callback = threading.Thread( + name="ros2_server", target=rclpy.spin_once, args=(self,), kwargs={} + ) self._background_callback.setDaemon(True) self._background_callback.start() request = ROS2AudioChunkServer.RECEIVE_QUEUE.get(block=True) - [args, kwargs] = json.loads(request.request, object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs) + [args, kwargs] = json.loads( + request.request, + object_hook=self._plugin_decoder_hook, + **self._deserializer_kwargs, + ) return args, kwargs except Exception as e: logging.error("[ROS 2] Service call failed %s" % e) @@ -389,19 +468,23 @@ def reply(self, aud: Tuple[np.ndarray, int]): self.channels = channels if self.channels == -1 else self.channels if 0 < self.chunk != chunk or 0 < self.channels != channels: raise ValueError("Incorrect audio shape for publisher") - aud = np.require(aud, dtype=np.float32, requirements='C') + aud = np.require(aud, dtype=np.float32, requirements="C") aud_msg = self._rep_msg.response aud_msg.header.stamp = self.get_clock().now().to_msg() aud_msg.chunk_size = chunk aud_msg.channels = channels aud_msg.sample_rate = rate - aud_msg.is_bigendian = aud.dtype.byteorder == '>' or (aud.dtype.byteorder == '=' and sys.byteorder == 'big') - aud_msg.encoding = 'S16BE' if aud_msg.is_bigendian else 'S16LE' + aud_msg.is_bigendian = aud.dtype.byteorder == ">" or ( + aud.dtype.byteorder == "=" and sys.byteorder == "big" + ) + aud_msg.encoding = "S16BE" if aud_msg.is_bigendian else "S16LE" aud_msg.step = aud.strides[0] aud_msg.data = aud.tobytes() # (aud * 32767.0).tobytes() self._rep_msg.response = aud_msg ROS2AudioChunkServer.SEND_QUEUE.put(self._rep_msg, block=False) except queue.Full: - logging.warning(f"[ROS 2] Discarding data because queue is full. " - f"This happened due to bad synchronization in {self.__name__}") + logging.warning( + f"[ROS 2] Discarding data because queue is full. " + f"This happened due to bad synchronization in {self.__name__}" + ) diff --git a/wrapyfi/servers/yarp.py b/wrapyfi/servers/yarp.py index ed67075..cdc0c53 100755 --- a/wrapyfi/servers/yarp.py +++ b/wrapyfi/servers/yarp.py @@ -12,9 +12,16 @@ class YarpServer(Server): - def __init__(self, name: str, out_topic: str, carrier: Literal["tcp", "udp", "mcast"] = "tcp", - out_topic_connect: Optional[str] = None, persistent: bool = True, - yarp_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + out_topic: str, + carrier: Literal["tcp", "udp", "mcast"] = "tcp", + out_topic_connect: Optional[str] = None, + persistent: bool = True, + yarp_kwargs: Optional[dict] = None, + **kwargs, + ): """ Initialize the server. @@ -26,7 +33,13 @@ def __init__(self, name: str, out_topic: str, carrier: Literal["tcp", "udp", "mc :param yarp_kwargs: dict: Additional kwargs for the Yarp middleware :param kwargs: dict: Additional kwargs for the server """ - super().__init__(name, out_topic, carrier=carrier, out_topic_connect=out_topic_connect, **kwargs) + super().__init__( + name, + out_topic, + carrier=carrier, + out_topic_connect=out_topic_connect, + **kwargs, + ) YarpMiddleware.activate(**yarp_kwargs or {}) self.style = yarp.ContactStyle() self.style.persistent = persistent @@ -49,9 +62,17 @@ def __del__(self): @Servers.register("NativeObject", "yarp") class YarpNativeObjectServer(YarpServer): - def __init__(self, name: str, out_topic: str, carrier: Literal["tcp", "udp", "mcast"] = "tcp", - out_topic_connect: Optional[str] = None, persistent: bool = True, - serializer_kwargs: Optional[dict] = None, deserializer_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + out_topic: str, + carrier: Literal["tcp", "udp", "mcast"] = "tcp", + out_topic_connect: Optional[str] = None, + persistent: bool = True, + serializer_kwargs: Optional[dict] = None, + deserializer_kwargs: Optional[dict] = None, + **kwargs, + ): """ Specific server handling native Python objects, serializing them to JSON strings for transmission. @@ -64,7 +85,14 @@ def __init__(self, name: str, out_topic: str, carrier: Literal["tcp", "udp", "mc :param serializer_kwargs: dict: Additional kwargs for the serializer :param deserializer_kwargs: dict: Additional kwargs for the deserializer """ - super().__init__(name, out_topic, carrier=carrier, out_topic_connect=out_topic_connect, persistent=persistent, **kwargs) + super().__init__( + name, + out_topic, + carrier=carrier, + out_topic_connect=out_topic_connect, + persistent=persistent, + **kwargs, + ) self._plugin_encoder = JsonEncoder self._plugin_kwargs = kwargs self._serializer_kwargs = serializer_kwargs or {} @@ -80,11 +108,17 @@ def establish(self): self._port = yarp.RpcServer() self._port.open(self.out_topic) if self.style.persistent: - self._netconnect = yarp.Network.connect(self.out_topic, self.out_topic_connect, self.style) + self._netconnect = yarp.Network.connect( + self.out_topic, self.out_topic_connect, self.style + ) else: - self._netconnect = yarp.Network.connect(self.out_topic, self.out_topic_connect, self.carrier) + self._netconnect = yarp.Network.connect( + self.out_topic, self.out_topic_connect, self.carrier + ) - self._netconnect = yarp.Network.connect(self.out_topic, self.out_topic_connect, self.carrier) + self._netconnect = yarp.Network.connect( + self.out_topic, self.out_topic_connect, self.carrier + ) if self.persistent: self.established = True @@ -106,7 +140,11 @@ def await_request(self, *args, **kwargs): request = False while not request: request = self._port.read(obj_msg, True) - [args, kwargs] = json.loads(obj_msg.get(0).asString(), object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs) + [args, kwargs] = json.loads( + obj_msg.get(0).asString(), + object_hook=self._plugin_decoder_hook, + **self._deserializer_kwargs, + ) return args, kwargs except Exception as e: logging.error("[YARP] Service call failed: %s" % e) @@ -119,8 +157,12 @@ def reply(self, obj): :param obj: Any: The Python object to be serialized and sent """ - obj_str = json.dumps(obj, cls=self._plugin_encoder, **self._plugin_kwargs, - serializer_kwrags=self._serializer_kwargs) + obj_str = json.dumps( + obj, + cls=self._plugin_encoder, + **self._plugin_kwargs, + serializer_kwrags=self._serializer_kwargs, + ) obj_msg = yarp.Bottle() obj_msg.clear() obj_msg.addString(obj_str) @@ -132,10 +174,20 @@ def reply(self, obj): @Servers.register("Image", "yarp") class YarpImageServer(YarpNativeObjectServer): - def __init__(self, name: str, out_topic: str, carrier: Literal["tcp", "udp", "mcast"] = "tcp", - out_topic_connect: Optional[str] = None, persistent: bool = True, - width: int = -1, height: int = -1, rgb: bool = True, fp: bool = False, - deserializer_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + out_topic: str, + carrier: Literal["tcp", "udp", "mcast"] = "tcp", + out_topic_connect: Optional[str] = None, + persistent: bool = True, + width: int = -1, + height: int = -1, + rgb: bool = True, + fp: bool = False, + deserializer_kwargs: Optional[dict] = None, + **kwargs, + ): """ Specific server handling image data as numpy arrays, serializing them to JSON strings for transmission. @@ -152,9 +204,19 @@ def __init__(self, name: str, out_topic: str, carrier: Literal["tcp", "udp", "mc :param deserializer_kwargs: dict: Additional kwargs for the deserializer """ if "jpg" in kwargs: - logging.warning("[YARP] YARP currently does not support JPG encoding in REQ/REP. Using raw image.") + logging.warning( + "[YARP] YARP currently does not support JPG encoding in REQ/REP. Using raw image." + ) kwargs.pop("jpg") - super().__init__(name, out_topic, carrier=carrier, out_topic_connect=out_topic_connect, persistent=persistent, deserializer_kwargs=deserializer_kwargs, **kwargs) + super().__init__( + name, + out_topic, + carrier=carrier, + out_topic_connect=out_topic_connect, + persistent=persistent, + deserializer_kwargs=deserializer_kwargs, + **kwargs, + ) self.width = width self.height = height self.rgb = rgb @@ -166,8 +228,14 @@ def reply(self, img: np.ndarray): :param img: np.ndarray: Image to send formatted as a cv2 image - np.ndarray[img_height, img_width, channels] """ - if 0 < self.width != img.shape[1] or 0 < self.height != img.shape[0] or \ - not ((img.ndim == 2 and not self.rgb) or (img.ndim == 3 and self.rgb and img.shape[2] == 3)): + if ( + 0 < self.width != img.shape[1] + or 0 < self.height != img.shape[0] + or not ( + (img.ndim == 2 and not self.rgb) + or (img.ndim == 3 and self.rgb and img.shape[2] == 3) + ) + ): raise ValueError("Incorrect image shape for publisher") # img = np.require(img, dtype=self._type, requirements='C') super().reply(img) @@ -175,10 +243,19 @@ def reply(self, img: np.ndarray): @Servers.register("AudioChunk", "yarp") class YarpAudioChunkServer(YarpNativeObjectServer): - def __init__(self, name: str, out_topic: str, carrier: Literal["tcp", "udp", "mcast"] = "tcp", - out_topic_connect: Optional[str] = None, persistent: bool = True, - channels: int = 1, rate: int = 44100, chunk: int = -1, - deserializer_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + out_topic: str, + carrier: Literal["tcp", "udp", "mcast"] = "tcp", + out_topic_connect: Optional[str] = None, + persistent: bool = True, + channels: int = 1, + rate: int = 44100, + chunk: int = -1, + deserializer_kwargs: Optional[dict] = None, + **kwargs, + ): """ Specific server handling audio data as numpy arrays, serializing them to JSON strings for transmission. @@ -193,7 +270,15 @@ def __init__(self, name: str, out_topic: str, carrier: Literal["tcp", "udp", "mc :param chunk: int: Number of samples in the audio chunk. Default is -1 (use the chunk size of the received audio) :param deserializer_kwargs: dict: Additional kwargs for the deserializer """ - super().__init__(name, out_topic, carrier=carrier, out_topic_connect=out_topic_connect, persistent=persistent, deserializer_kwargs=deserializer_kwargs, **kwargs) + super().__init__( + name, + out_topic, + carrier=carrier, + out_topic_connect=out_topic_connect, + persistent=persistent, + deserializer_kwargs=deserializer_kwargs, + **kwargs, + ) self.channels = channels self.rate = rate self.chunk = chunk @@ -214,5 +299,5 @@ def reply(self, aud: Tuple[np.ndarray, int]): self.channels = channels if self.channels == -1 else self.channels if 0 < self.chunk != chunk or 0 < self.channels != channels: raise ValueError("Incorrect audio shape for publisher") - aud = np.require(aud, dtype=np.float32, requirements='C') + aud = np.require(aud, dtype=np.float32, requirements="C") super().reply((chunk, channels, rate, aud)) diff --git a/wrapyfi/servers/zeromq.py b/wrapyfi/servers/zeromq.py index 32271a3..4c21eb8 100644 --- a/wrapyfi/servers/zeromq.py +++ b/wrapyfi/servers/zeromq.py @@ -16,16 +16,27 @@ SOCKET_IP = os.environ.get("WRAPYFI_ZEROMQ_SOCKET_IP", "127.0.0.1") SOCKET_PUB_PORT = int(os.environ.get("WRAPYFI_ZEROMQ_SOCKET_REQ_PORT", 5558)) SOCKET_SUB_PORT = int(os.environ.get("WRAPYFI_ZEROMQ_SOCKET_REP_PORT", 5559)) -START_PROXY_BROKER = os.environ.get("WRAPYFI_ZEROMQ_START_PROXY_BROKER", True) != "False" +START_PROXY_BROKER = ( + os.environ.get("WRAPYFI_ZEROMQ_START_PROXY_BROKER", True) != "False" +) PROXY_BROKER_SPAWN = os.environ.get("WRAPYFI_ZEROMQ_PROXY_BROKER_SPAWN", "process") WATCHDOG_POLL_REPEAT = None class ZeroMQServer(Server): - def __init__(self, name: str, out_topic: str, carrier: str = "tcp", - socket_ip: str = SOCKET_IP, socket_rep_port: int = SOCKET_PUB_PORT, socket_req_port: int = SOCKET_SUB_PORT, - start_proxy_broker: bool = START_PROXY_BROKER, proxy_broker_spawn: bool = PROXY_BROKER_SPAWN, - zeromq_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + out_topic: str, + carrier: str = "tcp", + socket_ip: str = SOCKET_IP, + socket_rep_port: int = SOCKET_PUB_PORT, + socket_req_port: int = SOCKET_SUB_PORT, + start_proxy_broker: bool = START_PROXY_BROKER, + proxy_broker_spawn: bool = PROXY_BROKER_SPAWN, + zeromq_kwargs: Optional[dict] = None, + **kwargs, + ): """ Initialize the server and start the device broker if necessary. @@ -41,20 +52,26 @@ def __init__(self, name: str, out_topic: str, carrier: str = "tcp", :param kwargs: dict: Additional kwargs for the server """ if out_topic != "": - logging.warning(f"[ZeroMQ] ZeroMQ does not support topics for the REQ/REP pattern. Topic {out_topic} removed") + logging.warning( + f"[ZeroMQ] ZeroMQ does not support topics for the REQ/REP pattern. Topic {out_topic} removed" + ) out_topic = "" if carrier or carrier != "tcp": - logging.warning("[ZeroMQ] ZeroMQ does not support other carriers than TCP for REQ/REP pattern. Using TCP.") + logging.warning( + "[ZeroMQ] ZeroMQ does not support other carriers than TCP for REQ/REP pattern. Using TCP." + ) carrier = "tcp" super().__init__(name, out_topic, carrier=carrier, **kwargs) self.socket_rep_address = f"{carrier}://{socket_ip}:{socket_rep_port}" self.socket_req_address = f"{carrier}://{socket_ip}:{socket_req_port}" if start_proxy_broker: - ZeroMQMiddlewareReqRep.activate(socket_rep_address=self.socket_rep_address, - socket_req_address=self.socket_req_address, - proxy_broker_spawn=proxy_broker_spawn, - **zeromq_kwargs or {}) + ZeroMQMiddlewareReqRep.activate( + socket_rep_address=self.socket_rep_address, + socket_req_address=self.socket_req_address, + proxy_broker_spawn=proxy_broker_spawn, + **zeromq_kwargs or {}, + ) else: ZeroMQMiddlewareReqRep.activate(**zeromq_kwargs or {}) @@ -72,8 +89,15 @@ def __del__(self): @Servers.register("NativeObject", "zeromq") class ZeroMQNativeObjectServer(ZeroMQServer): - def __init__(self, name: str, out_topic: str, carrier: str = "tcp", - serializer_kwargs: Optional[dict] = None, deserializer_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + out_topic: str, + carrier: str = "tcp", + serializer_kwargs: Optional[dict] = None, + deserializer_kwargs: Optional[dict] = None, + **kwargs, + ): """ Specific server handling native Python objects, serializing them to JSON strings for transmission. @@ -98,9 +122,13 @@ def establish(self, **kwargs): self._socket = zmq.Context().instance().socket(zmq.REP) for socket_property in ZeroMQMiddlewareReqRep().zeromq_kwargs.items(): if isinstance(socket_property[1], str): - self._socket.setsockopt_string(getattr(zmq, socket_property[0]), socket_property[1]) + self._socket.setsockopt_string( + getattr(zmq, socket_property[0]), socket_property[1] + ) else: - self._socket.setsockopt(getattr(zmq, socket_property[0]), socket_property[1]) + self._socket.setsockopt( + getattr(zmq, socket_property[0]), socket_property[1] + ) self._socket.connect(self.socket_req_address) self.established = True @@ -116,7 +144,11 @@ def await_request(self, *args, **kwargs): """ message = self._socket.recv_string() try: - request = json.loads(message, object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs) + request = json.loads( + message, + object_hook=self._plugin_decoder_hook, + **self._deserializer_kwargs, + ) args, kwargs = request return args, kwargs except json.JSONDecodeError as e: @@ -137,9 +169,19 @@ def reply(self, obj): @Servers.register("Image", "zeromq") class ZeroMQImageServer(ZeroMQNativeObjectServer): - def __init__(self, name: str, out_topic: str, carrier: str = "tcp", - width: int = -1, height: int = -1, rgb: bool = True, fp: bool = False, jpg: bool = False, - deserializer_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + out_topic: str, + carrier: str = "tcp", + width: int = -1, + height: int = -1, + rgb: bool = True, + fp: bool = False, + jpg: bool = False, + deserializer_kwargs: Optional[dict] = None, + **kwargs, + ): """ Specific server handling image data as numpy arrays, serializing them to JSON strings for transmission. @@ -153,7 +195,13 @@ def __init__(self, name: str, out_topic: str, carrier: str = "tcp", :param jpg: bool: True if the image should be decompressed from JPG. Default is False :param deserializer_kwargs: dict: Additional kwargs for the deserializer """ - super().__init__(name, out_topic, carrier=carrier, deserializer_kwargs=deserializer_kwargs, **kwargs) + super().__init__( + name, + out_topic, + carrier=carrier, + deserializer_kwargs=deserializer_kwargs, + **kwargs, + ) self.width = width self.height = height @@ -173,15 +221,21 @@ def reply(self, img: np.ndarray): logging.warning("[ZeroMQ] Image is None. Skipping reply.") return - if 0 < self.width != img.shape[1] or 0 < self.height != img.shape[0] or \ - not ((img.ndim == 2 and not self.rgb) or (img.ndim == 3 and self.rgb and img.shape[2] == 3)): + if ( + 0 < self.width != img.shape[1] + or 0 < self.height != img.shape[0] + or not ( + (img.ndim == 2 and not self.rgb) + or (img.ndim == 3 and self.rgb and img.shape[2] == 3) + ) + ): raise ValueError("Incorrect image shape for server reply") - if not img.flags['C_CONTIGUOUS']: + if not img.flags["C_CONTIGUOUS"]: img = np.ascontiguousarray(img) if self.jpg: - _, img_encoded = cv2.imencode('.jpg', img) + _, img_encoded = cv2.imencode(".jpg", img) img_bytes = img_encoded.tobytes() self._socket.send(img_bytes) else: @@ -193,9 +247,17 @@ def reply(self, img: np.ndarray): @Servers.register("AudioChunk", "zeromq") class ZeroMQAudioChunkServer(ZeroMQNativeObjectServer): - def __init__(self, name: str, out_topic: str, carrier: str = "tcp", - channels: int = 1, rate: int = 44100, chunk: int = -1, - deserializer_kwargs: Optional[dict] = None, **kwargs): + def __init__( + self, + name: str, + out_topic: str, + carrier: str = "tcp", + channels: int = 1, + rate: int = 44100, + chunk: int = -1, + deserializer_kwargs: Optional[dict] = None, + **kwargs, + ): """ Specific server handling audio data as numpy arrays, serializing them to JSON strings for transmission. @@ -207,7 +269,13 @@ def __init__(self, name: str, out_topic: str, carrier: str = "tcp", :param chunk: int: Number of samples in the audio chunk. Default is -1 (use the chunk size of the received audio) :param deserializer_kwargs: dict: Additional kwargs for the deserializer """ - super().__init__(name, out_topic, carrier=carrier, deserializer_kwargs=deserializer_kwargs, **kwargs) + super().__init__( + name, + out_topic, + carrier=carrier, + deserializer_kwargs=deserializer_kwargs, + **kwargs, + ) self.channels = channels self.rate = rate self.chunk = chunk @@ -226,8 +294,8 @@ def reply(self, aud: Tuple[np.ndarray, int]): self.channels = channels if self.channels == -1 else self.channels if 0 < self.chunk != chunk or 0 < self.channels != channels: raise ValueError("Incorrect audio shape for publisher") - aud = np.require(aud, dtype=np.float32, requirements='C') + aud = np.require(aud, dtype=np.float32, requirements="C") aud_list = aud.tolist() aud_json = json.dumps({"aud": (int(chunk), int(channels), int(rate), aud_list)}) - self._socket.send_string(aud_json) \ No newline at end of file + self._socket.send_string(aud_json) diff --git a/wrapyfi/standalone/zeromq_param_server.py b/wrapyfi/standalone/zeromq_param_server.py index c0238bc..d50e168 100644 --- a/wrapyfi/standalone/zeromq_param_server.py +++ b/wrapyfi/standalone/zeromq_param_server.py @@ -11,12 +11,12 @@ def access_nested_dict(d, keys): - return reduce(lambda x,y: x[y], keys[:-1], d)[keys[-1]] + return reduce(lambda x, y: x[y], keys[:-1], d)[keys[-1]] def parse_prefix(param_full: str, topics: dict): # split the string by '/' to get the individual components - components = param_full.split('/') + components = param_full.split("/") # If there is only one component in the list, add it to the topics dictionary # as a value and return if len(components) == 1: @@ -45,29 +45,44 @@ def reverse_parse_prefix(topics: dict, prefix: str = ""): # If the value of the key is a string, add the key and value to the list # of strings as a '/' separated string if isinstance(topics[key], str): - strings.append(key + '/' + topics[key]) + strings.append(key + "/" + topics[key]) # If the value of the key is a dictionary, recursively call generate_strings # with the value and add the returned list of strings to the list of strings # with the key as a prefix else: nested_strings = reverse_parse_prefix(topics[key]) for string in nested_strings: - strings.append(key + '/' + string) + strings.append(key + "/" + string) # return the list of strings filtered for repeated /'s - return [prefix + string.replace('//', '/') for string in strings] - - -def main(role, param_ip, param_pub_port, param_sub_port, param_reqrep_port, param_poll_interval, param_poll_repeat, **kwargs): + return [prefix + string.replace("//", "/") for string in strings] + + +def main( + role, + param_ip, + param_pub_port, + param_sub_port, + param_reqrep_port, + param_poll_interval, + param_poll_repeat, + **kwargs, +): if role == "server": param_pub_address = f"tcp://{param_ip}:{param_pub_port}" param_sub_address = f"tcp://{param_ip}:{param_sub_port}" param_reqrep_address = f"tcp://{param_ip}:{param_reqrep_port}" - ZeroMQMiddlewareParamServer.activate(**{"param_pub_address": param_pub_address, - "param_sub_address": param_sub_address, - "param_reqrep_address": param_reqrep_address, - "param_poll_interval": param_poll_interval, - "proxy_broker_spawn": "process", "verbose": True}, **kwargs) + ZeroMQMiddlewareParamServer.activate( + **{ + "param_pub_address": param_pub_address, + "param_sub_address": param_sub_address, + "param_reqrep_address": param_reqrep_address, + "param_poll_interval": param_poll_interval, + "proxy_broker_spawn": "process", + "verbose": True, + }, + **kwargs, + ) while True: pass @@ -87,17 +102,20 @@ def main(role, param_ip, param_pub_port, param_sub_port, param_reqrep_port, para while True: # send a request to the request server starting with write, read, delete, set or get - new_commands = str(input(f"input command: (default - {default_command})")) or default_command + new_commands = ( + str(input(f"input command: (default - {default_command})")) + or default_command + ) default_command = new_commands # pass a list of commands to the request server e.g. ['set /foo/bar/21', 'set /foo/bar/baz/42', 'get /foo/bar', 'delete /foo/bar/baz', 'read /foo'] - if '[' in new_commands: - new_commands = new_commands.replace('[\'', '').replace('\']', '') - new_commands = new_commands.split('\', \'') + if "[" in new_commands: + new_commands = new_commands.replace("['", "").replace("']", "") + new_commands = new_commands.split("', '") # write supports writing full tree dicts e.g. write {'foo': {'bar': {'': '42', 'car': '43'}}} - elif 'write' in new_commands: - new_commands = new_commands.replace('write ', '') - if new_commands.startswith('{'): - new_commands = new_commands.replace('\n', '').replace('\t', '') + elif "write" in new_commands: + new_commands = new_commands.replace("write ", "") + if new_commands.startswith("{"): + new_commands = new_commands.replace("\n", "").replace("\t", "") new_commands = json.loads(new_commands) new_commands = reverse_parse_prefix(new_commands, prefix="set ") # pass commands directly to the request server e.g. set /foo/bar/42 @@ -111,7 +129,7 @@ def main(role, param_ip, param_pub_port, param_sub_port, param_reqrep_port, para print("Received reply from server: %s" % reply) if "success:::" in reply: topics = {} - current_prefix = reply.split(":::")[1].encode('utf-8') + current_prefix = reply.split(":::")[1].encode("utf-8") param_server.subscribe(current_prefix) # time.sleep(0.2) prev_params = Counter() @@ -121,40 +139,91 @@ def main(role, param_ip, param_pub_port, param_sub_port, param_reqrep_port, para try: prefix, param, value = param_server.recv_multipart() except zmq.error.Again: - print("No new parameters received. Need atleast one topic to subscribe to.") + print( + "No new parameters received. Need atleast one topic to subscribe to." + ) break # construct the full parameter name with the namespace prefix - prefix, param, value = prefix.decode('utf-8'), param.decode('utf-8'), value.decode('utf-8') - if (param in prev_params or param is None) and prev_params[param] == param_poll_repeat: + prefix, param, value = ( + prefix.decode("utf-8"), + param.decode("utf-8"), + value.decode("utf-8"), + ) + if (param in prev_params or param is None) and prev_params[ + param + ] == param_poll_repeat: break prev_params[param] += 1 full_param = "/".join([prefix, param, value]) parse_prefix(full_param, topics) # print("Received update: %s" % (full_param)) try: - topic_results = access_nested_dict(topics, current_prefix.decode('utf-8').split('/')) + topic_results = access_nested_dict( + topics, current_prefix.decode("utf-8").split("/") + ) except KeyError: - print(current_prefix.decode('utf-8') + " has no children") + print(current_prefix.decode("utf-8") + " has no children") continue - print(json.dumps({current_prefix.decode('utf-8'): topic_results}, indent=None, default=str)) - print("Reverse parse (always in set mode regardless of the transmitted command):") - print(reverse_parse_prefix(topic_results, prefix=f"set {current_prefix.decode('utf-8')}/")) + print( + json.dumps( + {current_prefix.decode("utf-8"): topic_results}, + indent=None, + default=str, + ) + ) + print( + "Reverse parse (always in set mode regardless of the transmitted command):" + ) + print( + reverse_parse_prefix( + topic_results, + prefix=f"set {current_prefix.decode('utf-8')}/", + ) + ) # close the connection param_server.unsubscribe(current_prefix) def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument("--param_ip", type=str, default="127.0.0.1", help="Parameter sever IP address") - parser.add_argument("--param_pub_port", type=int, default=5655, help="Socket publishing (PUB) port") - parser.add_argument("--param_sub_port", type=int, default=5656, help="Socket subscription (SUB) port") - parser.add_argument("--param_reqrep_port", type=int, default=5659, help="Socket request (REQ)/reply (REP) port") - parser.add_argument("--param_poll_interval", type=int, default=1, help="PUB/SUB poll interval in milliseconds " - "(only used when role=server)") - parser.add_argument("--param_poll_repeat", type=int, default=5, help="SUB poll repetition in terms of no. of " - "attempts (only used when role=client)") - parser.add_argument("--role", type=str, default="server", choices=["server", "client"], - help="Parameter server in serving (server) or requesting (client) mode") + parser.add_argument( + "--param_ip", type=str, default="127.0.0.1", help="Parameter sever IP address" + ) + parser.add_argument( + "--param_pub_port", type=int, default=5655, help="Socket publishing (PUB) port" + ) + parser.add_argument( + "--param_sub_port", + type=int, + default=5656, + help="Socket subscription (SUB) port", + ) + parser.add_argument( + "--param_reqrep_port", + type=int, + default=5659, + help="Socket request (REQ)/reply (REP) port", + ) + parser.add_argument( + "--param_poll_interval", + type=int, + default=1, + help="PUB/SUB poll interval in milliseconds " "(only used when role=server)", + ) + parser.add_argument( + "--param_poll_repeat", + type=int, + default=5, + help="SUB poll repetition in terms of no. of " + "attempts (only used when role=client)", + ) + parser.add_argument( + "--role", + type=str, + default="server", + choices=["server", "client"], + help="Parameter server in serving (server) or requesting (client) mode", + ) return parser.parse_args() diff --git a/wrapyfi/standalone/zeromq_proxy_broker.py b/wrapyfi/standalone/zeromq_proxy_broker.py index ccde90e..88c0afb 100644 --- a/wrapyfi/standalone/zeromq_proxy_broker.py +++ b/wrapyfi/standalone/zeromq_proxy_broker.py @@ -1,5 +1,6 @@ import argparse import logging + logging.getLogger().setLevel(logging.INFO) import zmq @@ -7,22 +8,42 @@ from wrapyfi.middlewares.zeromq import ZeroMQMiddlewareReqRep, ZeroMQMiddlewarePubSub -def main(comm_type, socket_ip, socket_pub_port, socket_sub_port, socket_rep_port, socket_req_port, **kwargs): +def main( + comm_type, + socket_ip, + socket_pub_port, + socket_sub_port, + socket_rep_port, + socket_req_port, + **kwargs, +): if comm_type == "pubsub": socket_pub_address = f"tcp://{socket_ip}:{socket_pub_port}" socket_sub_address = f"tcp://{socket_ip}:{socket_sub_port}" - ZeroMQMiddlewarePubSub.activate(**{"socket_pub_address": socket_pub_address, - "socket_sub_address": socket_sub_address, - "proxy_broker_spawn": "process", "verbose": True}, **kwargs) + ZeroMQMiddlewarePubSub.activate( + **{ + "socket_pub_address": socket_pub_address, + "socket_sub_address": socket_sub_address, + "proxy_broker_spawn": "process", + "verbose": True, + }, + **kwargs, + ) while True: pass elif comm_type == "reqrep": socket_rep_address = f"tcp://{socket_ip}:{socket_rep_port}" socket_req_address = f"tcp://{socket_ip}:{socket_req_port}" - ZeroMQMiddlewareReqRep.activate(**{"socket_rep_address": socket_rep_address, - "socket_req_address": socket_req_address, - "proxy_broker_spawn": "process", "verbose": True}, **kwargs) + ZeroMQMiddlewareReqRep.activate( + **{ + "socket_rep_address": socket_rep_address, + "socket_req_address": socket_req_address, + "proxy_broker_spawn": "process", + "verbose": True, + }, + **kwargs, + ) while True: pass @@ -44,27 +65,41 @@ def main(comm_type, socket_ip, socket_pub_port, socket_sub_port, socket_rep_port event = dict(poller.poll(1000)) if xpub_socket in event: message = xpub_socket.recv_multipart() - #print("[ZeroMQ BROKER] xpub_socket recv message: %r" % message) + # print("[ZeroMQ BROKER] xpub_socket recv message: %r" % message) xsub_socket.send_multipart(message) if xsub_socket in event: message = xsub_socket.recv_multipart() - #print("[ZeroMQ BROKER] xsub_socket recv message: %r" % message) + # print("[ZeroMQ BROKER] xsub_socket recv message: %r" % message) xpub_socket.send_multipart(message) + def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument("--socket_ip", type=str, default="127.0.0.1", help="Socket IP address") - parser.add_argument("--socket_pub_port", type=int, default=5555, help="Socket publishing port") - parser.add_argument("--socket_sub_port", type=int, default=5556, help="Socket subscription port") - parser.add_argument("--socket_rep_port", type=int, default=5559, help="Socket reply port") - parser.add_argument("--socket_req_port", type=int, default=5560, help="Socket request port") - parser.add_argument("--comm_type", type=str, default="pubsub", choices=["pubsub", "pubsubpoll", "reqrep"], - help="The zeromq communication pattern") + parser.add_argument( + "--socket_ip", type=str, default="127.0.0.1", help="Socket IP address" + ) + parser.add_argument( + "--socket_pub_port", type=int, default=5555, help="Socket publishing port" + ) + parser.add_argument( + "--socket_sub_port", type=int, default=5556, help="Socket subscription port" + ) + parser.add_argument( + "--socket_rep_port", type=int, default=5559, help="Socket reply port" + ) + parser.add_argument( + "--socket_req_port", type=int, default=5560, help="Socket request port" + ) + parser.add_argument( + "--comm_type", + type=str, + default="pubsub", + choices=["pubsub", "pubsubpoll", "reqrep"], + help="The zeromq communication pattern", + ) return parser.parse_args() if __name__ == "__main__": args = parse_args() main(**vars(args)) - - diff --git a/wrapyfi/standalone/zeromq_pubsub_topic_monitor.py b/wrapyfi/standalone/zeromq_pubsub_topic_monitor.py index a20d966..b119f57 100644 --- a/wrapyfi/standalone/zeromq_pubsub_topic_monitor.py +++ b/wrapyfi/standalone/zeromq_pubsub_topic_monitor.py @@ -16,8 +16,8 @@ def monitor_active_connections(socket_pub_address, topic): while True: topic, message = subscriber.recv_multipart() - topic = topic.decode('utf-8') - data = json.loads(message.decode('utf-8')) + topic = topic.decode("utf-8") + data = json.loads(message.decode("utf-8")) logging.info(f"[ZeroMQ] Topic: {topic}, Data: {data}") except Exception as e: @@ -26,15 +26,23 @@ def monitor_active_connections(socket_pub_address, topic): def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument("--socket_pub_address", type=str, default="tcp://127.0.0.1:5555", - help="Socket subscription address") - parser.add_argument("--topic", type=str, default="ZEROMQ/CONNECTIONS", - help="Topic to subscribe. The ZEROMQ/CONNECTIONS topic monitors subscribers to a any topic on " - "the sub_address. Note that zeromq_proxy_broker.py in pubsub mode (either as a standalone" - "or spawned by default by any publisher) must be running on socket_pub_address") + parser.add_argument( + "--socket_pub_address", + type=str, + default="tcp://127.0.0.1:5555", + help="Socket subscription address", + ) + parser.add_argument( + "--topic", + type=str, + default="ZEROMQ/CONNECTIONS", + help="Topic to subscribe. The ZEROMQ/CONNECTIONS topic monitors subscribers to a any topic on " + "the sub_address. Note that zeromq_proxy_broker.py in pubsub mode (either as a standalone" + "or spawned by default by any publisher) must be running on socket_pub_address", + ) return parser.parse_args() if __name__ == "__main__": args = parse_args() - monitor_active_connections(args.socket_pub_address, args.topic) \ No newline at end of file + monitor_active_connections(args.socket_pub_address, args.topic) diff --git a/wrapyfi/tests/test_middleware.py b/wrapyfi/tests/test_middleware.py index 5398f31..7190de4 100644 --- a/wrapyfi/tests/test_middleware.py +++ b/wrapyfi/tests/test_middleware.py @@ -13,6 +13,7 @@ def test_publish_listen(self): receive messages using the PUB/SUB pattern. """ import wrapyfi.tests.tools.class_test as class_test + # importlib.reload(wrapyfi) # importlib.reload(class_test) test_func = class_test.test_func @@ -23,15 +24,29 @@ def test_publish_listen(self): self.skipTest(f"{self.MWARE} not installed") listen_queue = Queue(maxsize=10) - test_lsn = multiprocessing.Process(target=test_func, args=(listen_queue,), - kwargs={"mode": "listen", "mware": self.MWARE, "iterations": 10, - "should_wait": True}) + test_lsn = multiprocessing.Process( + target=test_func, + args=(listen_queue,), + kwargs={ + "mode": "listen", + "mware": self.MWARE, + "iterations": 10, + "should_wait": True, + }, + ) # test_lsn.daemon = True publish_queue = Queue(maxsize=10) - test_pub = multiprocessing.Process(target=test_func, args=(publish_queue,), - kwargs={"mode": "publish", "mware": self.MWARE, "iterations": 10, - "should_wait": True}) + test_pub = multiprocessing.Process( + target=test_func, + args=(publish_queue,), + kwargs={ + "mode": "publish", + "mware": self.MWARE, + "iterations": 10, + "should_wait": True, + }, + ) # test_pub.daemon = True test_lsn.start() @@ -40,7 +55,9 @@ def test_publish_listen(self): test_pub.join() for i in range(10): try: - self.assertDictEqual(listen_queue.get(timeout=3), publish_queue.get(timeout=3)) + self.assertDictEqual( + listen_queue.get(timeout=3), publish_queue.get(timeout=3) + ) except queue.Empty: self.assertEqual(i, 9) @@ -50,6 +67,7 @@ class ROS2TestMiddleware(ZeroMQTestMiddleware): Test the ROS 2 wrapper. This test class inherits from the ZeroMQ test class, so all tests from the ZeroMQ test class are also run for the ROS 2 wrapper. """ + MWARE = "ros2" @@ -58,6 +76,7 @@ class YarpTestMiddleware(ZeroMQTestMiddleware): Test the YARP wrapper. This test class inherits from the ZeroMQ test class, so all tests from the ZeroMQ test class are also run for the YARP wrapper. """ + MWARE = "yarp" @@ -66,8 +85,9 @@ class ROSTestMiddleware(ZeroMQTestMiddleware): Test the ROS wrapper. This test class inherits from the ZeroMQ test class, so all tests from the ZeroMQ test class are also run for the ROS wrapper. """ + MWARE = "ros" -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/wrapyfi/tests/test_wrapper.py b/wrapyfi/tests/test_wrapper.py index dfd04cc..116f918 100644 --- a/wrapyfi/tests/test_wrapper.py +++ b/wrapyfi/tests/test_wrapper.py @@ -12,6 +12,7 @@ def test_activate_communication(self): the connection should no longer be active. The ``activate_communication`` method should also be callable multiple times on the same function. """ import wrapyfi.tests.tools.class_test as class_test + Test = class_test.Test Test.close_all_instances() if self.MWARE not in Test.get_communicators(): @@ -21,7 +22,9 @@ def test_activate_communication(self): test, test2, test3 = Test(), Test(), Test() test.activate_communication(test.exchange_object, mode="publish") test2.activate_communication(test2.exchange_object, mode=None) - test2.activate_communication(test2.exchange_object, mode="publish") # override None + test2.activate_communication( + test2.exchange_object, mode="publish" + ) # override None test3.activate_communication(test3.exchange_object, mode="disable") # ensure both nodes are registered @@ -29,9 +32,11 @@ def test_activate_communication(self): self.assertIn("Test.exchange_object.2", Test._MiddlewareCommunicator__registry) self.assertIn("Test.exchange_object.3", Test._MiddlewareCommunicator__registry) for i in range(2): - msg_object, = test.exchange_object(mware=self.MWARE) - msg_object2, = test2.exchange_object(mware=self.MWARE, topic="/test/test_native_exchange2") - msg_object3, = test3.exchange_object(mware=self.MWARE) + (msg_object,) = test.exchange_object(mware=self.MWARE) + (msg_object2,) = test2.exchange_object( + mware=self.MWARE, topic="/test/test_native_exchange2" + ) + (msg_object3,) = test3.exchange_object(mware=self.MWARE) self.assertDictEqual(msg_object, msg_object2) self.assertIsNone(msg_object3) test.close() @@ -40,10 +45,14 @@ def test_activate_communication(self): # ensure only one node is registered self.assertIn("Test.exchange_object", Test._MiddlewareCommunicator__registry) self.assertIn("Test.exchange_object.2", Test._MiddlewareCommunicator__registry) - self.assertNotIn("Test.exchange_object.3", Test._MiddlewareCommunicator__registry) + self.assertNotIn( + "Test.exchange_object.3", Test._MiddlewareCommunicator__registry + ) for i in range(2): - msg_object2, = test2.exchange_object(mware=self.MWARE, topic="/test/test_native_exchange2") + (msg_object2,) = test2.exchange_object( + mware=self.MWARE, topic="/test/test_native_exchange2" + ) test2.close() test3.close() del test2, test3 @@ -60,6 +69,7 @@ def test_close(self): multiple times on different instances of a class. """ import wrapyfi.tests.tools.class_test as class_test + Test = class_test.Test Test.close_all_instances() @@ -82,11 +92,17 @@ def test_close(self): del test2 self.assertIn("Test.exchange_object", Test._MiddlewareCommunicator__registry) self.assertIn("Test.exchange_object.2", Test._MiddlewareCommunicator__registry) - self.assertNotIn("Test.exchange_object.3", Test._MiddlewareCommunicator__registry) + self.assertNotIn( + "Test.exchange_object.3", Test._MiddlewareCommunicator__registry + ) # check that the first & third node are still registered - test_id = Test._MiddlewareCommunicator__registry["Test.exchange_object"]["__WRAPYFI_INSTANCES"].index(hex(id(test))) - test3_id = Test._MiddlewareCommunicator__registry["Test.exchange_object"]["__WRAPYFI_INSTANCES"].index(hex(id(test3))) + test_id = Test._MiddlewareCommunicator__registry["Test.exchange_object"][ + "__WRAPYFI_INSTANCES" + ].index(hex(id(test))) + test3_id = Test._MiddlewareCommunicator__registry["Test.exchange_object"][ + "__WRAPYFI_INSTANCES" + ].index(hex(id(test3))) self.assertEqual(test_id, 0) self.assertEqual(test3_id, 1) @@ -100,8 +116,12 @@ def test_close(self): self.assertIn("Test.exchange_object", Test._MiddlewareCommunicator__registry) # check that the first & third node are still registered - test_id = Test._MiddlewareCommunicator__registry["Test.exchange_object"]["__WRAPYFI_INSTANCES"].index(hex(id(test))) - test3_id = Test._MiddlewareCommunicator__registry["Test.exchange_object"]["__WRAPYFI_INSTANCES"].index(hex(id(test3))) + test_id = Test._MiddlewareCommunicator__registry["Test.exchange_object"][ + "__WRAPYFI_INSTANCES" + ].index(hex(id(test))) + test3_id = Test._MiddlewareCommunicator__registry["Test.exchange_object"][ + "__WRAPYFI_INSTANCES" + ].index(hex(id(test3))) self.assertEqual(test_id, 1) self.assertEqual(test3_id, 0) @@ -112,14 +132,22 @@ def test_close(self): del test self.assertIn("Test.exchange_object", Test._MiddlewareCommunicator__registry) - self.assertNotIn("Test.exchange_object.2", Test._MiddlewareCommunicator__registry) - self.assertNotIn("Test.exchange_object.3", Test._MiddlewareCommunicator__registry) + self.assertNotIn( + "Test.exchange_object.2", Test._MiddlewareCommunicator__registry + ) + self.assertNotIn( + "Test.exchange_object.3", Test._MiddlewareCommunicator__registry + ) # close all instances after all have been closed Test.close_all_instances() self.assertIn("Test.exchange_object", Test._MiddlewareCommunicator__registry) - self.assertNotIn("Test.exchange_object.2", Test._MiddlewareCommunicator__registry) - self.assertNotIn("Test.exchange_object.3", Test._MiddlewareCommunicator__registry) + self.assertNotIn( + "Test.exchange_object.2", Test._MiddlewareCommunicator__registry + ) + self.assertNotIn( + "Test.exchange_object.3", Test._MiddlewareCommunicator__registry + ) def test_get_communicators(self): """ @@ -127,6 +155,7 @@ def test_get_communicators(self): method should return a list of all available middleware wrappers. """ import wrapyfi.tests.tools.class_test as class_test + # importlib.reload(class_test) Test = class_test.Test @@ -144,6 +173,7 @@ class ROS2TestWrapper(ZeroMQTestWrapper): Test the ROS 2 wrapper. This test class inherits from the ZeroMQ test class, so all tests from the ZeroMQ test class are also run for the ROS 2 wrapper. """ + MWARE = "ros2" @@ -152,6 +182,7 @@ class YarpTestWrapper(ZeroMQTestWrapper): Test the YARP wrapper. This test class inherits from the ZeroMQ test class, so all tests from the ZeroMQ test class are also run for the YARP wrapper. """ + MWARE = "yarp" @@ -160,8 +191,9 @@ class ROSTestWrapper(ZeroMQTestWrapper): Test the ROS wrapper. This test class inherits from the ZeroMQ test class, so all tests from the ZeroMQ test class are also run for the ROS wrapper. """ + MWARE = "ros" -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/wrapyfi/tests/tools/benchmarking_native_object.py b/wrapyfi/tests/tools/benchmarking_native_object.py index 0bac873..138dd68 100644 --- a/wrapyfi/tests/tools/benchmarking_native_object.py +++ b/wrapyfi/tests/tools/benchmarking_native_object.py @@ -1,5 +1,6 @@ import argparse import time + try: import numpy as np import pandas as pd @@ -7,8 +8,9 @@ print("Install pandas and NumPy before running this script.") try: import tensorflow as tf + # avoid allocating all GPU memory assuming tf>=2.2 - gpus = tf.config.experimental.list_physical_devices('GPU') + gpus = tf.config.experimental.list_physical_devices("GPU") for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) except ImportError: @@ -28,11 +30,16 @@ def get_numpy_object(dims): @staticmethod def get_pandas_object(dims): - return {"pandas": pd.DataFrame(np.ones(dims), index=None, columns=list(range(dims[-1])))} + return { + "pandas": pd.DataFrame( + np.ones(dims), index=None, columns=list(range(dims[-1])) + ) + } @staticmethod def get_pillow_object(dims): from PIL import Image + return {"pillow": Image.fromarray(np.ones(dims, dtype=np.uint8))} @staticmethod @@ -42,96 +49,178 @@ def get_tensorflow_object(dims): @staticmethod def get_jax_object(dims): import jax as jx + return {"jax": jx.numpy.ones(dims)} @staticmethod def get_mxnet_object(dims): import mxnet as mx + return {"mxnet": mx.nd.ones(dims)} @staticmethod def get_mxnet_gpu_object(dims, gpu=0): import mxnet as mx + return {"mxnet_gpu": mx.nd.ones(dims, ctx=mx.gpu(gpu))} @staticmethod def get_pytorch_object(dims): import torch as th + return {"pytorch": th.ones(dims)} @staticmethod def get_pytorch_gpu_object(dims, gpu=0): import torch as th + return {"pytorch_gpu": th.ones(dims, device=f"cuda:{gpu}")} @staticmethod def get_paddle_object(dims): import paddle as pa + return {"paddle": pa.Tensor(pa.ones(dims), place=pa.CPUPlace())} @staticmethod def get_paddle_gpu_object(dims, gpu=0): import paddle as pa + return {"paddle_gpu": pa.Tensor(pa.zeros(dims), place=pa.CUDAPlace(gpu))} def get_all_objects(self, count, plugin_name): - obj = {"count": count, - "time": time.time()} - obj.update(**getattr(self, f"get_{plugin_name}_object")((args.height, args.width,))) + obj = {"count": count, "time": time.time()} + obj.update( + **getattr(self, f"get_{plugin_name}_object")( + ( + args.height, + args.width, + ) + ) + ) return obj - @MiddlewareCommunicator.register("NativeObject", "yarp", - "ExampleClass", "/example/get_native_objects", - carrier="tcp", should_wait=SHOULD_WAIT) + @MiddlewareCommunicator.register( + "NativeObject", + "yarp", + "ExampleClass", + "/example/get_native_objects", + carrier="tcp", + should_wait=SHOULD_WAIT, + ) def get_yarp_native_objects(self, count, plugin_name): - return self.get_all_objects(count, plugin_name), + return (self.get_all_objects(count, plugin_name),) - @MiddlewareCommunicator.register("NativeObject", "ros", - "ExampleClass", "/example/get_native_objects", - carrier="tcp", should_wait=SHOULD_WAIT) + @MiddlewareCommunicator.register( + "NativeObject", + "ros", + "ExampleClass", + "/example/get_native_objects", + carrier="tcp", + should_wait=SHOULD_WAIT, + ) def get_ros_native_objects(self, count, plugin_name): - return self.get_all_objects(count, plugin_name), + return (self.get_all_objects(count, plugin_name),) - @MiddlewareCommunicator.register("NativeObject", "ros2", - "ExampleClass", "/example/get_native_objects", - carrier="tcp", should_wait=SHOULD_WAIT) + @MiddlewareCommunicator.register( + "NativeObject", + "ros2", + "ExampleClass", + "/example/get_native_objects", + carrier="tcp", + should_wait=SHOULD_WAIT, + ) def get_ros2_native_objects(self, count, plugin_name): - return self.get_all_objects(count, plugin_name), + return (self.get_all_objects(count, plugin_name),) - @MiddlewareCommunicator.register("NativeObject", "zeromq", - "ExampleClass", "/example/get_native_objects", - carrier="tcp", should_wait=SHOULD_WAIT) + @MiddlewareCommunicator.register( + "NativeObject", + "zeromq", + "ExampleClass", + "/example/get_native_objects", + carrier="tcp", + should_wait=SHOULD_WAIT, + ) def get_zeromq_native_objects(self, count, plugin_name): - return self.get_all_objects(count, plugin_name), + return (self.get_all_objects(count, plugin_name),) def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument("--publish", dest="mode", action="store_const", const="publish", default="listen", help="Publish mode") - parser.add_argument("--listen", dest="mode", action="store_const", const="listen", default="listen", help="Listen mode (default)") - parser.add_argument("--mwares", type=str, default=list(MiddlewareCommunicator.get_communicators()), - choices=MiddlewareCommunicator.get_communicators(), nargs="+", - help="The middlewares to use for transmission") - parser.add_argument("--plugins", type=str, - default=["numpy", "pandas", "tensorflow", "jax", "mxnet", "mxnet_gpu", "pytorch", "pytorch_gpu", - "paddle", "paddle_gpu"], nargs="+", - help="The middlewares to use for transmission") - parser.add_argument("--height", type=int, default=200, help="The tensor image height") + parser.add_argument( + "--publish", + dest="mode", + action="store_const", + const="publish", + default="listen", + help="Publish mode", + ) + parser.add_argument( + "--listen", + dest="mode", + action="store_const", + const="listen", + default="listen", + help="Listen mode (default)", + ) + parser.add_argument( + "--mwares", + type=str, + default=list(MiddlewareCommunicator.get_communicators()), + choices=MiddlewareCommunicator.get_communicators(), + nargs="+", + help="The middlewares to use for transmission", + ) + parser.add_argument( + "--plugins", + type=str, + default=[ + "numpy", + "pandas", + "tensorflow", + "jax", + "mxnet", + "mxnet_gpu", + "pytorch", + "pytorch_gpu", + "paddle", + "paddle_gpu", + ], + nargs="+", + help="The middlewares to use for transmission", + ) + parser.add_argument( + "--height", type=int, default=200, help="The tensor image height" + ) parser.add_argument("--width", type=int, default=200, help="The tensor image width") - parser.add_argument("--trials", type=int, default=2000, help="Number of trials to run per middleware") - parser.add_argument("--skip-trials", type=int, default=0, help="Number of trials to skip before logging " - "to csv to avoid warmup time logging") + parser.add_argument( + "--trials", + type=int, + default=2000, + help="Number of trials to run per middleware", + ) + parser.add_argument( + "--skip-trials", + type=int, + default=0, + help="Number of trials to skip before logging " + "to csv to avoid warmup time logging", + ) return parser.parse_args() if __name__ == "__main__": args = parse_args() benchmarker = Benchmarker() - benchmark_logger = pd.DataFrame(columns=["middleware", "plugin", "timestamp", "count", "delay"]) + benchmark_logger = pd.DataFrame( + columns=["middleware", "plugin", "timestamp", "count", "delay"] + ) benchmark_iterator = {} for middleware_name in args.mwares: - benchmark_iterator[middleware_name] = getattr(benchmarker, f"get_{middleware_name}_native_objects") + benchmark_iterator[middleware_name] = getattr( + benchmarker, f"get_{middleware_name}_native_objects" + ) for middleware_name, method in benchmark_iterator.items(): benchmarker.activate_communication(method, mode=args.mode) @@ -140,20 +229,32 @@ def parse_args(): counter = -1 while True: counter += 1 - native_objects, = method(counter, plugin_name) + (native_objects,) = method(counter, plugin_name) if native_objects is not None: time_acc_native_objects.append(time.time() - native_objects["time"]) - print(f"{middleware_name} :: {plugin_name} :: delay:", time_acc_native_objects[-1], - " Length:", len(time_acc_native_objects), " Count:", native_objects["count"]) + print( + f"{middleware_name} :: {plugin_name} :: delay:", + time_acc_native_objects[-1], + " Length:", + len(time_acc_native_objects), + " Count:", + native_objects["count"], + ) if args.trials - 1 == native_objects["count"]: break if counter > args.skip_trials: - benchmark_logger = benchmark_logger.append(pd.DataFrame({"middleware": [middleware_name], - "plugin": [plugin_name], - "timestamp": [native_objects["timestamp"]], - "count": [native_objects["count"]], - "delay": [time_acc_native_objects[-1]]}), - ignore_index=True) + benchmark_logger = benchmark_logger.append( + pd.DataFrame( + { + "middleware": [middleware_name], + "plugin": [plugin_name], + "timestamp": [native_objects["timestamp"]], + "count": [native_objects["count"]], + "delay": [time_acc_native_objects[-1]], + } + ), + ignore_index=True, + ) if counter == 0: if args.mode == "publish": time.sleep(5) @@ -162,6 +263,12 @@ def parse_args(): time.sleep(0.1) time_acc_native_objects = pd.DataFrame(np.array(time_acc_native_objects)) - print(f"{middleware_name} :: {plugin_name} :: time statistics:", time_acc_native_objects.describe()) + print( + f"{middleware_name} :: {plugin_name} :: time statistics:", + time_acc_native_objects.describe(), + ) time.sleep(5) - benchmark_logger.to_csv(f"results/benchmarking_native_object_{args.mode}__{','.join(args.mwares)}__{','.join(args.plugins)}.csv", index=False) \ No newline at end of file + benchmark_logger.to_csv( + f"results/benchmarking_native_object_{args.mode}__{','.join(args.mwares)}__{','.join(args.plugins)}.csv", + index=False, + ) diff --git a/wrapyfi/tests/tools/class_test.py b/wrapyfi/tests/tools/class_test.py index 5703e37..44de6e7 100644 --- a/wrapyfi/tests/tools/class_test.py +++ b/wrapyfi/tests/tools/class_test.py @@ -6,24 +6,48 @@ class Test(MiddlewareCommunicator): - @MiddlewareCommunicator.register("NativeObject", "$mware", "Test", "$topic", - should_wait="$should_wait") - def exchange_object(self, msg=None, mware=DEFAULT_COMMUNICATOR, topic="/test/test_native_exchange", should_wait=False): - ret = {"message": msg, - "set": {'a', 1, None}, - "list": [[[3, [4], 5.677890, 1.2]]], - "string": "string of characters", - "2": 2.73211, - "dict": {"other": [None, False, 16, 4.32,]}} - return ret, + @MiddlewareCommunicator.register( + "NativeObject", "$mware", "Test", "$topic", should_wait="$should_wait" + ) + def exchange_object( + self, + msg=None, + mware=DEFAULT_COMMUNICATOR, + topic="/test/test_native_exchange", + should_wait=False, + ): + ret = { + "message": msg, + "set": {"a", 1, None}, + "list": [[[3, [4], 5.677890, 1.2]]], + "string": "string of characters", + "2": 2.73211, + "dict": { + "other": [ + None, + False, + 16, + 4.32, + ] + }, + } + return (ret,) -def test_func(queue_buffer, mode="listen", mware=DEFAULT_COMMUNICATOR, topic="/test/test_native_exchange", iterations=2, - should_wait=False): +def test_func( + queue_buffer, + mode="listen", + mware=DEFAULT_COMMUNICATOR, + topic="/test/test_native_exchange", + iterations=2, + should_wait=False, +): test = Test() test.activate_communication(test.exchange_object, mode=mode) for i in range(iterations): - my_message = test.exchange_object(msg=f"signal_idx:{i}", mware=mware, topic=topic, should_wait=should_wait) + my_message = test.exchange_object( + msg=f"signal_idx:{i}", mware=mware, topic=topic, should_wait=should_wait + ) if my_message is not None: print(f"result {mode}:", my_message[0]["message"]) queue_buffer.put(my_message[0]) diff --git a/wrapyfi/utils.py b/wrapyfi/utils.py index 95f27d0..6901413 100755 --- a/wrapyfi/utils.py +++ b/wrapyfi/utils.py @@ -12,7 +12,11 @@ lock = threading.Lock() -def deepcopy(obj: Any, exclude_keys: Optional[Union[list, tuple]] = None, shallow_keys: Optional[Union[list, tuple]] = None): +def deepcopy( + obj: Any, + exclude_keys: Optional[Union[list, tuple]] = None, + shallow_keys: Optional[Union[list, tuple]] = None, +): """ Deep copy an object, excluding specified keys. @@ -21,6 +25,7 @@ def deepcopy(obj: Any, exclude_keys: Optional[Union[list, tuple]] = None, shallo :param shallow_keys: Union[list, tuple]: A list of keys to shallow copy """ import copy + if exclude_keys is None: return copy.deepcopy(obj) else: @@ -32,7 +37,11 @@ def deepcopy(obj: Any, exclude_keys: Optional[Union[list, tuple]] = None, shallo return {deepcopy(item, exclude_keys) for item in obj} elif isinstance(obj, dict): _shallows = shallow_keys or [] - ret = {key: deepcopy(val, exclude_keys) for key, val in obj.items() if key not in exclude_keys + _shallows} + ret = { + key: deepcopy(val, exclude_keys) + for key, val in obj.items() + if key not in exclude_keys + _shallows + } ret.update({key: val for key, val in obj.items() if key in _shallows}) return ret else: @@ -46,6 +55,7 @@ def get_default_args(fnc: Callable[..., Any]): :param fnc: Callable[..., Any]: The function to get the default arguments for """ import inspect + signature = inspect.signature(fnc) return { k: v.default @@ -54,7 +64,12 @@ def get_default_args(fnc: Callable[..., Any]): } -def match_args(args: Union[list, tuple], kwargs: dict, src_args: Union[list, tuple], src_kwargs: dict): +def match_args( + args: Union[list, tuple], + kwargs: dict, + src_args: Union[list, tuple], + src_kwargs: dict, +): """ Match and Substitute Arguments and Keyword Arguments using Specified Source Values. @@ -91,7 +106,11 @@ def match_args(args: Union[list, tuple], kwargs: dict, src_args: Union[list, tup for kwarg_key, kwarg_val in kwargs.items(): if isinstance(kwarg_val, str) and "$" in kwarg_val and kwarg_val[1:].isdigit(): new_kwargs[kwarg_key] = src_args[int(kwarg_val[1:])] - elif isinstance(kwarg_val, str) and "$" in kwarg_val and not kwarg_val[1:].isdigit(): + elif ( + isinstance(kwarg_val, str) + and "$" in kwarg_val + and not kwarg_val[1:].isdigit() + ): new_kwargs[kwarg_key] = src_kwargs[kwarg_val[1:]] else: new_kwargs[kwarg_key] = kwarg_val @@ -111,14 +130,14 @@ def dynamic_module_import(modules: List[str], globals: dict): module_name = module_name[:-3] module_name = module_name.replace("/", ".") try: - module = __import__(module_name, fromlist=['*']) + module = __import__(module_name, fromlist=["*"]) except ImportError as e: # print(module_name + " could not be imported.", e) continue - if hasattr(module, '__all__'): + if hasattr(module, "__all__"): all_names = module.__all__ else: - all_names = [name for name in dir(module) if not name.startswith('_')] + all_names = [name for name in dir(module) if not name.startswith("_")] globals.update({name: getattr(module, name) for name in all_names}) @@ -128,13 +147,16 @@ class SingletonOptimized(type): Source: https://stackoverflow.com/a/6798042 """ + _instances = {} def __call__(cls, *args, **kwargs): if cls not in cls._instances: with lock: if cls not in cls._instances: - cls._instances[cls] = super(SingletonOptimized, cls).__call__(*args, **kwargs) + cls._instances[cls] = super(SingletonOptimized, cls).__call__( + *args, **kwargs + ) return cls._instances[cls] @@ -142,6 +164,7 @@ class Plugin(object): """ Base class for encoding and decoding plugins. """ + def encode(self, *args, **kwargs): """ Encode data into a base64 string. @@ -172,6 +195,7 @@ class PluginRegistrar(object): """ Class for registering encoding and decoding plugins. """ + encoder_registry = {} decoder_registry = {} @@ -182,12 +206,14 @@ def register(types=None): :param types: tuple: The type(s) to register the plugin for """ + def wrapper(cls): if types is not None: for cls_type in types: PluginRegistrar.encoder_registry[cls_type] = cls PluginRegistrar.decoder_registry[str(cls.__name__)] = cls return cls + return wrapper @staticmethod @@ -196,20 +222,30 @@ def scan(): Scan the plugins directory (Wrapyfi builtin and external) for plugins to register. This method is called automatically when the module is imported. """ - modules = glob(os.path.join(os.path.dirname(__file__), "plugins", "*.py"), recursive=True) - modules = ["wrapyfi.plugins." + module.replace(os.path.dirname(__file__) + "/plugins/", "") for module in modules] + modules = glob( + os.path.join(os.path.dirname(__file__), "plugins", "*.py"), recursive=True + ) + modules = [ + "wrapyfi.plugins." + + module.replace(os.path.dirname(__file__) + "/plugins/", "") + for module in modules + ] dynamic_module_import(modules, globals()) extern_modules_paths = os.environ.get(WRAPYFI_PLUGIN_PATHS, "").split(":") for mod_group_idx, extern_module_path in enumerate(extern_modules_paths): - extern_modules = glob(os.path.join(extern_module_path, "plugins", "*.py"), recursive=True) + extern_modules = glob( + os.path.join(extern_module_path, "plugins", "*.py"), recursive=True + ) for mod_idx, extern_module in enumerate(extern_modules): - spec = importlib.util.spec_from_file_location(f"wrapyfi.extern{mod_group_idx}.plugins{mod_idx}", extern_module) + spec = importlib.util.spec_from_file_location( + f"wrapyfi.extern{mod_group_idx}.plugins{mod_idx}", extern_module + ) module = importlib.util.module_from_spec(spec) sys.modules[spec.name] = module spec.loader.exec_module(module) - extern_modules = [f"wrapyfi.extern{mod_group_idx}.plugins{mod_idx}." + extern_module.replace( - extern_module_path + "/plugins/", "")] + extern_modules = [ + f"wrapyfi.extern{mod_group_idx}.plugins{mod_idx}." + + extern_module.replace(extern_module_path + "/plugins/", "") + ] dynamic_module_import(extern_modules, globals()) - -