diff --git a/.github/workflows/manual_update_version.yml b/.github/workflows/manual_update_version.yml
index 344079f..556d60a 100644
--- a/.github/workflows/manual_update_version.yml
+++ b/.github/workflows/manual_update_version.yml
@@ -27,6 +27,11 @@ jobs:
# sed -i.bak -E "s|[^<]+|$VERSION|" wrapyfi_extensions/wrapyfi_ros2_interfaces/package.xml
# sed -i.bak -E "s|[^<]+|$VERSION|" wrapyfi_extensions/wrapyfi_ros_interfaces/package.xml
+ - name: Refactor with Code Black
+ run: |
+ python3 -m pip install black
+ black .
+
- name: Generate Documentation
run: |
cd docs
diff --git a/README.md b/README.md
index cbbe323..95e5ab9 100755
--- a/README.md
+++ b/README.md
@@ -14,12 +14,17 @@
[![arXiv](https://custom-icon-badges.demolab.com/badge/arXiv:2302.09648-lightyellow.svg?logo=arxiv-logomark-small)](https://arxiv.org/abs/2302.09648 "arXiv link")
[![doi](https://custom-icon-badges.demolab.com/badge/10.1145/3610977.3637471-lightyellow.svg?logo=doi_logo)](https://doi.org/10.1145/3610977.3637471 "doi link")
+
+
[![License](https://custom-icon-badges.demolab.com/github/license/denvercoder1/custom-icon-badges?logo=law&logoColor=white)](https://github.com/fabawi/wrapyfi/blob/main/LICENSE "license MIT")
[![FOSSA status](https://app.fossa.com/api/projects/git%2Bgithub.com%2Fmodular-ml%2Fwrapyfi.svg?type=shield&issueType=license)](https://app.fossa.com/projects/git%2Bgithub.com%2Fmodular-ml%2Fwrapyfi?ref=badge_shield&issueType=license)
[![Documentation status](https://readthedocs.org/projects/wrapyfi/badge/?version=latest)](https://wrapyfi.readthedocs.io/en/latest/?badge=latest)
-[![codecov](https://codecov.io/github/modular-ml/wrapyfi/graph/badge.svg?token=5SD1A6ENKE)](https://codecov.io/github/modular-ml/wrapyfi)
+[![codecov](https://codecov.io/github/modular-ml/wrapyfi/graph/badge.svg?token=5SD1A6ENKE)](https://codecov.io/github/modular-ml/wrapyfi)
[![PyPI version](https://badge.fury.io/py/wrapyfi.svg)](https://badge.fury.io/py/wrapyfi)
+[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://black.readthedocs.io/ "code style link")
+[![PyPI - Implementation](https://img.shields.io/pypi/implementation/wrapyfi)](https://pypi.org/project/wrapyfi/ "implementation")
+[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/wrapyfi)](https://pypi.org/project/wrapyfi/ "python version")
[![PyPI total downloads](https://img.shields.io/pepy/dt/wrapyfi)](https://www.pepy.tech/projects/wrapyfi)
[![Docker Hub Pulls](https://img.shields.io/docker/pulls/modularml/wrapyfi.svg)](https://hub.docker.com/repository/docker/modularml/wrapyfi)
diff --git a/docs/_extensions/link_modifier.py b/docs/_extensions/link_modifier.py
index 122686d..7e7fc97 100644
--- a/docs/_extensions/link_modifier.py
+++ b/docs/_extensions/link_modifier.py
@@ -3,38 +3,36 @@
import re
REPLACEMENTS = {
- 'https://github.com/fabawi/wrapyfi/tree/main/wrapyfi_extensions/yarp/README.md':
- 'yarp_install_lnk.html',
- 'https://github.com/modular-ml/wrapyfi_ros2_interfaces/blob/master/README.md':
- 'ros2_interfaces_lnk.html',
- 'https://github.com/modular-ml/wrapyfi_ros_interfaces/blob/master/README.md':
- 'ros_interfaces_lnk.html',
- 'https://github.com/fabawi/wrapyfi/tree/main/dockerfiles/README.md':
- 'wrapyfi_docker_lnk.html',
+ "https://github.com/fabawi/wrapyfi/tree/main/wrapyfi_extensions/yarp/README.md": "yarp_install_lnk.html",
+ "https://github.com/modular-ml/wrapyfi_ros2_interfaces/blob/master/README.md": "ros2_interfaces_lnk.html",
+ "https://github.com/modular-ml/wrapyfi_ros_interfaces/blob/master/README.md": "ros_interfaces_lnk.html",
+ "https://github.com/fabawi/wrapyfi/tree/main/dockerfiles/README.md": "wrapyfi_docker_lnk.html",
}
+
class LinkModifier(SphinxTransform):
default_priority = 999
def apply(self):
for node in self.document.traverse(nodes.reference):
- uri = node.get('refuri', '')
+ uri = node.get("refuri", "")
for link in REPLACEMENTS.keys():
if link in uri:
# Extract rank value
- match = re.search(r'\?rank=(-?\d+)', uri)
+ match = re.search(r"\?rank=(-?\d+)", uri)
if match:
rank = int(match.group(1))
if rank < 0:
- prefix = '../' * -rank
+ prefix = "../" * -rank
new_uri = prefix + REPLACEMENTS[link]
else:
new_uri = REPLACEMENTS[link]
else:
# Default to rank -2 if no rank parameter
- new_uri = '../../' + REPLACEMENTS[link]
+ new_uri = "../../" + REPLACEMENTS[link]
+
+ node["refuri"] = uri.replace(link, new_uri)
- node['refuri'] = uri.replace(link, new_uri)
def setup(app):
app.add_transform(LinkModifier)
diff --git a/docs/_extensions/math_block_converter.py b/docs/_extensions/math_block_converter.py
index dc4165a..c658afb 100644
--- a/docs/_extensions/math_block_converter.py
+++ b/docs/_extensions/math_block_converter.py
@@ -1,9 +1,11 @@
from sphinx.application import Sphinx
import re
+
def convert_math_blocks(app: Sphinx, docname: str, source: list):
if source:
- source[0] = re.sub(r'```math', '```{math}', source[0])
+ source[0] = re.sub(r"```math", "```{math}", source[0])
+
def setup(app: Sphinx):
app.connect("source-read", convert_math_blocks)
diff --git a/docs/conf.py b/docs/conf.py
index bfe9a9e..ab8653c 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -6,40 +6,41 @@
import ast
import json
+
def get_project_info_from_setup():
curr_dir = os.path.dirname(__file__)
- setup_path = os.path.join(curr_dir, '..', 'setup.py')
- with open(setup_path, 'r') as f:
+ setup_path = os.path.join(curr_dir, "..", "setup.py")
+ with open(setup_path, "r") as f:
content = f.read()
-
+
name_match = re.search(r"name\s*=\s*['\"]([^'\"]*)['\"]", content)
version_match = re.search(r"version\s*=\s*['\"]([^'\"]*)['\"]", content)
url_match = re.search(r"url\s*=\s*['\"]([^'\"]*)['\"]", content)
author_match = re.search(r"author\s*=\s*['\"]([^'\"]*)['\"]", content)
-
+
if not name_match or not version_match or not url_match:
raise RuntimeError("Unable to find name, version, url, or author string.")
-
+
return {
- 'name': name_match.group(1),
- 'version': version_match.group(1),
- 'url': url_match.group(1),
- 'author': author_match.group(1)
+ "name": name_match.group(1),
+ "version": version_match.group(1),
+ "url": url_match.group(1),
+ "author": author_match.group(1),
}
def get_imported_modules(package_name):
package = importlib.import_module(package_name)
imported_modules = []
- for _, module_name, _ in pkgutil.walk_packages(path=package.__path__,
- prefix=package.__name__ + '.',
- onerror=lambda x: None):
+ for _, module_name, _ in pkgutil.walk_packages(
+ path=package.__path__, prefix=package.__name__ + ".", onerror=lambda x: None
+ ):
imported_modules.append(module_name)
return imported_modules
def get_all_imports_in_file(file_path):
- with open(file_path, 'r', encoding='utf-8') as file:
+ with open(file_path, "r", encoding="utf-8") as file:
tree = ast.parse(file.read(), filename=file_path)
imports = set()
@@ -57,37 +58,45 @@ def get_all_imports_in_package(package_path):
all_imports = set()
for root, dirs, files in os.walk(package_path):
for file in files:
- if file.endswith('.py'):
+ if file.endswith(".py"):
file_path = os.path.join(root, file)
all_imports.update(get_all_imports_in_file(file_path))
return all_imports
+
def setup(app):
- app.add_css_file('wide_theme.css')
+ app.add_css_file("wide_theme.css")
autodoc_default_options = {
- 'members': True,
- 'member-order': 'bysource',
- 'special-members': '__init__',
- 'undoc-members': True,
- 'exclude-members': '__weakref__'
+ "members": True,
+ "member-order": "bysource",
+ "special-members": "__init__",
+ "undoc-members": True,
+ "exclude-members": "__weakref__",
}
-main_doc = 'index'
-html_theme = 'sphinx_rtd_theme'
-html_static_path = ['_static']
-html_css_files = ['wide_theme.css']
-
-extensions = ['sphinx.ext.todo', 'sphinx.ext.viewcode', 'sphinx.ext.autodoc', 'myst_parser', 'sphinx.ext.mathjax',
- 'math_block_converter', 'link_modifier']
-source_suffix = ['.rst', '.md']
+main_doc = "index"
+html_theme = "sphinx_rtd_theme"
+html_static_path = ["_static"]
+html_css_files = ["wide_theme.css"]
+
+extensions = [
+ "sphinx.ext.todo",
+ "sphinx.ext.viewcode",
+ "sphinx.ext.autodoc",
+ "myst_parser",
+ "sphinx.ext.mathjax",
+ "math_block_converter",
+ "link_modifier",
+]
+source_suffix = [".rst", ".md"]
exclude_patterns = ["_build"]
myst_enable_extensions = ["dollarmath", "amsmath"]
# mock all libraries except for the ones that are installed
-with open('exclude_packages.json', 'r') as f:
+with open("exclude_packages.json", "r") as f:
all_imported_modules_pre = set(x for x in json.load(f) if x is not None)
print(all_imported_modules_pre)
# all_imported_modules = get_all_imports_in_package("wrapyfi")
@@ -100,27 +109,31 @@ def setup(app):
# extract project info
project_info = get_project_info_from_setup()
-project = project_info['name']
-release = project_info['version']
-version = '.'.join(release.split('.')[:2])
-url = project_info['url']
-author = project_info['author']
+project = project_info["name"]
+release = project_info["version"]
+version = ".".join(release.split(".")[:2])
+url = project_info["url"]
+author = project_info["author"]
# modify the latex cover page for pdf generation
latex_elements = {
- 'preamble': r'''
+ "preamble": r"""
\usepackage{titling}
\pretitle{%
\begin{center}
\vspace{\droptitle}
\includegraphics[width=60mm]{../assets/wrapyfi.png}\\[\bigskipamount]
- \Large{\textbf{''' + project + '''}}\\
- \normalsize{v''' + release + '''}
+ \Large{\textbf{"""
+ + project
+ + """}}\\
+ \normalsize{v"""
+ + release
+ + """}
}
\posttitle{\end{center}}
-'''
+"""
}
-sys.path.insert(0, os.path.abspath('../'))
-sys.path.append(os.path.abspath('./_extensions'))
-sys.path.append(os.path.abspath('./mock_imports'))
+sys.path.insert(0, os.path.abspath("../"))
+sys.path.append(os.path.abspath("./_extensions"))
+sys.path.append(os.path.abspath("./mock_imports"))
diff --git a/docs/mock_imports/geometry_msgs/msg/__init__.py b/docs/mock_imports/geometry_msgs/msg/__init__.py
index c032b16..11e57eb 100644
--- a/docs/mock_imports/geometry_msgs/msg/__init__.py
+++ b/docs/mock_imports/geometry_msgs/msg/__init__.py
@@ -5,4 +5,4 @@ def __init__(self, *args, **kwargs):
class Quaternion(object):
def __init__(self, *args, **kwargs):
- pass
\ No newline at end of file
+ pass
diff --git a/docs/mock_imports/rclpy/__init__.py b/docs/mock_imports/rclpy/__init__.py
index 852e4aa..cc0f15f 100644
--- a/docs/mock_imports/rclpy/__init__.py
+++ b/docs/mock_imports/rclpy/__init__.py
@@ -1,3 +1,3 @@
class Parameter(object):
def __init__(self, *args, **kwargs):
- pass
\ No newline at end of file
+ pass
diff --git a/docs/mock_imports/std_msgs/msg/__init__.py b/docs/mock_imports/std_msgs/msg/__init__.py
index 2b53c11..5862ec4 100644
--- a/docs/mock_imports/std_msgs/msg/__init__.py
+++ b/docs/mock_imports/std_msgs/msg/__init__.py
@@ -1,3 +1,3 @@
class String(object):
def __init__(self, *args, **kwargs):
- pass
\ No newline at end of file
+ pass
diff --git a/examples/applications/affective_signaling_multirobot.py b/examples/applications/affective_signaling_multirobot.py
index 36cbdcc..d900623 100644
--- a/examples/applications/affective_signaling_multirobot.py
+++ b/examples/applications/affective_signaling_multirobot.py
@@ -19,93 +19,210 @@ class ExperimentController(MiddlewareCommunicator):
def __init__(self, **kwargs):
super(ExperimentController, self).__init__()
- @MiddlewareCommunicator.register("NativeObject", "$_mware", "ExperimentController",
- "/control_interface/facial_expressions_esr9", should_wait=False)
+ @MiddlewareCommunicator.register(
+ "NativeObject",
+ "$_mware",
+ "ExperimentController",
+ "/control_interface/facial_expressions_esr9",
+ should_wait=False,
+ )
def listen_facial_expressions_esr9(self, _mware=DEFAULT_COMMUNICATOR):
- return None,
+ return (None,)
- @MiddlewareCommunicator.register("NativeObject", "$_mware", "ExperimentController",
- "/control_interface/facial_expressions_icub", should_wait=False)
+ @MiddlewareCommunicator.register(
+ "NativeObject",
+ "$_mware",
+ "ExperimentController",
+ "/control_interface/facial_expressions_icub",
+ should_wait=False,
+ )
def publish_facial_expressions_icub(self, obj, _mware="yarp"):
- return obj,
+ return (obj,)
- @MiddlewareCommunicator.register("NativeObject", "$_mware", "ExperimentController",
- "/control_interface/facial_expressions_pepper",
- carrier="tcp", should_wait=False)
+ @MiddlewareCommunicator.register(
+ "NativeObject",
+ "$_mware",
+ "ExperimentController",
+ "/control_interface/facial_expressions_pepper",
+ carrier="tcp",
+ should_wait=False,
+ )
def publish_facial_expressions_pepper(self, obj, _mware="ros"):
- return obj,
-
- @MiddlewareCommunicator.register("Image", "$_mware", "ExperimentController", "$_topic",
- width="$_width", height="$_height", rgb=True, fp=False, should_wait=False,
- jpg=True)
- def listen_image_webcam(self, _mware=DEFAULT_COMMUNICATOR, _width=WEBCAM_WIDTH, _height=WEBCAM_HEIGHT,
- _topic="/control_interface/image_webcam"): # dynamic topic according to corresponding API
- return None,
-
- @MiddlewareCommunicator.register("Image", "$_mware", "ExperimentController", "$_topic",
- width="$_width", height="$_height", rgb=True, fp=False, should_wait=False)
- def listen_image_pepper_cam(self, _mware="ros", _width=PEPPER_CAM_WIDTH, _height=PEPPER_CAM_HEIGHT,
- _topic="/pepper/camera/front/camera/image_raw"):
- return None,
-
- @MiddlewareCommunicator.register("Image", "$_mware", "ExperimentController", "$_topic",
- width="$_width", height="$_height", rgb=True, fp=False, should_wait=False)
- def listen_image_icub_cam(self, _mware="yarp", _width=ICUB_CAM_WIDTH, _height=ICUB_CAM_HEIGHT,
- _topic="/icub/cam/right"):
- return None,
-
- @MiddlewareCommunicator.register("Image", "$_mware", "ExperimentController",
- "/control_interface/image_webcam",
- width="$_width", height="$_height", rgb=True, fp=False, should_wait=False)
- def forward_image_webcam(self, img, _mware=DEFAULT_COMMUNICATOR, _width=WEBCAM_WIDTH,
- _height=WEBCAM_HEIGHT):
- return img,
-
- @MiddlewareCommunicator.register("Image", "$_mware", "ExperimentController",
- "/control_interface/image_pepper_cam",
- width="$_width", height="$_height", rgb=True, fp=False, jpg=True,
- should_wait=False)
- def forward_image_pepper_cam(self, img, _mware=DEFAULT_COMMUNICATOR, _width=PEPPER_CAM_WIDTH,
- _height=PEPPER_CAM_HEIGHT):
- return img,
-
- @MiddlewareCommunicator.register("Image", "$_mware", "ExperimentController",
- "/control_interface/image_icub_cam",
- width="$_width", height="$_height", rgb=True, fp=False, jpg=True,
- should_wait=False)
- def forward_image_icub_cam(self, img, _mware=DEFAULT_COMMUNICATOR, _width=ICUB_CAM_WIDTH,
- _height=ICUB_CAM_HEIGHT):
+ return (obj,)
+
+ @MiddlewareCommunicator.register(
+ "Image",
+ "$_mware",
+ "ExperimentController",
+ "$_topic",
+ width="$_width",
+ height="$_height",
+ rgb=True,
+ fp=False,
+ should_wait=False,
+ jpg=True,
+ )
+ def listen_image_webcam(
+ self,
+ _mware=DEFAULT_COMMUNICATOR,
+ _width=WEBCAM_WIDTH,
+ _height=WEBCAM_HEIGHT,
+ _topic="/control_interface/image_webcam",
+ ): # dynamic topic according to corresponding API
+ return (None,)
+
+ @MiddlewareCommunicator.register(
+ "Image",
+ "$_mware",
+ "ExperimentController",
+ "$_topic",
+ width="$_width",
+ height="$_height",
+ rgb=True,
+ fp=False,
+ should_wait=False,
+ )
+ def listen_image_pepper_cam(
+ self,
+ _mware="ros",
+ _width=PEPPER_CAM_WIDTH,
+ _height=PEPPER_CAM_HEIGHT,
+ _topic="/pepper/camera/front/camera/image_raw",
+ ):
+ return (None,)
+
+ @MiddlewareCommunicator.register(
+ "Image",
+ "$_mware",
+ "ExperimentController",
+ "$_topic",
+ width="$_width",
+ height="$_height",
+ rgb=True,
+ fp=False,
+ should_wait=False,
+ )
+ def listen_image_icub_cam(
+ self,
+ _mware="yarp",
+ _width=ICUB_CAM_WIDTH,
+ _height=ICUB_CAM_HEIGHT,
+ _topic="/icub/cam/right",
+ ):
+ return (None,)
+
+ @MiddlewareCommunicator.register(
+ "Image",
+ "$_mware",
+ "ExperimentController",
+ "/control_interface/image_webcam",
+ width="$_width",
+ height="$_height",
+ rgb=True,
+ fp=False,
+ should_wait=False,
+ )
+ def forward_image_webcam(
+ self,
+ img,
+ _mware=DEFAULT_COMMUNICATOR,
+ _width=WEBCAM_WIDTH,
+ _height=WEBCAM_HEIGHT,
+ ):
+ return (img,)
+
+ @MiddlewareCommunicator.register(
+ "Image",
+ "$_mware",
+ "ExperimentController",
+ "/control_interface/image_pepper_cam",
+ width="$_width",
+ height="$_height",
+ rgb=True,
+ fp=False,
+ jpg=True,
+ should_wait=False,
+ )
+ def forward_image_pepper_cam(
+ self,
+ img,
+ _mware=DEFAULT_COMMUNICATOR,
+ _width=PEPPER_CAM_WIDTH,
+ _height=PEPPER_CAM_HEIGHT,
+ ):
+ return (img,)
+
+ @MiddlewareCommunicator.register(
+ "Image",
+ "$_mware",
+ "ExperimentController",
+ "/control_interface/image_icub_cam",
+ width="$_width",
+ height="$_height",
+ rgb=True,
+ fp=False,
+ jpg=True,
+ should_wait=False,
+ )
+ def forward_image_icub_cam(
+ self,
+ img,
+ _mware=DEFAULT_COMMUNICATOR,
+ _width=ICUB_CAM_WIDTH,
+ _height=ICUB_CAM_HEIGHT,
+ ):
# convert to bgr
if img is not None:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
- return img,
+ return (img,)
- @MiddlewareCommunicator.register("Image", "$_mware", "ExperimentController",
- "/control_interface/image_esr9",
- width="$_width", height="$_height", rgb=True, fp=False, should_wait=False, jpg=True)
- def publish_esr9_cam(self, img, _mware=DEFAULT_COMMUNICATOR, _width=ESR_CAM_WIDTH, _height=ESR_CAM_HEIGHT):
+ @MiddlewareCommunicator.register(
+ "Image",
+ "$_mware",
+ "ExperimentController",
+ "/control_interface/image_esr9",
+ width="$_width",
+ height="$_height",
+ rgb=True,
+ fp=False,
+ should_wait=False,
+ jpg=True,
+ )
+ def publish_esr9_cam(
+ self,
+ img,
+ _mware=DEFAULT_COMMUNICATOR,
+ _width=ESR_CAM_WIDTH,
+ _height=ESR_CAM_HEIGHT,
+ ):
# modify size of image
if img is not None:
img = cv2.resize(img, (self.ESR_CAM_WIDTH, self.ESR_CAM_HEIGHT))
- return img,
+ return (img,)
def parse_args():
parser = argparse.ArgumentParser()
- parser.add_argument("--wrapyfi_cfg",
- help="File to load Wrapyfi configs for running instance. "
- "Choose one of the configs available in "
- "./wrapyfi_configs/affective_signaling_multirobot "
- "for each running instance. All configs with a prefix of COMP must "
- "have corresponding instances running. OPT prefixed configs execute "
- "scripts on a machine connected to either robot (Pepper/iCub) or both, "
- "and at least one must run in addition to COMP.",
- type=str)
- parser.add_argument("--cam_source",
- help="The camera input source being either from a "
- "webcam, Pepper, or iCub. Note that this must be similar for all running "
- "instances even when the two robots are set to display the expressions.",
- type=str, default="webcam", choices=["webcam", "pepper", "icub"])
+ parser.add_argument(
+ "--wrapyfi_cfg",
+ help="File to load Wrapyfi configs for running instance. "
+ "Choose one of the configs available in "
+ "./wrapyfi_configs/affective_signaling_multirobot "
+ "for each running instance. All configs with a prefix of COMP must "
+ "have corresponding instances running. OPT prefixed configs execute "
+ "scripts on a machine connected to either robot (Pepper/iCub) or both, "
+ "and at least one must run in addition to COMP.",
+ type=str,
+ )
+ parser.add_argument(
+ "--cam_source",
+ help="The camera input source being either from a "
+ "webcam, Pepper, or iCub. Note that this must be similar for all running "
+ "instances even when the two robots are set to display the expressions.",
+ type=str,
+ default="webcam",
+ choices=["webcam", "pepper", "icub"],
+ )
return parser.parse_args()
@@ -119,8 +236,8 @@ def parse_args():
image_cam = None
while True:
if args.cam_source == "webcam":
- image_webcam, = ec.listen_image_webcam()
- image_webcam_linked, = ec.forward_image_webcam(image_webcam)
+ (image_webcam,) = ec.listen_image_webcam()
+ (image_webcam_linked,) = ec.forward_image_webcam(image_webcam)
if image_webcam is None:
image_webcam = image_webcam_linked
image_cam = image_webcam
@@ -128,8 +245,8 @@ def parse_args():
cv2.imshow("Webcam image", image_webcam)
cv2.waitKey(1)
if args.cam_source == "pepper":
- image_pepper_cam, = ec.listen_image_pepper_cam()
- image_pepper_cam_linked, = ec.forward_image_pepper_cam(image_pepper_cam)
+ (image_pepper_cam,) = ec.listen_image_pepper_cam()
+ (image_pepper_cam_linked,) = ec.forward_image_pepper_cam(image_pepper_cam)
if image_pepper_cam is None:
image_pepper_cam = image_pepper_cam_linked
image_cam = image_pepper_cam
@@ -137,8 +254,8 @@ def parse_args():
cv2.imshow("Pepper image", image_pepper_cam)
cv2.waitKey(1)
if args.cam_source == "icub":
- image_icub_cam, = ec.listen_image_icub_cam()
- image_icub_cam_linked, = ec.forward_image_icub_cam(image_icub_cam)
+ (image_icub_cam,) = ec.listen_image_icub_cam()
+ (image_icub_cam_linked,) = ec.forward_image_icub_cam(image_icub_cam)
if image_icub_cam is None:
image_icub_cam = image_icub_cam_linked
image_cam = image_icub_cam
@@ -146,9 +263,8 @@ def parse_args():
cv2.imshow("iCub image", image_icub_cam_linked)
cv2.waitKey(1)
- image_esr, = ec.publish_esr9_cam(image_cam)
- facial_expression, = ec.listen_facial_expressions_esr9()
+ (image_esr,) = ec.publish_esr9_cam(image_cam)
+ (facial_expression,) = ec.listen_facial_expressions_esr9()
if facial_expression is not None:
ec.publish_facial_expressions_icub(facial_expression)
ec.publish_facial_expressions_pepper(facial_expression)
-
diff --git a/examples/applications/gaze_mirroring_multisensor.py b/examples/applications/gaze_mirroring_multisensor.py
index f5181b5..30eca32 100644
--- a/examples/applications/gaze_mirroring_multisensor.py
+++ b/examples/applications/gaze_mirroring_multisensor.py
@@ -17,46 +17,103 @@ class ExperimentController(MiddlewareCommunicator):
def __init__(self, **kwargs):
super(ExperimentController, self).__init__()
- @MiddlewareCommunicator.register("NativeObject", "$_sixdrepnet_mware", "ExperimentController",
- "/control_interface/orientation_sixdrepnet", should_wait=False)
- @MiddlewareCommunicator.register("NativeObject", "$_waveshareimu_mware", "ExperimentController",
- "/control_interface/orientation_waveshareimu", should_wait=False)
- def listen_orientation_sixdrepnet_waveshareimu(self, _sixdrepnet_mware="yarp", _waveshareimu_mware="yarp"):
- return None, None,
-
- @MiddlewareCommunicator.register("NativeObject", "$_mware", "ExperimentController",
- "/control_interface/orientation_icub", should_wait=False)
+ @MiddlewareCommunicator.register(
+ "NativeObject",
+ "$_sixdrepnet_mware",
+ "ExperimentController",
+ "/control_interface/orientation_sixdrepnet",
+ should_wait=False,
+ )
+ @MiddlewareCommunicator.register(
+ "NativeObject",
+ "$_waveshareimu_mware",
+ "ExperimentController",
+ "/control_interface/orientation_waveshareimu",
+ should_wait=False,
+ )
+ def listen_orientation_sixdrepnet_waveshareimu(
+ self, _sixdrepnet_mware="yarp", _waveshareimu_mware="yarp"
+ ):
+ return (
+ None,
+ None,
+ )
+
+ @MiddlewareCommunicator.register(
+ "NativeObject",
+ "$_mware",
+ "ExperimentController",
+ "/control_interface/orientation_icub",
+ should_wait=False,
+ )
def publish_orientation_icub(self, obj, _mware="yarp"):
- return obj,
-
- @MiddlewareCommunicator.register("NativeObject", "$_mware", "ExperimentController",
- "/control_interface/gaze_pupil", should_wait=False)
+ return (obj,)
+
+ @MiddlewareCommunicator.register(
+ "NativeObject",
+ "$_mware",
+ "ExperimentController",
+ "/control_interface/gaze_pupil",
+ should_wait=False,
+ )
def listen_gaze_pupil(self, _mware="zeromq"):
- return None,
-
- @MiddlewareCommunicator.register("NativeObject", "$_mware", "ExperimentController",
- "/control_interface/gaze_icub", should_wait=False)
+ return (None,)
+
+ @MiddlewareCommunicator.register(
+ "NativeObject",
+ "$_mware",
+ "ExperimentController",
+ "/control_interface/gaze_icub",
+ should_wait=False,
+ )
def publish_gaze_icub(self, obj, _mware="yarp"):
- return obj,
-
-
- @MiddlewareCommunicator.register("Image", "$_mware", "ExperimentController",
- "/control_interface/image_webcam",
- width="$_width", height="$_height", rgb=True, fp=False, should_wait=False,
- jpg=True)
- def listen_image_webcam(self, _mware="ros2", _width=WEBCAM_WIDTH, _height=WEBCAM_HEIGHT):
- return None,
-
- @MiddlewareCommunicator.register("Image", "$_mware", "ExperimentController",
- "/control_interface/image_sixdrepnet",
- width="$_width", height="$_height", rgb=True, fp=False, should_wait=False)
- def publish_sixdrepnet_cam(self, img, _mware="ros2", _width=SIXDREPNET_CAM_WIDTH, _height=SIXDREPNET_CAM_HEIGHT):
+ return (obj,)
+
+ @MiddlewareCommunicator.register(
+ "Image",
+ "$_mware",
+ "ExperimentController",
+ "/control_interface/image_webcam",
+ width="$_width",
+ height="$_height",
+ rgb=True,
+ fp=False,
+ should_wait=False,
+ jpg=True,
+ )
+ def listen_image_webcam(
+ self, _mware="ros2", _width=WEBCAM_WIDTH, _height=WEBCAM_HEIGHT
+ ):
+ return (None,)
+
+ @MiddlewareCommunicator.register(
+ "Image",
+ "$_mware",
+ "ExperimentController",
+ "/control_interface/image_sixdrepnet",
+ width="$_width",
+ height="$_height",
+ rgb=True,
+ fp=False,
+ should_wait=False,
+ )
+ def publish_sixdrepnet_cam(
+ self,
+ img,
+ _mware="ros2",
+ _width=SIXDREPNET_CAM_WIDTH,
+ _height=SIXDREPNET_CAM_HEIGHT,
+ ):
# modify size of image
if img is not None:
- img = cv2.resize(img, (self.SIXDREPNET_CAM_WIDTH, self.SIXDREPNET_CAM_HEIGHT))
- return img,
-
- def priority_control_sources(self, orientation_sixdrepnet, orientation_imu, control_sources):
+ img = cv2.resize(
+ img, (self.SIXDREPNET_CAM_WIDTH, self.SIXDREPNET_CAM_HEIGHT)
+ )
+ return (img,)
+
+ def priority_control_sources(
+ self, orientation_sixdrepnet, orientation_imu, control_sources
+ ):
if control_sources[0] == "vision":
if orientation_sixdrepnet is not None:
# print("Orientation 6DRepNet: ", orientation_sixdrepnet)
@@ -72,27 +129,38 @@ def priority_control_sources(self, orientation_sixdrepnet, orientation_imu, cont
# print("Orientation 6DRepNet: ", orientation_sixdrepnet)
self.publish_orientation_icub(orientation_sixdrepnet)
+
def parse_args():
parser = argparse.ArgumentParser()
- parser.add_argument("--wrapyfi_cfg",
- help="File to load Wrapyfi configs for running instance. "
- "Choose one of the configs available in "
- "./wrapyfi_configs/gaze_mirroring_multisensor. "
- "for each running instance. All configs with a prefix of COMP must "
- "have corresponding instances running. OPT prefixed config is optional "
- "(only when using a vision model) executes the script on a machine connected "
- "to the camera (or has access to the camera topic).",
- type=str)
- parser.add_argument("--control_sources",
- help="Control sources to use for the experiment. The order of sources indicates the priority. "
- "For example, if vision is the first source, then the vision source will "
- "be used for control. If vision is not available, then the IMU source will be used. "
- "If one source is provided, then the experiment will run with that source. ",
- type=str, default=["vision", "imu"], nargs="+", choices=["vision", "imu"])
- parser.add_argument("--enable_gaze",
- help="Enable the gaze (eye movement) control. If not enabled, "
- "then the gaze control will not be used. ",
- action="store_true", default=False)
+ parser.add_argument(
+ "--wrapyfi_cfg",
+ help="File to load Wrapyfi configs for running instance. "
+ "Choose one of the configs available in "
+ "./wrapyfi_configs/gaze_mirroring_multisensor. "
+ "for each running instance. All configs with a prefix of COMP must "
+ "have corresponding instances running. OPT prefixed config is optional "
+ "(only when using a vision model) executes the script on a machine connected "
+ "to the camera (or has access to the camera topic).",
+ type=str,
+ )
+ parser.add_argument(
+ "--control_sources",
+ help="Control sources to use for the experiment. The order of sources indicates the priority. "
+ "For example, if vision is the first source, then the vision source will "
+ "be used for control. If vision is not available, then the IMU source will be used. "
+ "If one source is provided, then the experiment will run with that source. ",
+ type=str,
+ default=["vision", "imu"],
+ nargs="+",
+ choices=["vision", "imu"],
+ )
+ parser.add_argument(
+ "--enable_gaze",
+ help="Enable the gaze (eye movement) control. If not enabled, "
+ "then the gaze control will not be used. ",
+ action="store_true",
+ default=False,
+ )
return parser.parse_args()
@@ -107,8 +175,8 @@ def parse_args():
image_cam = None
while True:
if "vision" in args.control_sources:
- image_webcam, = ec.listen_image_webcam()
- image_webcam_linked, = ec.publish_sixdrepnet_cam(image_cam)
+ (image_webcam,) = ec.listen_image_webcam()
+ (image_webcam_linked,) = ec.publish_sixdrepnet_cam(image_cam)
if image_webcam is None:
image_webcam = image_webcam_linked
image_cam = image_webcam
@@ -116,15 +184,15 @@ def parse_args():
cv2.imshow("Webcam image", image_webcam)
cv2.waitKey(1)
- orientation_sixdrepnet, orientation_imu = ec.listen_orientation_sixdrepnet_waveshareimu()
- ec.priority_control_sources(orientation_sixdrepnet, orientation_imu, args.control_sources)
+ orientation_sixdrepnet, orientation_imu = (
+ ec.listen_orientation_sixdrepnet_waveshareimu()
+ )
+ ec.priority_control_sources(
+ orientation_sixdrepnet, orientation_imu, args.control_sources
+ )
if args.enable_gaze:
- gaze_pupil, = ec.listen_gaze_pupil()
+ (gaze_pupil,) = ec.listen_gaze_pupil()
if gaze_pupil is not None:
# print("Gaze Pupil: ", gaze_pupil)
ec.publish_gaze_icub(gaze_pupil)
-
-
-
-
diff --git a/examples/communication_patterns/request_reply_example.py b/examples/communication_patterns/request_reply_example.py
index 1984b27..be69234 100755
--- a/examples/communication_patterns/request_reply_example.py
+++ b/examples/communication_patterns/request_reply_example.py
@@ -56,16 +56,28 @@
class ReqRep(MiddlewareCommunicator):
@MiddlewareCommunicator.register(
- "NativeObject", "$mware", "ReqRep", "/req_rep/my_message",
- carrier="tcp", persistent=True
+ "NativeObject",
+ "$mware",
+ "ReqRep",
+ "/req_rep/my_message",
+ carrier="tcp",
+ persistent=True,
)
@MiddlewareCommunicator.register(
- "Image", "$mware", "ReqRep", "/req_rep/my_image_message",
- carrier="", width="$img_width", height="$img_height", rgb=True, jpg=True,
- persistent=True
+ "Image",
+ "$mware",
+ "ReqRep",
+ "/req_rep/my_image_message",
+ carrier="",
+ width="$img_width",
+ height="$img_height",
+ rgb=True,
+ jpg=True,
+ persistent=True,
)
- def send_img_message(self, msg=None, img_width=320, img_height=240,
- mware=None, *args, **kwargs):
+ def send_img_message(
+ self, msg=None, img_width=320, img_height=240, mware=None, *args, **kwargs
+ ):
"""
Exchange messages with OpenCV images and other native Python objects.
"""
@@ -74,26 +86,62 @@ def send_img_message(self, msg=None, img_width=320, img_height=240,
# read image from file
img = cv2.imread("../../assets/wrapyfi.png")
img = cv2.resize(img, (img_width, img_height), interpolation=cv2.INTER_AREA)
- cv2.putText(img, msg,
- ((img.shape[1] - cv2.getTextSize(msg, cv2.FONT_HERSHEY_SIMPLEX, 1, 2)[0][0]) // 2,
- (img.shape[0] + cv2.getTextSize(msg, cv2.FONT_HERSHEY_SIMPLEX, 1, 2)[0][1]) // 2),
- cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2, cv2.LINE_AA)
+ cv2.putText(
+ img,
+ msg,
+ (
+ (
+ img.shape[1]
+ - cv2.getTextSize(msg, cv2.FONT_HERSHEY_SIMPLEX, 1, 2)[0][0]
+ )
+ // 2,
+ (
+ img.shape[0]
+ + cv2.getTextSize(msg, cv2.FONT_HERSHEY_SIMPLEX, 1, 2)[0][1]
+ )
+ // 2,
+ ),
+ cv2.FONT_HERSHEY_SIMPLEX,
+ 1,
+ (255, 0, 0),
+ 2,
+ cv2.LINE_AA,
+ )
return obj, img
@MiddlewareCommunicator.register(
- "NativeObject", "$mware", "ReqRep", "/req_rep/my_message",
- carrier="tcp", persistent=True
+ "NativeObject",
+ "$mware",
+ "ReqRep",
+ "/req_rep/my_message",
+ carrier="tcp",
+ persistent=True,
)
@MiddlewareCommunicator.register(
- "AudioChunk", "$mware", "ReqRep", "/req_rep/my_audio_message",
- carrier="", rate="$aud_rate", chunk="$aud_chunk", channels="$aud_channels",
- persistent=True
+ "AudioChunk",
+ "$mware",
+ "ReqRep",
+ "/req_rep/my_audio_message",
+ carrier="",
+ rate="$aud_rate",
+ chunk="$aud_chunk",
+ channels="$aud_channels",
+ persistent=True,
)
- def send_aud_message(self, msg=None,
- aud_rate=-1, aud_chunk=-1, aud_channels=2,
- mware=None, *args, **kwargs):
- """Exchange messages with sounddevice audio chunks and other native Python objects."""
+ def send_aud_message(
+ self,
+ msg=None,
+ aud_rate=-1,
+ aud_chunk=-1,
+ aud_channels=2,
+ mware=None,
+ *args,
+ **kwargs,
+ ):
+ """
+ Exchange messages with sounddevice audio chunks and other native Python objects.
+ """
obj = {"message": msg, "args": args, "kwargs": kwargs}
# read audio from file
aud = sf.read("../../assets/sound_test.wav", dtype="float32")
@@ -102,32 +150,46 @@ def send_aud_message(self, msg=None,
def parse_args():
- """Parse command line arguments."""
- parser = argparse.ArgumentParser(description="A message requester and replier for native Python objects, images using OpenCV, and sound using PortAudio.")
+ """
+ Parse command line arguments.
+ """
+ parser = argparse.ArgumentParser(
+ description="A message requester and replier for native Python objects, images using OpenCV, and sound using PortAudio."
+ )
parser.add_argument(
- "--mode", type=str, default="request",
+ "--mode",
+ type=str,
+ default="request",
choices={"request", "reply"},
- help="The mode of communication, either 'request' or 'reply'"
+ help="The mode of communication, either 'request' or 'reply'",
)
parser.add_argument(
- "--mware", type=str, default=DEFAULT_COMMUNICATOR,
+ "--mware",
+ type=str,
+ default=DEFAULT_COMMUNICATOR,
choices=MiddlewareCommunicator.get_communicators(),
- help="The middleware to use for transmission"
+ help="The middleware to use for transmission",
)
parser.add_argument(
- "--sound_device", type=int, default=0,
- help="The sound device to use for audio playback"
+ "--sound_device",
+ type=int,
+ default=0,
+ help="The sound device to use for audio playback",
)
parser.add_argument(
- "--list_sound_devices", action="store_true",
- help="List all available sound devices and exit"
+ "--list_sound_devices",
+ action="store_true",
+ help="List all available sound devices and exit",
)
parser.add_argument(
- "--stream", type=str, default="image", choices={"image", "audio"},
- help="The streamed data as either 'image' or 'audio'"
+ "--stream",
+ type=str,
+ default="image",
+ choices={"image", "audio"},
+ help="The streamed data as either 'image' or 'audio'",
)
return parser.parse_args()
@@ -145,7 +207,9 @@ def sound_play(my_aud, blocking=True, device=0):
sd.play(*my_aud, blocking=blocking, device=device)
return True
except sd.PortAudioError:
- logging.warning("PortAudioError: No device is found or the device is already in use. Will try again in 3 seconds.")
+ logging.warning(
+ "PortAudioError: No device is found or the device is already in use. Will try again in 3 seconds."
+ )
return False
@@ -154,10 +218,17 @@ def main(args):
print(sd.query_devices())
return
- """Main function to initiate ReqRep class and communication."""
+ """
+
+ Main function to initiate ReqRep class and communication.
+ """
if args.mode == "request" and args.mware == "zeromq":
- print("WE INTENTIONALLY WAIT 5 SECONDS TO ALLOW THE REPLIER ENOUGH TIME TO START UP. ")
- print("THIS IS NEEDED WHEN USING ZEROMQ AS THE COMMUNICATION MIDDLEWARE IF THE SERVER ")
+ print(
+ "WE INTENTIONALLY WAIT 5 SECONDS TO ALLOW THE REPLIER ENOUGH TIME TO START UP. "
+ )
+ print(
+ "THIS IS NEEDED WHEN USING ZEROMQ AS THE COMMUNICATION MIDDLEWARE IF THE SERVER "
+ )
print("IS SET TO SPAWN A PROXY BROKER (DEFAULT).")
time.sleep(5)
@@ -175,8 +246,13 @@ def main(args):
# but this separation is NOT necessary for the method to work
if args.mode == "request":
msg = input("Type your message: ")
- my_message, my_image = req_rep.send_img_message(msg, counter=counter, mware=args.mware)
- my_message2, my_aud, = req_rep.send_aud_message(msg, counter=counter, mware=args.mware)
+ my_message, my_image = req_rep.send_img_message(
+ msg, counter=counter, mware=args.mware
+ )
+ (
+ my_message2,
+ my_aud,
+ ) = req_rep.send_aud_message(msg, counter=counter, mware=args.mware)
my_message = my_message2 if my_message2 is not None else my_message
counter += 1
if my_message is not None:
@@ -186,7 +262,9 @@ def main(args):
cv2.imshow("Received image", my_image)
while True:
k = cv2.waitKey(1) & 0xFF
- if not (cv2.getWindowProperty("Received image", cv2.WND_PROP_VISIBLE)):
+ if not (
+ cv2.getWindowProperty("Received image", cv2.WND_PROP_VISIBLE)
+ ):
break
if cv2.waitKey(1) == 27:
@@ -204,7 +282,10 @@ def main(args):
# The send_message() only executes in "reply" mode,
# meaning, the method is only accessible from this code block
my_message, my_image = req_rep.send_img_message(mware=args.mware)
- my_message2, my_aud, = req_rep.send_aud_message(mware=args.mware)
+ (
+ my_message2,
+ my_aud,
+ ) = req_rep.send_aud_message(mware=args.mware)
my_message = my_message2 if my_message2 is not None else my_message
if my_message is not None:
print("Reply: received reply:", my_message)
@@ -229,4 +310,4 @@ def main(args):
if __name__ == "__main__":
args = parse_args()
- main(args)
\ No newline at end of file
+ main(args)
diff --git a/examples/communication_schemes/channeling_example.py b/examples/communication_schemes/channeling_example.py
index 8295161..ed8bf9e 100644
--- a/examples/communication_schemes/channeling_example.py
+++ b/examples/communication_schemes/channeling_example.py
@@ -42,55 +42,100 @@
class ChannelingCls(MiddlewareCommunicator):
@MiddlewareCommunicator.register(
- "NativeObject", "$mware_A", "ChannelingCls", "/example/native_A_msg",
- carrier="mcast", should_wait=True
+ "NativeObject",
+ "$mware_A",
+ "ChannelingCls",
+ "/example/native_A_msg",
+ carrier="mcast",
+ should_wait=True,
)
@MiddlewareCommunicator.register(
- "Image", "$mware_B", "ChannelingCls", "/example/image_B_msg",
- carrier="tcp", width="$img_width", height="$img_height", rgb=True, should_wait=False
+ "Image",
+ "$mware_B",
+ "ChannelingCls",
+ "/example/image_B_msg",
+ carrier="tcp",
+ width="$img_width",
+ height="$img_height",
+ rgb=True,
+ should_wait=False,
)
@MiddlewareCommunicator.register(
- "AudioChunk", "$mware_C", "ChannelingCls", "/example/audio_C_msg",
- carrier="tcp", rate="$aud_rate", chunk="$aud_chunk", channels="$aud_channels", should_wait=False
+ "AudioChunk",
+ "$mware_C",
+ "ChannelingCls",
+ "/example/audio_C_msg",
+ carrier="tcp",
+ rate="$aud_rate",
+ chunk="$aud_chunk",
+ channels="$aud_channels",
+ should_wait=False,
)
- def read_mulret_mulmware(self, img_width=200, img_height=200,
- aud_rate=44100, aud_chunk=8820, aud_channels=1,
- mware_A=None, mware_B=None, mware_C=None):
- """Read and forward messages through channels A, B, and C."""
- ros_img = np.random.randint(256, size=(img_height, img_width, 3), dtype=np.uint8)
- zeromq_aud = (np.random.uniform(-1, 1, aud_chunk), aud_rate,)
+ def read_mulret_mulmware(
+ self,
+ img_width=200,
+ img_height=200,
+ aud_rate=44100,
+ aud_chunk=8820,
+ aud_channels=1,
+ mware_A=None,
+ mware_B=None,
+ mware_C=None,
+ ):
+ """
+ Read and forward messages through channels A, B, and C.
+ """
+ ros_img = np.random.randint(
+ 256, size=(img_height, img_width, 3), dtype=np.uint8
+ )
+ zeromq_aud = (
+ np.random.uniform(-1, 1, aud_chunk),
+ aud_rate,
+ )
yarp_native = [ros_img, zeromq_aud]
return yarp_native, ros_img, zeromq_aud
def parse_args():
- """Parse command line arguments."""
+ """
+ Parse command line arguments.
+ """
parser = argparse.ArgumentParser(description="Channeling Example using Wrapyfi.")
parser.add_argument(
- "--mode", type=str, default="publish",
+ "--mode",
+ type=str,
+ default="publish",
choices={"publish", "listen"},
- help="The transmission mode"
+ help="The transmission mode",
)
parser.add_argument(
- "--mware_A", type=str, default="none",
+ "--mware_A",
+ type=str,
+ default="none",
choices=MiddlewareCommunicator.get_communicators().update({"none"}),
- help="The middleware to use for transmission of channel A"
+ help="The middleware to use for transmission of channel A",
)
parser.add_argument(
- "--mware_B", type=str, default="none",
+ "--mware_B",
+ type=str,
+ default="none",
choices=MiddlewareCommunicator.get_communicators().update({"none"}),
- help="The middleware to use for transmission of channel B"
+ help="The middleware to use for transmission of channel B",
)
parser.add_argument(
- "--mware_C", type=str, default="none",
+ "--mware_C",
+ type=str,
+ default="none",
choices=MiddlewareCommunicator.get_communicators().update({"none"}),
- help="The middleware to use for transmission of channel C"
+ help="The middleware to use for transmission of channel C",
)
return parser.parse_args()
def main(args):
- """Main function to initiate ChannelingCls class and communication."""
+ """
+ Main function to initiate ChannelingCls class and communication.
+ """
channeling = ChannelingCls()
channeling.activate_communication(channeling.read_mulret_mulmware, mode=args.mode)
diff --git a/examples/communication_schemes/forwarding_example.py b/examples/communication_schemes/forwarding_example.py
index 0d8cebe..f49dd1e 100644
--- a/examples/communication_schemes/forwarding_example.py
+++ b/examples/communication_schemes/forwarding_example.py
@@ -36,62 +36,95 @@
class ForwardCls(MiddlewareCommunicator):
@MiddlewareCommunicator.register(
- 'NativeObject', '$mware_chain_A', 'ForwardCls', '/example/native_chain_A_msg',
- carrier='mcast', should_wait=True)
- def read_chain_A(self, mware_chain_A=None, msg=''):
- """Read and forward message from chain A."""
- return msg,
+ "NativeObject",
+ "$mware_chain_A",
+ "ForwardCls",
+ "/example/native_chain_A_msg",
+ carrier="mcast",
+ should_wait=True,
+ )
+ def read_chain_A(self, mware_chain_A=None, msg=""):
+ """
+ Read and forward message from chain A.
+ """
+ return (msg,)
@MiddlewareCommunicator.register(
- 'NativeObject', '$mware_chain_B', 'ForwardCls', '/example/native_chain_B_msg',
- carrier='tcp', should_wait=False)
- def read_chain_B(self, mware_chain_B=None, msg=''):
- """Read and forward message from chain B."""
- return msg,
+ "NativeObject",
+ "$mware_chain_B",
+ "ForwardCls",
+ "/example/native_chain_B_msg",
+ carrier="tcp",
+ should_wait=False,
+ )
+ def read_chain_B(self, mware_chain_B=None, msg=""):
+ """
+ Read and forward message from chain B.
+ """
+ return (msg,)
+
def parse_args():
- """Parse command line arguments."""
+ """
+ Parse command line arguments.
+ """
parser = argparse.ArgumentParser(description="Forwarding Example using Wrapyfi.")
parser.add_argument(
- "--mode_chain_A", type=str, default="publish",
+ "--mode_chain_A",
+ type=str,
+ default="publish",
choices=["listen", "publish", "disable", "none", None],
- help="The mode of transmission for the first method in the chain"
+ help="The mode of transmission for the first method in the chain",
)
parser.add_argument(
- "--mode_chain_B", type=str, default="listen",
+ "--mode_chain_B",
+ type=str,
+ default="listen",
choices=["listen", "publish", "disable", "none", None],
- help="The mode of transmission for the second method in the chain"
+ help="The mode of transmission for the second method in the chain",
)
parser.add_argument(
- "--mware_chain_A", type=str, default=DEFAULT_COMMUNICATOR,
+ "--mware_chain_A",
+ type=str,
+ default=DEFAULT_COMMUNICATOR,
choices=MiddlewareCommunicator.get_communicators(),
- help="The middleware to use for transmission of the first method in the chain"
+ help="The middleware to use for transmission of the first method in the chain",
)
parser.add_argument(
- "--mware_chain_B", type=str, default=DEFAULT_COMMUNICATOR,
+ "--mware_chain_B",
+ type=str,
+ default=DEFAULT_COMMUNICATOR,
choices=MiddlewareCommunicator.get_communicators(),
- help="The middleware to use for transmission of the second method in the chain"
+ help="The middleware to use for transmission of the second method in the chain",
)
return parser.parse_args()
def main(args):
- """Main function to initiate ForwardCls class and communication."""
+ """
+ Main function to initiate ForwardCls class and communication.
+ """
forward = ForwardCls()
forward.activate_communication(forward.read_chain_A, mode=args.mode_chain_A)
forward.activate_communication(forward.read_chain_B, mode=args.mode_chain_B)
while True:
- msg, = forward.read_chain_A(mware_chain_A=args.mware_chain_A,
- msg=f"This argument message was sent from read_chain_A transmitted over "
- f"{args.mware_chain_A}")
+ (msg,) = forward.read_chain_A(
+ mware_chain_A=args.mware_chain_A,
+ msg=f"This argument message was sent from read_chain_A transmitted over "
+ f"{args.mware_chain_A}",
+ )
if msg is not None:
print(msg)
- msg, = forward.read_chain_B(mware_chain_B=args.mware_chain_B,
- msg=f"{msg}. It was then forwarded to read_chain_B over {args.mware_chain_B}")
+ (msg,) = forward.read_chain_B(
+ mware_chain_B=args.mware_chain_B,
+ msg=f"{msg}. It was then forwarded to read_chain_B over {args.mware_chain_B}",
+ )
if msg is not None:
if args.mode_chain_B == "listen":
- print(f"{msg}. This message is the last in the chain received over {args.mware_chain_B}")
+ print(
+ f"{msg}. This message is the last in the chain received over {args.mware_chain_B}"
+ )
else:
print(msg)
time.sleep(0.1)
diff --git a/examples/communication_schemes/mirroring_example.py b/examples/communication_schemes/mirroring_example.py
index 8ba09c0..06bd089 100644
--- a/examples/communication_schemes/mirroring_example.py
+++ b/examples/communication_schemes/mirroring_example.py
@@ -41,37 +41,56 @@
class MirrorCls(MiddlewareCommunicator):
@MiddlewareCommunicator.register(
- 'NativeObject', '$mware', 'MirrorCls',
- '/example/read_msg',
- carrier='tcp', should_wait='$blocking')
- def read_msg(self, mware=None, msg='', blocking=True):
- """Exchange messages and mirror user input."""
- msg_ip = input('Type message: ')
- obj = {'msg': msg, 'msg_ip': msg_ip}
- return obj,
+ "NativeObject",
+ "$mware",
+ "MirrorCls",
+ "/example/read_msg",
+ carrier="tcp",
+ should_wait="$blocking",
+ )
+ def read_msg(self, mware=None, msg="", blocking=True):
+ """
+ Exchange messages and mirror user input.
+ """
+ msg_ip = input("Type message: ")
+ obj = {"msg": msg, "msg_ip": msg_ip}
+ return (obj,)
+
def parse_args():
- """Parse command line arguments."""
+ """
+ Parse command line arguments.
+ """
parser = argparse.ArgumentParser(description="Mirroring Example using Wrapyfi.")
parser.add_argument(
- "--mode", type=str, default="listen",
+ "--mode",
+ type=str,
+ default="listen",
choices={"publish", "listen", "request", "reply"},
- help="The communication mode (publish, listen, request, reply)"
+ help="The communication mode (publish, listen, request, reply)",
)
parser.add_argument(
- "--mware", type=str, default=DEFAULT_COMMUNICATOR,
+ "--mware",
+ type=str,
+ default=DEFAULT_COMMUNICATOR,
choices=MiddlewareCommunicator.get_communicators(),
- help="The middleware to use for transmission"
+ help="The middleware to use for transmission",
)
return parser.parse_args()
+
def main(args):
- """Main function to initiate MirrorCls class and communication."""
+ """
+ Main function to initiate MirrorCls class and communication.
+ """
mirror = MirrorCls()
mirror.activate_communication(MirrorCls.read_msg, mode=args.mode)
while True:
- msg_object, = mirror.read_msg(mware=args.mware, msg=f"This argument message was sent by the {args.mode} script")
+ (msg_object,) = mirror.read_msg(
+ mware=args.mware,
+ msg=f"This argument message was sent by the {args.mode} script",
+ )
if msg_object is not None:
print(msg_object)
diff --git a/examples/custom_msgs/ros2_message_example.py b/examples/custom_msgs/ros2_message_example.py
index 0201720..49a3239 100644
--- a/examples/custom_msgs/ros2_message_example.py
+++ b/examples/custom_msgs/ros2_message_example.py
@@ -35,15 +35,23 @@
class Notifier(MiddlewareCommunicator):
@MiddlewareCommunicator.register(
- "ROS2Message", "ros2", "Notifier", "/notify/test_ros2_msg_str_exchange",
- should_wait=True
+ "ROS2Message",
+ "ros2",
+ "Notifier",
+ "/notify/test_ros2_msg_str_exchange",
+ should_wait=True,
)
@MiddlewareCommunicator.register(
- "ROS2Message", "ros2", "Notifier", "/notify/test_ros2_msg_pose_exchange",
- should_wait=True
+ "ROS2Message",
+ "ros2",
+ "Notifier",
+ "/notify/test_ros2_msg_pose_exchange",
+ should_wait=True,
)
def send_message(self):
- """Exchange ROS 2 messages over ROS 2."""
+ """
+ Exchange ROS 2 messages over ROS 2.
+ """
msg = input("Type your message: ")
quat = Quaternion()
quat.x = 0.1
@@ -57,18 +65,26 @@ def send_message(self):
def parse_args():
- """Parse command line arguments."""
- parser = argparse.ArgumentParser(description="A message publisher and listener for ROS 2 messages using Wrapyfi.")
+ """
+ Parse command line arguments.
+ """
+ parser = argparse.ArgumentParser(
+ description="A message publisher and listener for ROS 2 messages using Wrapyfi."
+ )
parser.add_argument(
- "--mode", type=str, default="publish",
+ "--mode",
+ type=str,
+ default="publish",
choices={"publish", "listen"},
- help="The transmission mode"
+ help="The transmission mode",
)
return parser.parse_args()
def main(args):
- """Main function to initiate Notify class and communication."""
+ """
+ Main function to initiate Notify class and communication.
+ """
ros2_message = Notifier()
ros2_message.activate_communication(Notifier.send_message, mode=args.mode)
diff --git a/examples/custom_msgs/ros_message_example.py b/examples/custom_msgs/ros_message_example.py
index 8c2b255..1cf1694 100755
--- a/examples/custom_msgs/ros_message_example.py
+++ b/examples/custom_msgs/ros_message_example.py
@@ -35,15 +35,23 @@
class Notifier(MiddlewareCommunicator):
@MiddlewareCommunicator.register(
- "ROSMessage", "ros", "Notifier", "/notify/test_ros_msg_str_exchange",
- should_wait=True
+ "ROSMessage",
+ "ros",
+ "Notifier",
+ "/notify/test_ros_msg_str_exchange",
+ should_wait=True,
)
@MiddlewareCommunicator.register(
- "ROSMessage", "ros", "Notifier", "/notify/test_ros_msg_pose_exchange",
- should_wait=True
+ "ROSMessage",
+ "ros",
+ "Notifier",
+ "/notify/test_ros_msg_pose_exchange",
+ should_wait=True,
)
def send_message(self):
- """Exchange ROS messages over ROS."""
+ """
+ Exchange ROS messages over ROS.
+ """
msg = input("Type your message: ")
quat = Quaternion()
quat.x = 0.1
@@ -57,18 +65,26 @@ def send_message(self):
def parse_args():
- """Parse command line arguments."""
- parser = argparse.ArgumentParser(description="A message publisher and listener for ROS messages using Wrapyfi.")
+ """
+ Parse command line arguments.
+ """
+ parser = argparse.ArgumentParser(
+ description="A message publisher and listener for ROS messages using Wrapyfi."
+ )
parser.add_argument(
- "--mode", type=str, default="publish",
+ "--mode",
+ type=str,
+ default="publish",
choices={"publish", "listen"},
- help="The transmission mode"
+ help="The transmission mode",
)
return parser.parse_args()
def main(args):
- """Main function to initiate Notify class and communication."""
+ """
+ Main function to initiate Notify class and communication.
+ """
ros_message = Notifier()
ros_message.activate_communication(Notifier.send_message, mode=args.mode)
diff --git a/examples/custom_msgs/ros_parameter_example.py b/examples/custom_msgs/ros_parameter_example.py
index 28c45c2..983e28d 100755
--- a/examples/custom_msgs/ros_parameter_example.py
+++ b/examples/custom_msgs/ros_parameter_example.py
@@ -32,44 +32,73 @@
class Notifier(MiddlewareCommunicator):
@MiddlewareCommunicator.register(
- "Properties", "ros", "Notifier", "/notify/test_property_exchange",
- should_wait=True
+ "Properties",
+ "ros",
+ "Notifier",
+ "/notify/test_property_exchange",
+ should_wait=True,
)
@MiddlewareCommunicator.register(
- "Properties", "ros", "Notifier", "/notify/test_property_exchange/a",
- should_wait=True
+ "Properties",
+ "ros",
+ "Notifier",
+ "/notify/test_property_exchange/a",
+ should_wait=True,
)
@MiddlewareCommunicator.register(
- "Properties", "ros", "Notifier", "/notify/test_property_exchange/e",
- persistent=False, should_wait=True
+ "Properties",
+ "ros",
+ "Notifier",
+ "/notify/test_property_exchange/e",
+ persistent=False,
+ should_wait=True,
)
def set_property(self):
- """Exchange ROS properties over ROS."""
+ """
+ Exchange ROS properties over ROS.
+ """
ret_str = input("Type your message: ")
- ret_multiprops = {"b": [1,2,3,4], "c": False, "d": 12.3}
- ret_non_persistent = "Non-persistent property which should be deleted on closure"
+ ret_multiprops = {"b": [1, 2, 3, 4], "c": False, "d": 12.3}
+ ret_non_persistent = (
+ "Non-persistent property which should be deleted on closure"
+ )
return ret_multiprops, ret_str, ret_non_persistent
def parse_args():
- """Parse command line arguments."""
- parser = argparse.ArgumentParser(description="A message publisher and listener for ROS properties using Wrapyfi.")
+ """
+ Parse command line arguments.
+ """
+ parser = argparse.ArgumentParser(
+ description="A message publisher and listener for ROS properties using Wrapyfi."
+ )
parser.add_argument(
- "--mode", type=str, default="publish",
+ "--mode",
+ type=str,
+ default="publish",
choices={"publish", "listen"},
- help="The transmission mode"
+ help="The transmission mode",
)
return parser.parse_args()
def main(args):
- """Main function to initiate Notify class and communication."""
+ """
+ Main function to initiate Notify class and communication.
+ """
ros_message = Notifier()
ros_message.activate_communication(Notifier.set_property, mode=args.mode)
while True:
- my_dict_message, my_string_message, my_nonpersistent_message = ros_message.set_property()
- print("Method result:", my_string_message, my_dict_message, my_nonpersistent_message)
+ my_dict_message, my_string_message, my_nonpersistent_message = (
+ ros_message.set_property()
+ )
+ print(
+ "Method result:",
+ my_string_message,
+ my_dict_message,
+ my_nonpersistent_message,
+ )
if __name__ == "__main__":
diff --git a/examples/encoders/astropy_example.py b/examples/encoders/astropy_example.py
index a010149..3dd0aff 100644
--- a/examples/encoders/astropy_example.py
+++ b/examples/encoders/astropy_example.py
@@ -41,57 +41,72 @@
# Modifying the WRAPYFI_PLUGINS_PATH environment variable to include the plugins directory
script_dir = os.path.dirname(os.path.realpath(__file__))
-if 'WRAPYFI_PLUGINS_PATH' in os.environ:
- os.environ['WRAPYFI_PLUGINS_PATH'] += os.pathsep + script_dir
+if "WRAPYFI_PLUGINS_PATH" in os.environ:
+ os.environ["WRAPYFI_PLUGINS_PATH"] += os.pathsep + script_dir
else:
- os.environ['WRAPYFI_PLUGINS_PATH'] = script_dir
+ os.environ["WRAPYFI_PLUGINS_PATH"] = script_dir
class Notifier(MiddlewareCommunicator):
@MiddlewareCommunicator.register(
- "NativeObject", "$mware", "Notifier", "/notify/test_astropy_exchange",
- carrier="tcp", should_wait=True
+ "NativeObject",
+ "$mware",
+ "Notifier",
+ "/notify/test_astropy_exchange",
+ carrier="tcp",
+ should_wait=True,
)
def exchange_object(self, mware=None):
- """Exchange messages with Astropy Tables and other native Python objects."""
+ """
+ Exchange messages with Astropy Tables and other native Python objects.
+ """
msg = input("Type your message: ")
# Creating an example Astropy Table
t = Table()
- t['name'] = ['source 1', 'source 2', 'source 3']
- t['flux'] = [1.2, 2.2, 3.1]
+ t["name"] = ["source 1", "source 2", "source 3"]
+ t["flux"] = [1.2, 2.2, 3.1]
ret = {
"message": msg,
"astropy_table": t,
}
- return ret,
+ return (ret,)
def parse_args():
- """Parse command line arguments."""
+ """
+ Parse command line arguments.
+ """
parser = argparse.ArgumentParser(
- description="A message publisher and listener for native Python objects and Astropy Tables.")
+ description="A message publisher and listener for native Python objects and Astropy Tables."
+ )
parser.add_argument(
- "--mode", type=str, default="publish",
+ "--mode",
+ type=str,
+ default="publish",
choices={"publish", "listen"},
- help="The transmission mode"
+ help="The transmission mode",
)
parser.add_argument(
- "--mware", type=str, default=DEFAULT_COMMUNICATOR,
+ "--mware",
+ type=str,
+ default=DEFAULT_COMMUNICATOR,
choices=MiddlewareCommunicator.get_communicators(),
- help="The middleware to use for transmission"
+ help="The middleware to use for transmission",
)
return parser.parse_args()
def main(args):
- """Main function to initiate Notifier class and communication."""
+ """
+ Main function to initiate Notifier class and communication.
+ """
notifier = Notifier()
notifier.activate_communication(Notifier.exchange_object, mode=args.mode)
while True:
- msg_object, = notifier.exchange_object(mware=args.mware)
+ (msg_object,) = notifier.exchange_object(mware=args.mware)
print("Method result:", msg_object)
diff --git a/examples/encoders/cupy_example.py b/examples/encoders/cupy_example.py
index 28af7fb..854b30c 100644
--- a/examples/encoders/cupy_example.py
+++ b/examples/encoders/cupy_example.py
@@ -39,44 +39,60 @@
class CuPyNotifier(MiddlewareCommunicator):
@MiddlewareCommunicator.register(
- "NativeObject", "$mware", "CuPyNotifier", "/notify/test_cupy_exchange",
- carrier="", should_wait=True,
- listener_kwargs=dict(load_cupy_device=cp.cuda.Device(0))
+ "NativeObject",
+ "$mware",
+ "CuPyNotifier",
+ "/notify/test_cupy_exchange",
+ carrier="",
+ should_wait=True,
+ listener_kwargs=dict(load_cupy_device=cp.cuda.Device(0)),
)
def exchange_object(self, mware=None):
- """Exchange messages with CuPy tensors and other native Python objects."""
+ """
+ Exchange messages with CuPy tensors and other native Python objects.
+ """
msg = input("Type your message: ")
ret = {
"message": msg,
"cupy_ones_cuda": cp.ones((2, 4), dtype=cp.float32),
- "cupy_zeros_cuda": cp.zeros((2, 3), dtype=cp.float32)
+ "cupy_zeros_cuda": cp.zeros((2, 3), dtype=cp.float32),
}
- return ret,
+ return (ret,)
def parse_args():
- """Parse command line arguments."""
- parser = argparse.ArgumentParser(description="A message publisher and listener for native Python objects and CuPy tensors.")
+ """
+ Parse command line arguments.
+ """
+ parser = argparse.ArgumentParser(
+ description="A message publisher and listener for native Python objects and CuPy tensors."
+ )
parser.add_argument(
- "--mode", type=str, default="publish",
+ "--mode",
+ type=str,
+ default="publish",
choices={"publish", "listen"},
- help="The transmission mode"
+ help="The transmission mode",
)
parser.add_argument(
- "--mware", type=str, default=DEFAULT_COMMUNICATOR,
+ "--mware",
+ type=str,
+ default=DEFAULT_COMMUNICATOR,
choices=MiddlewareCommunicator.get_communicators(),
- help="The middleware to use for transmission"
+ help="The middleware to use for transmission",
)
return parser.parse_args()
def main(args):
- """Main function to initiate CuPyNotifier class and communication."""
+ """
+ Main function to initiate CuPyNotifier class and communication.
+ """
notifier = CuPyNotifier()
notifier.activate_communication(CuPyNotifier.exchange_object, mode=args.mode)
while True:
- msg_object, = notifier.exchange_object(mware=args.mware)
+ (msg_object,) = notifier.exchange_object(mware=args.mware)
print("Method result:", msg_object)
diff --git a/examples/encoders/dask_example.py b/examples/encoders/dask_example.py
index f914156..dbcc706 100644
--- a/examples/encoders/dask_example.py
+++ b/examples/encoders/dask_example.py
@@ -40,19 +40,28 @@
class Notifier(MiddlewareCommunicator):
@MiddlewareCommunicator.register(
- "NativeObject", "$mware", "Notifier", "/notify/test_dask_exchange",
- carrier="tcp", should_wait=True
+ "NativeObject",
+ "$mware",
+ "Notifier",
+ "/notify/test_dask_exchange",
+ carrier="tcp",
+ should_wait=True,
)
def exchange_object(self, mware=None):
- """Exchange messages with Dask arrays/dataframes and other native Python objects."""
+ """
+ Exchange messages with Dask arrays/dataframes and other native Python objects.
+ """
msg = input("Type your message: ")
# Creating an example Dask DataFrame
- df = pd.DataFrame({
- 'num_legs': [4, 2, 0, 4],
- 'num_wings': [0, 2, 0, 0],
- 'num_specimen_seen': [10, 2, 1, 8]
- }, index=['falcon', 'parrot', 'fish', 'dog'])
+ df = pd.DataFrame(
+ {
+ "num_legs": [4, 2, 0, 4],
+ "num_wings": [0, 2, 0, 0],
+ "num_specimen_seen": [10, 2, 1, 8],
+ },
+ index=["falcon", "parrot", "fish", "dog"],
+ )
ddf = dd.from_pandas(df, npartitions=2)
@@ -67,32 +76,42 @@ def exchange_object(self, mware=None):
"dask_array": darray,
"dask_series": dds,
}
- return ret,
+ return (ret,)
def parse_args():
- """Parse command line arguments."""
- parser = argparse.ArgumentParser(description="A message publisher and listener for native Python objects and Dask arrays/dataframes.")
+ """
+ Parse command line arguments.
+ """
+ parser = argparse.ArgumentParser(
+ description="A message publisher and listener for native Python objects and Dask arrays/dataframes."
+ )
parser.add_argument(
- "--mode", type=str, default="publish",
+ "--mode",
+ type=str,
+ default="publish",
choices={"publish", "listen"},
- help="The transmission mode"
+ help="The transmission mode",
)
parser.add_argument(
- "--mware", type=str, default=DEFAULT_COMMUNICATOR,
+ "--mware",
+ type=str,
+ default=DEFAULT_COMMUNICATOR,
choices=MiddlewareCommunicator.get_communicators(),
- help="The middleware to use for transmission"
+ help="The middleware to use for transmission",
)
return parser.parse_args()
def main(args):
- """Main function to initiate Notifier class and communication."""
+ """
+ Main function to initiate Notifier class and communication.
+ """
notifier = Notifier()
notifier.activate_communication(Notifier.exchange_object, mode=args.mode)
while True:
- msg_object, = notifier.exchange_object(mware=args.mware)
+ (msg_object,) = notifier.exchange_object(mware=args.mware)
# Compute and print the actual values of the Dask objects
for key, value in msg_object.items():
diff --git a/examples/encoders/jax_example.py b/examples/encoders/jax_example.py
index 8778e56..b5022f8 100755
--- a/examples/encoders/jax_example.py
+++ b/examples/encoders/jax_example.py
@@ -38,42 +38,58 @@
class Notifier(MiddlewareCommunicator):
@MiddlewareCommunicator.register(
- "NativeObject", "$mware", "Notifier", "/notify/test_jax_exchange",
- carrier="tcp", should_wait=True
+ "NativeObject",
+ "$mware",
+ "Notifier",
+ "/notify/test_jax_exchange",
+ carrier="tcp",
+ should_wait=True,
)
def exchange_object(self, mware=None):
- """Exchange messages with JAX arrays and other native Python objects."""
+ """
+ Exchange messages with JAX arrays and other native Python objects.
+ """
msg = input("Type your message: ")
ret = {
"message": msg,
"jax_ones": jnp.ones((2, 4)),
}
- return ret,
+ return (ret,)
def parse_args():
- """Parse command line arguments."""
- parser = argparse.ArgumentParser(description="A message publisher and listener for native Python objects and JAX arrays.")
+ """
+ Parse command line arguments.
+ """
+ parser = argparse.ArgumentParser(
+ description="A message publisher and listener for native Python objects and JAX arrays."
+ )
parser.add_argument(
- "--mode", type=str, default="publish",
+ "--mode",
+ type=str,
+ default="publish",
choices={"publish", "listen"},
- help="The transmission mode"
+ help="The transmission mode",
)
parser.add_argument(
- "--mware", type=str, default=DEFAULT_COMMUNICATOR,
+ "--mware",
+ type=str,
+ default=DEFAULT_COMMUNICATOR,
choices=MiddlewareCommunicator.get_communicators(),
- help="The middleware to use for transmission"
+ help="The middleware to use for transmission",
)
return parser.parse_args()
def main(args):
- """Main function to initiate Notifier class and communication."""
+ """
+ Main function to initiate Notifier class and communication.
+ """
notifier = Notifier()
notifier.activate_communication(Notifier.exchange_object, mode=args.mode)
while True:
- msg_object, = notifier.exchange_object(mware=args.mware)
+ (msg_object,) = notifier.exchange_object(mware=args.mware)
print("Method result:", msg_object)
diff --git a/examples/encoders/mxnet_example.py b/examples/encoders/mxnet_example.py
index 1ff24a2..d34c9dc 100755
--- a/examples/encoders/mxnet_example.py
+++ b/examples/encoders/mxnet_example.py
@@ -40,44 +40,63 @@
class Notifier(MiddlewareCommunicator):
@MiddlewareCommunicator.register(
- "NativeObject", "$mware", "Notifier", "/notify/test_mxnet_exchange",
- carrier="", should_wait=True,
- listener_kwargs=dict(load_mxnet_device=mxnet.gpu(0), map_mxnet_devices={'cpu': 'cuda:0', 'gpu:0': 'cpu'})
+ "NativeObject",
+ "$mware",
+ "Notifier",
+ "/notify/test_mxnet_exchange",
+ carrier="",
+ should_wait=True,
+ listener_kwargs=dict(
+ load_mxnet_device=mxnet.gpu(0),
+ map_mxnet_devices={"cpu": "cuda:0", "gpu:0": "cpu"},
+ ),
)
def exchange_object(self, mware=None):
- """Exchange messages with MXNet tensors and other native Python objects."""
+ """
+ Exchange messages with MXNet tensors and other native Python objects.
+ """
msg = input("Type your message: ")
ret = {
"message": msg,
"mx_ones": mxnet.nd.ones((2, 4), ctx=mxnet.cpu()),
- "mxnet_zeros_cuda": mxnet.nd.zeros((2, 3), ctx=mxnet.gpu(0))
+ "mxnet_zeros_cuda": mxnet.nd.zeros((2, 3), ctx=mxnet.gpu(0)),
}
- return ret,
+ return (ret,)
def parse_args():
- """Parse command line arguments."""
- parser = argparse.ArgumentParser(description="A message publisher and listener for native Python objects and MXNet tensors.")
+ """
+ Parse command line arguments.
+ """
+ parser = argparse.ArgumentParser(
+ description="A message publisher and listener for native Python objects and MXNet tensors."
+ )
parser.add_argument(
- "--mode", type=str, default="publish",
+ "--mode",
+ type=str,
+ default="publish",
choices={"publish", "listen"},
- help="The transmission mode"
+ help="The transmission mode",
)
parser.add_argument(
- "--mware", type=str, default=DEFAULT_COMMUNICATOR,
+ "--mware",
+ type=str,
+ default=DEFAULT_COMMUNICATOR,
choices=MiddlewareCommunicator.get_communicators(),
- help="The middleware to use for transmission"
+ help="The middleware to use for transmission",
)
return parser.parse_args()
def main(args):
- """Main function to initiate Notifier class and communication."""
+ """
+ Main function to initiate Notifier class and communication.
+ """
notifier = Notifier()
notifier.activate_communication(Notifier.exchange_object, mode=args.mode)
while True:
- msg_object, = notifier.exchange_object(mware=args.mware)
+ (msg_object,) = notifier.exchange_object(mware=args.mware)
print("Method result:", msg_object)
diff --git a/examples/encoders/numpy_pandas_example.py b/examples/encoders/numpy_pandas_example.py
index f069280..737fa20 100755
--- a/examples/encoders/numpy_pandas_example.py
+++ b/examples/encoders/numpy_pandas_example.py
@@ -40,46 +40,65 @@
class Notifier(MiddlewareCommunicator):
@MiddlewareCommunicator.register(
- "NativeObject", "$mware", "Notifier", "/notify/test_native_exchange",
- carrier="tcp", should_wait=True
+ "NativeObject",
+ "$mware",
+ "Notifier",
+ "/notify/test_native_exchange",
+ carrier="tcp",
+ should_wait=True,
)
def exchange_object(self, mware=None):
- """Exchange messages with NumPy arrays, pandas series/dataframes, and other native Python objects."""
+ """
+ Exchange messages with NumPy arrays, pandas series/dataframes, and other native Python objects.
+ """
msg = input("Type your message: ")
ret = {
"message": msg,
"numpy_array": np.ones((2, 4)),
"pandas_series": pd.Series([1, 3, 5, np.nan, 6, 8]),
- "pandas_dataframe": pd.DataFrame(np.random.randn(6, 4), index=pd.date_range("20130101", periods=6),
- columns=list("ABCD")),
+ "pandas_dataframe": pd.DataFrame(
+ np.random.randn(6, 4),
+ index=pd.date_range("20130101", periods=6),
+ columns=list("ABCD"),
+ ),
}
- return ret,
+ return (ret,)
def parse_args():
- """Parse command line arguments."""
- parser = argparse.ArgumentParser(description="A message publisher and listener for native Python objects, NumPy arrays, and pandas series/dataframes.")
+ """
+ Parse command line arguments.
+ """
+ parser = argparse.ArgumentParser(
+ description="A message publisher and listener for native Python objects, NumPy arrays, and pandas series/dataframes."
+ )
parser.add_argument(
- "--mode", type=str, default="publish",
+ "--mode",
+ type=str,
+ default="publish",
choices={"publish", "listen"},
- help="The transmission mode"
+ help="The transmission mode",
)
parser.add_argument(
- "--mware", type=str, default=DEFAULT_COMMUNICATOR,
+ "--mware",
+ type=str,
+ default=DEFAULT_COMMUNICATOR,
choices=MiddlewareCommunicator.get_communicators(),
- help="The middleware to use for transmission"
+ help="The middleware to use for transmission",
)
return parser.parse_args()
def main(args):
- """Main function to initiate Notifier class and communication."""
+ """
+ Main function to initiate Notifier class and communication.
+ """
notifier = Notifier()
notifier.activate_communication(Notifier.exchange_object, mode=args.mode)
while True:
- msg_object, = notifier.exchange_object(mware=args.mware)
+ (msg_object,) = notifier.exchange_object(mware=args.mware)
print("Method result:", msg_object)
diff --git a/examples/encoders/paddlepaddle_example.py b/examples/encoders/paddlepaddle_example.py
index bce99ef..b5e71e3 100755
--- a/examples/encoders/paddlepaddle_example.py
+++ b/examples/encoders/paddlepaddle_example.py
@@ -29,6 +29,7 @@
"""
import argparse
+
try:
import paddle
except ImportError:
@@ -39,44 +40,67 @@
class Notifier(MiddlewareCommunicator):
@MiddlewareCommunicator.register(
- "NativeObject", "$mware", "Notifier", "/notify/test_paddle_exchange",
- carrier="", should_wait=True,
- listener_kwargs=dict(load_paddle_device='gpu:0', map_paddle_devices={'cpu': 'cuda:0', 'gpu:0': 'cpu'})
+ "NativeObject",
+ "$mware",
+ "Notifier",
+ "/notify/test_paddle_exchange",
+ carrier="",
+ should_wait=True,
+ listener_kwargs=dict(
+ load_paddle_device="gpu:0",
+ map_paddle_devices={"cpu": "cuda:0", "gpu:0": "cpu"},
+ ),
)
def exchange_object(self, mware=None):
- """Exchange messages with PaddlePaddle tensors and other native Python objects."""
+ """
+ Exchange messages with PaddlePaddle tensors and other native Python objects.
+ """
msg = input("Type your message: ")
ret = {
"message": msg,
- "paddle_ones": paddle.ones([2, 4], dtype='float32', place=paddle.CPUPlace()),
- "paddle_zeros_cuda": paddle.zeros([2, 3], dtype='float32', place=paddle.CUDAPlace(0))
+ "paddle_ones": paddle.ones(
+ [2, 4], dtype="float32", place=paddle.CPUPlace()
+ ),
+ "paddle_zeros_cuda": paddle.zeros(
+ [2, 3], dtype="float32", place=paddle.CUDAPlace(0)
+ ),
}
- return ret,
+ return (ret,)
def parse_args():
- """Parse command line arguments."""
- parser = argparse.ArgumentParser(description="A message publisher and listener for native Python objects and PaddlePaddle tensors.")
+ """
+ Parse command line arguments.
+ """
+ parser = argparse.ArgumentParser(
+ description="A message publisher and listener for native Python objects and PaddlePaddle tensors."
+ )
parser.add_argument(
- "--mode", type=str, default="publish",
+ "--mode",
+ type=str,
+ default="publish",
choices={"publish", "listen"},
- help="The transmission mode"
+ help="The transmission mode",
)
parser.add_argument(
- "--mware", type=str, default=DEFAULT_COMMUNICATOR,
+ "--mware",
+ type=str,
+ default=DEFAULT_COMMUNICATOR,
choices=MiddlewareCommunicator.get_communicators(),
- help="The middleware to use for transmission"
+ help="The middleware to use for transmission",
)
return parser.parse_args()
def main(args):
- """Main function to initiate Notifier class and communication."""
+ """
+ Main function to initiate Notifier class and communication.
+ """
notifier = Notifier()
notifier.activate_communication(Notifier.exchange_object, mode=args.mode)
while True:
- msg_object, = notifier.exchange_object(mware=args.mware)
+ (msg_object,) = notifier.exchange_object(mware=args.mware)
print("Method result:", msg_object)
diff --git a/examples/encoders/pillow_example.py b/examples/encoders/pillow_example.py
index 7fa29f5..74edd05 100755
--- a/examples/encoders/pillow_example.py
+++ b/examples/encoders/pillow_example.py
@@ -40,44 +40,56 @@
class Notifier(MiddlewareCommunicator):
@MiddlewareCommunicator.register(
- "NativeObject", "$mware", "Notify", "/notify/test_native_exchange",
- carrier="", should_wait=True
+ "NativeObject",
+ "$mware",
+ "Notify",
+ "/notify/test_native_exchange",
+ carrier="",
+ should_wait=True,
)
def exchange_object(self, mware=None):
msg = input("Type your message: ")
imarray = np.random.rand(100, 100, 3) * 255
ret = {
"message": msg,
- "pillow_random": Image.fromarray(imarray.astype('uint8')).convert('RGBA'),
+ "pillow_random": Image.fromarray(imarray.astype("uint8")).convert("RGBA"),
"pillow_png": Image.open("../../assets/wrapyfi.png"),
- "pillow_jpg": Image.open("../../assets/wrapyfi.jpg")
+ "pillow_jpg": Image.open("../../assets/wrapyfi.jpg"),
}
- return ret,
+ return (ret,)
def parse_args():
- """Parse command line arguments."""
+ """
+ Parse command line arguments.
+ """
parser = argparse.ArgumentParser()
parser.add_argument(
- "--mode", type=str, default="publish",
+ "--mode",
+ type=str,
+ default="publish",
choices={"publish", "listen"},
- help="The transmission mode"
+ help="The transmission mode",
)
parser.add_argument(
- "--mware", type=str, default=DEFAULT_COMMUNICATOR,
+ "--mware",
+ type=str,
+ default=DEFAULT_COMMUNICATOR,
choices=MiddlewareCommunicator.get_communicators(),
- help="The middleware to use for transmission"
+ help="The middleware to use for transmission",
)
return parser.parse_args()
def main(args):
- """Main function to initiate Notify class and communication."""
+ """
+ Main function to initiate Notify class and communication.
+ """
notifier = Notifier()
notifier.activate_communication(Notifier.exchange_object, mode=args.mode)
while True:
- msg_object, = notifier.exchange_object(mware=args.mware)
+ (msg_object,) = notifier.exchange_object(mware=args.mware)
print("Method result:", msg_object)
diff --git a/examples/encoders/pint_example.py b/examples/encoders/pint_example.py
index 040b068..81a2d45 100644
--- a/examples/encoders/pint_example.py
+++ b/examples/encoders/pint_example.py
@@ -38,53 +38,76 @@
class Notifier(MiddlewareCommunicator):
@MiddlewareCommunicator.register(
- "NativeObject", "$mware", "Notifier", "/notify/test_pint_exchange",
- carrier="tcp", should_wait=True
+ "NativeObject",
+ "$mware",
+ "Notifier",
+ "/notify/test_pint_exchange",
+ carrier="tcp",
+ should_wait=True,
)
def exchange_object(self, mware=None):
- """Exchange messages with Pint Quantities and other native Python objects."""
+ """
+ Exchange messages with Pint Quantities and other native Python objects.
+ """
msg = input("Type your message: ")
# Creating a Pint Quantity
ureg = pint.UnitRegistry()
- quantity = 42 * ureg.parse_expression('meter')
+ quantity = 42 * ureg.parse_expression("meter")
# Constructing the message object to be transmitted
- ret = [{"message": msg,
- "pint_quantity": quantity,
- "list": [1, 2, 3]},
- "string",
- 0.4344,
- {"other": (1, 2, 3, 4.32,)}]
- return ret,
+ ret = [
+ {"message": msg, "pint_quantity": quantity, "list": [1, 2, 3]},
+ "string",
+ 0.4344,
+ {
+ "other": (
+ 1,
+ 2,
+ 3,
+ 4.32,
+ )
+ },
+ ]
+ return (ret,)
def parse_args():
- """Parse command line arguments."""
- parser = argparse.ArgumentParser(description="A message publisher and listener for native Python objects and Pint Quantities.")
+ """
+ Parse command line arguments.
+ """
+ parser = argparse.ArgumentParser(
+ description="A message publisher and listener for native Python objects and Pint Quantities."
+ )
parser.add_argument(
- "--mode", type=str, default="publish",
+ "--mode",
+ type=str,
+ default="publish",
choices={"publish", "listen"},
- help="The transmission mode"
+ help="The transmission mode",
)
parser.add_argument(
- "--mware", type=str, default=DEFAULT_COMMUNICATOR,
+ "--mware",
+ type=str,
+ default=DEFAULT_COMMUNICATOR,
choices=MiddlewareCommunicator.get_communicators(),
- help="The middleware to use for transmission"
+ help="The middleware to use for transmission",
)
return parser.parse_args()
def main(args):
- """Main function to initiate Notifier class and communication."""
+ """
+ Main function to initiate Notifier class and communication.
+ """
notifier = Notifier()
notifier.activate_communication(Notifier.exchange_object, mode=args.mode)
while True:
- msg_object, = notifier.exchange_object(mware=args.mware)
+ (msg_object,) = notifier.exchange_object(mware=args.mware)
print("Method result:", msg_object)
if __name__ == "__main__":
args = parse_args()
- main(args)
\ No newline at end of file
+ main(args)
diff --git a/examples/encoders/plugins/astropy_tables.py b/examples/encoders/plugins/astropy_tables.py
index f3fe48f..375c1b8 100644
--- a/examples/encoders/plugins/astropy_tables.py
+++ b/examples/encoders/plugins/astropy_tables.py
@@ -24,6 +24,7 @@
try:
from astropy.table import Table
+
HAVE_ASTROPY = True
except ImportError:
HAVE_ASTROPY = False
@@ -54,9 +55,9 @@ def encode(self, obj, *args, **kwargs):
- '__wrapyfi__': A tuple containing the class name and the encoded data string
"""
memfile = io.BytesIO()
- obj.write(memfile, format='fits')
+ obj.write(memfile, format="fits")
memfile.seek(0)
- obj_data = base64.b64encode(memfile.getvalue()).decode('ascii')
+ obj_data = base64.b64encode(memfile.getvalue()).decode("ascii")
memfile.close()
return True, dict(__wrapyfi__=(str(self.__class__.__name__), obj_data))
@@ -74,8 +75,8 @@ def decode(self, obj_type, obj_full, *args, **kwargs):
"""
encoded_str = obj_full[1]
if isinstance(encoded_str, str):
- encoded_str = encoded_str.encode('ascii')
+ encoded_str = encoded_str.encode("ascii")
with io.BytesIO(base64.b64decode(encoded_str)) as memfile:
memfile.seek(0)
- obj = Table.read(memfile, format='fits')
+ obj = Table.read(memfile, format="fits")
return True, obj
diff --git a/examples/encoders/pyarrow_example.py b/examples/encoders/pyarrow_example.py
index ec87244..9027575 100755
--- a/examples/encoders/pyarrow_example.py
+++ b/examples/encoders/pyarrow_example.py
@@ -27,6 +27,7 @@
"""
import argparse
+
try:
import pyarrow as pa
except ImportError:
@@ -37,46 +38,61 @@
class Notifier(MiddlewareCommunicator):
@MiddlewareCommunicator.register(
- "NativeObject", "$mware", "Notifier", "/notify/test_arrow_exchange",
- carrier="tcp", should_wait=True
+ "NativeObject",
+ "$mware",
+ "Notifier",
+ "/notify/test_arrow_exchange",
+ carrier="tcp",
+ should_wait=True,
)
def exchange_object(self, mware=None):
- """Exchange messages with PyArrow arrays and other native Python objects."""
+ """
+ Exchange messages with PyArrow arrays and other native Python objects.
+ """
msg = input("Type your message: ")
ret = {
"message": msg,
"pyarrow_array": pa.array(range(100)),
}
- return ret,
+ return (ret,)
def parse_args():
- """Parse command line arguments."""
+ """
+ Parse command line arguments.
+ """
parser = argparse.ArgumentParser(
- description="A message publisher and listener for native Python objects and PyArrow arrays.")
+ description="A message publisher and listener for native Python objects and PyArrow arrays."
+ )
parser.add_argument(
- "--mode", type=str, default="publish",
+ "--mode",
+ type=str,
+ default="publish",
choices={"publish", "listen"},
- help="The transmission mode"
+ help="The transmission mode",
)
parser.add_argument(
- "--mware", type=str, default=DEFAULT_COMMUNICATOR,
+ "--mware",
+ type=str,
+ default=DEFAULT_COMMUNICATOR,
choices=MiddlewareCommunicator.get_communicators(),
- help="The middleware to use for transmission"
+ help="The middleware to use for transmission",
)
return parser.parse_args()
def main(args):
- """Main function to initiate Notifier class and communication."""
+ """
+ Main function to initiate Notifier class and communication.
+ """
notifier = Notifier()
notifier.activate_communication(Notifier.exchange_object, mode=args.mode)
while True:
- msg_object, = notifier.exchange_object(mware=args.mware)
+ (msg_object,) = notifier.exchange_object(mware=args.mware)
print("Method result:", msg_object)
if __name__ == "__main__":
args = parse_args()
- main(args)
\ No newline at end of file
+ main(args)
diff --git a/examples/encoders/pytorch_example.py b/examples/encoders/pytorch_example.py
index 37b4351..1da2e91 100755
--- a/examples/encoders/pytorch_example.py
+++ b/examples/encoders/pytorch_example.py
@@ -40,45 +40,63 @@
class Notifier(MiddlewareCommunicator):
@MiddlewareCommunicator.register(
- "NativeObject", "$mware", "Notifier", "/notify/test_torch_exchange",
- carrier="", should_wait=True,
- listener_kwargs=dict(load_torch_device='cuda:0', map_torch_devices={'cpu': 'cuda:0', 'cuda:0': 'cpu'})
+ "NativeObject",
+ "$mware",
+ "Notifier",
+ "/notify/test_torch_exchange",
+ carrier="",
+ should_wait=True,
+ listener_kwargs=dict(
+ load_torch_device="cuda:0",
+ map_torch_devices={"cpu": "cuda:0", "cuda:0": "cpu"},
+ ),
)
def exchange_object(self, mware=None):
- """Exchange messages with PyTorch tensors and other native Python objects."""
+ """
+ Exchange messages with PyTorch tensors and other native Python objects.
+ """
msg = input("Type your message: ")
ret = {
"message": msg,
- "torch_ones": torch.ones((2, 4), device='cpu'),
- "torch_zeros_cuda": torch.zeros((2, 3), device='cuda:0')
+ "torch_ones": torch.ones((2, 4), device="cpu"),
+ "torch_zeros_cuda": torch.zeros((2, 3), device="cuda:0"),
}
- return ret,
+ return (ret,)
def parse_args():
- """Parse command line arguments."""
+ """
+ Parse command line arguments.
+ """
parser = argparse.ArgumentParser(
- description="A message publisher and listener for native Python objects and PyTorch tensors.")
+ description="A message publisher and listener for native Python objects and PyTorch tensors."
+ )
parser.add_argument(
- "--mode", type=str, default="publish",
+ "--mode",
+ type=str,
+ default="publish",
choices={"publish", "listen"},
- help="The transmission mode"
+ help="The transmission mode",
)
parser.add_argument(
- "--mware", type=str, default=DEFAULT_COMMUNICATOR,
+ "--mware",
+ type=str,
+ default=DEFAULT_COMMUNICATOR,
choices=MiddlewareCommunicator.get_communicators(),
- help="The middleware to use for transmission"
+ help="The middleware to use for transmission",
)
return parser.parse_args()
def main(args):
- """Main function to initiate Notifier class and communication."""
+ """
+ Main function to initiate Notifier class and communication.
+ """
notifier = Notifier()
notifier.activate_communication(Notifier.exchange_object, mode=args.mode)
while True:
- msg_object, = notifier.exchange_object(mware=args.mware)
+ (msg_object,) = notifier.exchange_object(mware=args.mware)
print("Method result:", msg_object)
diff --git a/examples/encoders/tensorflow_example.py b/examples/encoders/tensorflow_example.py
index fe3290f..c17bcd8 100755
--- a/examples/encoders/tensorflow_example.py
+++ b/examples/encoders/tensorflow_example.py
@@ -27,6 +27,7 @@
"""
import argparse
+
try:
import tensorflow as tf
except ImportError:
@@ -37,45 +38,60 @@
class Notifier(MiddlewareCommunicator):
@MiddlewareCommunicator.register(
- "NativeObject", "$mware", "Notifier", "/notify/test_native_exchange",
- carrier="", should_wait=True
+ "NativeObject",
+ "$mware",
+ "Notifier",
+ "/notify/test_native_exchange",
+ carrier="",
+ should_wait=True,
)
def exchange_object(self, mware=None):
- """Exchange messages with TensorFlow tensors and other native Python objects."""
+ """
+ Exchange messages with TensorFlow tensors and other native Python objects.
+ """
msg = input("Type your message: ")
ret = {
"message": msg,
"tf_ones": tf.ones((2, 4)),
- "tf_string": tf.constant("This is string")
+ "tf_string": tf.constant("This is string"),
}
- return ret,
+ return (ret,)
def parse_args():
- """Parse command line arguments."""
+ """
+ Parse command line arguments.
+ """
parser = argparse.ArgumentParser(
- description="A message publisher and listener for native Python objects and TensorFlow tensors.")
+ description="A message publisher and listener for native Python objects and TensorFlow tensors."
+ )
parser.add_argument(
- "--mode", type=str, default="publish",
+ "--mode",
+ type=str,
+ default="publish",
choices={"publish", "listen"},
- help="The transmission mode"
+ help="The transmission mode",
)
parser.add_argument(
- "--mware", type=str, default=DEFAULT_COMMUNICATOR,
+ "--mware",
+ type=str,
+ default=DEFAULT_COMMUNICATOR,
choices=MiddlewareCommunicator.get_communicators(),
- help="The middleware to use for transmission"
+ help="The middleware to use for transmission",
)
return parser.parse_args()
def main(args):
- """Main function to initiate Notifier class and communication."""
+ """
+ Main function to initiate Notifier class and communication.
+ """
notifier = Notifier()
notifier.activate_communication(Notifier.exchange_object, mode=args.mode)
while True:
- msg_object, = notifier.exchange_object(mware=args.mware)
+ (msg_object,) = notifier.exchange_object(mware=args.mware)
print("Method result:", msg_object)
diff --git a/examples/encoders/xarray_example.py b/examples/encoders/xarray_example.py
index 51f9b89..eeb3365 100644
--- a/examples/encoders/xarray_example.py
+++ b/examples/encoders/xarray_example.py
@@ -41,57 +41,74 @@
class Notifier(MiddlewareCommunicator):
@MiddlewareCommunicator.register(
- "NativeObject", "$mware", "Notifier", "/notify/test_xarray_exchange",
- carrier="tcp", should_wait=True
+ "NativeObject",
+ "$mware",
+ "Notifier",
+ "/notify/test_xarray_exchange",
+ carrier="tcp",
+ should_wait=True,
)
def exchange_object(self, mware=None):
- """Exchange messages with xarray DataArrays and other native Python objects."""
+ """
+ Exchange messages with xarray DataArrays and other native Python objects.
+ """
msg = input("Type your message: ")
# Creating an example xarray DataArray
data = np.random.rand(4, 3)
- locs = ['IA', 'IL', 'IN']
- times = pd.date_range('2000-01-01', periods=4)
- da = xr.DataArray(data, coords=[times, locs], dims=['time', 'space'], name='example')
+ locs = ["IA", "IL", "IN"]
+ times = pd.date_range("2000-01-01", periods=4)
+ da = xr.DataArray(
+ data, coords=[times, locs], dims=["time", "space"], name="example"
+ )
ret = {
"message": msg,
"xarray_dataarray": da,
"additional_info": {
- "set": {'a', 1, None},
- "list": [[[3, [4], 5.677890, 1.2]]]
- }
+ "set": {"a", 1, None},
+ "list": [[[3, [4], 5.677890, 1.2]]],
+ },
}
- return ret,
+ return (ret,)
def parse_args():
- """Parse command line arguments."""
+ """
+ Parse command line arguments.
+ """
parser = argparse.ArgumentParser(
- description="A message publisher and listener for native Python objects and xarray DataArrays.")
+ description="A message publisher and listener for native Python objects and xarray DataArrays."
+ )
parser.add_argument(
- "--mode", type=str, default="publish",
+ "--mode",
+ type=str,
+ default="publish",
choices={"publish", "listen"},
- help="The transmission mode"
+ help="The transmission mode",
)
parser.add_argument(
- "--mware", type=str, default=DEFAULT_COMMUNICATOR,
+ "--mware",
+ type=str,
+ default=DEFAULT_COMMUNICATOR,
choices=MiddlewareCommunicator.get_communicators(),
- help="The middleware to use for transmission"
+ help="The middleware to use for transmission",
)
return parser.parse_args()
def main(args):
- """Main function to initiate Notifier class and communication."""
+ """
+ Main function to initiate Notifier class and communication.
+ """
notifier = Notifier()
notifier.activate_communication(Notifier.exchange_object, mode=args.mode)
while True:
- msg_object, = notifier.exchange_object(mware=args.mware)
+ (msg_object,) = notifier.exchange_object(mware=args.mware)
print("Method result:", msg_object)
if __name__ == "__main__":
args = parse_args()
- main(args)
\ No newline at end of file
+ main(args)
diff --git a/examples/encoders/zarr_example.py b/examples/encoders/zarr_example.py
index 983cc66..4a60d2d 100644
--- a/examples/encoders/zarr_example.py
+++ b/examples/encoders/zarr_example.py
@@ -40,11 +40,17 @@
class Notifier(MiddlewareCommunicator):
@MiddlewareCommunicator.register(
- "NativeObject", "$mware", "Notifier", "/notify/test_zarr_exchange",
- carrier="tcp", should_wait=True
+ "NativeObject",
+ "$mware",
+ "Notifier",
+ "/notify/test_zarr_exchange",
+ carrier="tcp",
+ should_wait=True,
)
def exchange_object(self, mware=None):
- """Exchange messages with Zarr arrays/groups and other native Python objects."""
+ """
+ Exchange messages with Zarr arrays/groups and other native Python objects.
+ """
msg = input("Type your message: ")
# Creating an example Zarr Array
@@ -52,41 +58,50 @@ def exchange_object(self, mware=None):
# Creating an example Zarr Group
zgroup = zarr.group()
- zgroup.create_dataset('dataset1', data=np.random.randint(0, 100, 50), chunks=10)
- zgroup.create_dataset('dataset2', data=np.random.random(100), chunks=10)
+ zgroup.create_dataset("dataset1", data=np.random.randint(0, 100, 50), chunks=10)
+ zgroup.create_dataset("dataset2", data=np.random.random(100), chunks=10)
ret = {
"message": msg,
"zarr_array": zarray,
"zarr_group": zgroup,
}
- return ret,
+ return (ret,)
def parse_args():
- """Parse command line arguments."""
+ """
+ Parse command line arguments.
+ """
parser = argparse.ArgumentParser(
- description="A message publisher and listener for native Python objects and Zarr arrays/groups.")
+ description="A message publisher and listener for native Python objects and Zarr arrays/groups."
+ )
parser.add_argument(
- "--mode", type=str, default="publish",
+ "--mode",
+ type=str,
+ default="publish",
choices={"publish", "listen"},
- help="The transmission mode"
+ help="The transmission mode",
)
parser.add_argument(
- "--mware", type=str, default=DEFAULT_COMMUNICATOR,
+ "--mware",
+ type=str,
+ default=DEFAULT_COMMUNICATOR,
choices=MiddlewareCommunicator.get_communicators(),
- help="The middleware to use for transmission"
+ help="The middleware to use for transmission",
)
return parser.parse_args()
def main(args):
- """Main function to initiate Notifier class and communication."""
+ """
+ Main function to initiate Notifier class and communication.
+ """
notifier = Notifier()
notifier.activate_communication(Notifier.exchange_object, mode=args.mode)
while True:
- msg_object, = notifier.exchange_object(mware=args.mware)
+ (msg_object,) = notifier.exchange_object(mware=args.mware)
print("Method result:", msg_object)
diff --git a/examples/hello_world.py b/examples/hello_world.py
index 5496e3d..6edcc7b 100755
--- a/examples/hello_world.py
+++ b/examples/hello_world.py
@@ -29,6 +29,7 @@
"""
+
import argparse
from wrapyfi.connect.wrapper import MiddlewareCommunicator, DEFAULT_COMMUNICATOR
@@ -36,22 +37,64 @@
class HelloWorld(MiddlewareCommunicator):
- @MiddlewareCommunicator.register("NativeObject", "$mware", "HelloWorld", "/hello/my_message",
- carrier="tcp", should_wait=True)
+ @MiddlewareCommunicator.register(
+ "NativeObject",
+ "$mware",
+ "HelloWorld",
+ "/hello/my_message",
+ carrier="tcp",
+ should_wait=True,
+ )
def send_message(self, arg_from_requester="", mware=None):
- """Exchange messages and mirror user input."""
+ """
+ Exchange messages and mirror user input.
+ """
msg = input("Type your message: ")
obj = {"message": msg, "message_from_requester": arg_from_requester}
- return obj,
+ return (obj,)
+
def parse_args():
parser = argparse.ArgumentParser()
- parser.add_argument("--publish", dest="mode", action="store_const", const="publish", default="listen", help="Publish mode")
- parser.add_argument("--listen", dest="mode", action="store_const", const="listen", default="listen", help="Listen mode (default)")
- parser.add_argument("--request", dest="mode", action="store_const", const="request", default="listen", help="Request mode")
- parser.add_argument("--reply", dest="mode", action="store_const", const="reply", default="listen", help="Reply mode")
- parser.add_argument("--mware", type=str, default=DEFAULT_COMMUNICATOR, choices=MiddlewareCommunicator.get_communicators(),
- help="The middleware to use for transmission")
+ parser.add_argument(
+ "--publish",
+ dest="mode",
+ action="store_const",
+ const="publish",
+ default="listen",
+ help="Publish mode",
+ )
+ parser.add_argument(
+ "--listen",
+ dest="mode",
+ action="store_const",
+ const="listen",
+ default="listen",
+ help="Listen mode (default)",
+ )
+ parser.add_argument(
+ "--request",
+ dest="mode",
+ action="store_const",
+ const="request",
+ default="listen",
+ help="Request mode",
+ )
+ parser.add_argument(
+ "--reply",
+ dest="mode",
+ action="store_const",
+ const="reply",
+ default="listen",
+ help="Reply mode",
+ )
+ parser.add_argument(
+ "--mware",
+ type=str,
+ default=DEFAULT_COMMUNICATOR,
+ choices=MiddlewareCommunicator.get_communicators(),
+ help="The middleware to use for transmission",
+ )
return parser.parse_args()
@@ -61,6 +104,9 @@ def parse_args():
hello_world.activate_communication(HelloWorld.send_message, mode=args.mode)
while True:
- my_message, = hello_world.send_message(arg_from_requester=f"I got this message from the script running in {args.mode} mode", mware=args.mware)
+ (my_message,) = hello_world.send_message(
+ arg_from_requester=f"I got this message from the script running in {args.mode} mode",
+ mware=args.mware,
+ )
if my_message is not None:
print("Method result:", my_message)
diff --git a/examples/robots/icub_head.py b/examples/robots/icub_head.py
index c330bd9..d61cbb2 100755
--- a/examples/robots/icub_head.py
+++ b/examples/robots/icub_head.py
@@ -86,9 +86,15 @@
from wrapyfi.connect.wrapper import MiddlewareCommunicator
ICUB_DEFAULT_COMMUNICATOR = os.environ.get("WRAPYFI_DEFAULT_COMMUNICATOR", "yarp")
-ICUB_DEFAULT_COMMUNICATOR = os.environ.get("WRAPYFI_DEFAULT_MWARE", ICUB_DEFAULT_COMMUNICATOR)
-ICUB_DEFAULT_COMMUNICATOR = os.environ.get("ICUB_DEFAULT_COMMUNICATOR", ICUB_DEFAULT_COMMUNICATOR)
-ICUB_DEFAULT_COMMUNICATOR = os.environ.get("ICUB_DEFAULT_MWARE", ICUB_DEFAULT_COMMUNICATOR)
+ICUB_DEFAULT_COMMUNICATOR = os.environ.get(
+ "WRAPYFI_DEFAULT_MWARE", ICUB_DEFAULT_COMMUNICATOR
+)
+ICUB_DEFAULT_COMMUNICATOR = os.environ.get(
+ "ICUB_DEFAULT_COMMUNICATOR", ICUB_DEFAULT_COMMUNICATOR
+)
+ICUB_DEFAULT_COMMUNICATOR = os.environ.get(
+ "ICUB_DEFAULT_MWARE", ICUB_DEFAULT_COMMUNICATOR
+)
EMOTION_LOOKUP = {
"Neutral": [("LIGHTS", "neu")],
@@ -101,7 +107,7 @@
"Contempt": [("raw", "L01"), ("raw", "R09"), ("raw", "ME9")], # change to array
"Cunning": [("LIGHTS", "cun")],
"Shy": [("LIGHTS", "shy")],
- "Evil": [("LIGHTS", "evi")]
+ "Evil": [("LIGHTS", "evi")],
}
@@ -129,7 +135,9 @@ def cartesian_to_spherical(xyz=None, x=None, y=None, z=None, expand_return=None)
ptr = np.zeros((3,))
xy = xyz[0] ** 2 + xyz[1] ** 2
ptr[0] = np.arctan2(xyz[1], xyz[0])
- ptr[1] = np.arctan2(xyz[2], np.sqrt(xy)) # for elevation angle defined from XY-plane up
+ ptr[1] = np.arctan2(
+ xyz[2], np.sqrt(xy)
+ ) # for elevation angle defined from XY-plane up
# ptr[1] = np.arctan2(np.sqrt(xy), xyz[2]) # for elevation angle defined from Z-axis down
ptr[2] = np.sqrt(xy + xyz[2] ** 2)
return ptr if not expand_return else {"p": ptr[0], "t": ptr[1], "r": ptr[2]}
@@ -145,6 +153,7 @@ def mode_smoothing_filter(time_series, default, window_length=6, min_count=None)
:param min_count: int: Minimum number of values in the window to apply the smoothing filter
"""
import scipy.stats
+
if min_count is None:
min_count = window_length // 2
mode = scipy.stats.mode(time_series[-window_length:])
@@ -171,15 +180,26 @@ class ICub(MiddlewareCommunicator, yarp.RFModule):
FACIAL_EXPRESSIONS_QUEUE_SIZE = 7
FACIAL_EXPRESSION_SMOOTHING_WINDOW = 6
- def __init__(self, simulation=False, headless=False, get_cam_feed=True,
- img_width=CAP_PROP_FRAME_WIDTH, img_height=CAP_PROP_FRAME_HEIGHT,
- control_head=True, set_head_coordinates=True, head_coordinates_port=HEAD_COORDINATES_PORT,
- control_eyes=True, set_eye_coordinates=True, eye_coordinates_port=EYE_COORDINATES_PORT,
- ikingaze=False,
- gaze_plane_coordinates_port=GAZE_PLANE_COORDINATES_PORT,
- control_expressions=False,
- set_facial_expressions=True, facial_expressions_port=FACIAL_EXPRESSIONS_PORT,
- mware=MWARE):
+ def __init__(
+ self,
+ simulation=False,
+ headless=False,
+ get_cam_feed=True,
+ img_width=CAP_PROP_FRAME_WIDTH,
+ img_height=CAP_PROP_FRAME_HEIGHT,
+ control_head=True,
+ set_head_coordinates=True,
+ head_coordinates_port=HEAD_COORDINATES_PORT,
+ control_eyes=True,
+ set_eye_coordinates=True,
+ eye_coordinates_port=EYE_COORDINATES_PORT,
+ ikingaze=False,
+ gaze_plane_coordinates_port=GAZE_PLANE_COORDINATES_PORT,
+ control_expressions=False,
+ set_facial_expressions=True,
+ facial_expressions_port=FACIAL_EXPRESSIONS_PORT,
+ mware=MWARE,
+ ):
"""
Initialize the ICub head controller, facial expression transmitter and camera viewer.
@@ -222,15 +242,19 @@ def __init__(self, simulation=False, headless=False, get_cam_feed=True,
if simulation:
props.put("remote", "/icubSim/head")
- self.cam_props = {"cam_world_port": "/icubSim/cam",
- "cam_left_port": "/icubSim/cam/left",
- "cam_right_port": "/icubSim/cam/right"}
+ self.cam_props = {
+ "cam_world_port": "/icubSim/cam",
+ "cam_left_port": "/icubSim/cam/left",
+ "cam_right_port": "/icubSim/cam/right",
+ }
emotion_cmd = f"yarp rpc /icubSim/face/emotions/in"
else:
props.put("remote", "/icub/head")
- self.cam_props = {"cam_world_port": "/icub/cam/left",
- "cam_left_port": "/icub/cam/left",
- "cam_right_port": "/icub/cam/right"}
+ self.cam_props = {
+ "cam_world_port": "/icub/cam/left",
+ "cam_left_port": "/icub/cam/left",
+ "cam_right_port": "/icub/cam/right",
+ }
emotion_cmd = f"yarp rpc /icub/face/emotions/in"
if img_width is not None:
@@ -248,10 +272,15 @@ def __init__(self, simulation=False, headless=False, get_cam_feed=True,
# control emotional expressions using RPC
self.client = pexpect.spawn(emotion_cmd)
else:
- logging.error("pexpect must be installed to control the emotion interface")
+ logging.error(
+ "pexpect must be installed to control the emotion interface"
+ )
self.activate_communication(self.update_facial_expressions, "disable")
- self.last_expression = ["", ""] # (emotion part on the robot's face , emotional expression category)
+ self.last_expression = [
+ "",
+ "",
+ ] # (emotion part on the robot's face , emotional expression category)
self.expressions_queue = deque(maxlen=self.FACIAL_EXPRESSIONS_QUEUE_SIZE)
else:
self.activate_communication(self.update_facial_expressions, "disable")
@@ -352,9 +381,20 @@ def build(self):
Updates the default method arguments according to constructor arguments. This method is called by the module constructor.
It is not necessary to call it manually.
"""
- ICub.acquire_head_coordinates.__defaults__ = (self.HEAD_COORDINATES_PORT, None, self.MWARE)
- ICub.acquire_eye_coordinates.__defaults__ = (self.EYE_COORDINATES_PORT, None, self.MWARE)
- ICub.receive_gaze_plane_coordinates.__defaults__ = (self.GAZE_PLANE_COORDINATES_PORT, self.MWARE)
+ ICub.acquire_head_coordinates.__defaults__ = (
+ self.HEAD_COORDINATES_PORT,
+ None,
+ self.MWARE,
+ )
+ ICub.acquire_eye_coordinates.__defaults__ = (
+ self.EYE_COORDINATES_PORT,
+ None,
+ self.MWARE,
+ )
+ ICub.receive_gaze_plane_coordinates.__defaults__ = (
+ self.GAZE_PLANE_COORDINATES_PORT,
+ self.MWARE,
+ )
ICub.wait_for_gaze.__defaults__ = (True, self.MWARE)
ICub.reset_gaze.__defaults__ = (self.MWARE,)
ICub.update_head_gaze_speed.__defaults__ = (10.0, 10.0, 20.0, 0.8, self.MWARE)
@@ -362,16 +402,37 @@ def build(self):
ICub.update_eye_gaze_speed.__defaults__ = (10.0, 10.0, 20.0, 0.5, self.MWARE)
ICub.control_eye_gaze.__defaults__ = (0.0, 0.0, 0.0, self.MWARE)
ICub._control_head_eye_gaze.__defaults__ = (self.MWARE,)
- ICub.control_gaze_at_plane.__defaults__ = (0, 0, 0.3, 0.3, True, True, self.MWARE)
- ICub.acquire_facial_expressions.__defaults__ = (self.FACIAL_EXPRESSIONS_PORT, None, self.MWARE)
+ ICub.control_gaze_at_plane.__defaults__ = (
+ 0,
+ 0,
+ 0.3,
+ 0.3,
+ True,
+ True,
+ self.MWARE,
+ )
+ ICub.acquire_facial_expressions.__defaults__ = (
+ self.FACIAL_EXPRESSIONS_PORT,
+ None,
+ self.MWARE,
+ )
ICub.update_facial_expressions.__defaults__ = (None, False, "mode", self.MWARE)
- ICub.receive_images.__defaults__ = (self.CAP_PROP_FRAME_WIDTH, self.CAP_PROP_FRAME_HEIGHT, True)
-
- @MiddlewareCommunicator.register("NativeObject", "$_mware",
- "ICub", "$head_coordinates_port",
- should_wait=False)
- def acquire_head_coordinates(self, head_coordinates_port=HEAD_COORDINATES_PORT, cv2_key=None,
- _mware=MWARE, **kwargs):
+ ICub.receive_images.__defaults__ = (
+ self.CAP_PROP_FRAME_WIDTH,
+ self.CAP_PROP_FRAME_HEIGHT,
+ True,
+ )
+
+ @MiddlewareCommunicator.register(
+ "NativeObject", "$_mware", "ICub", "$head_coordinates_port", should_wait=False
+ )
+ def acquire_head_coordinates(
+ self,
+ head_coordinates_port=HEAD_COORDINATES_PORT,
+ cv2_key=None,
+ _mware=MWARE,
+ **kwargs,
+ ):
"""
Acquire head coordinates for controlling the iCub.
@@ -382,7 +443,7 @@ def acquire_head_coordinates(self, head_coordinates_port=HEAD_COORDINATES_PORT,
if cv2_key is None:
logging.error("controlling orientation in headless mode not yet supported")
- return None,
+ return (None,)
else:
if cv2_key == 27: # Esc key to exit
exit(0)
@@ -413,20 +474,29 @@ def acquire_head_coordinates(self, head_coordinates_port=HEAD_COORDINATES_PORT,
logging.info("resetting the orientation")
else:
logging.info(cv2_key) # else print its value
- return None,
+ return (None,)
- return {"topic": head_coordinates_port.split("/")[-1],
+ return (
+ {
+ "topic": head_coordinates_port.split("/")[-1],
"timestamp": time.time(),
"pitch": self._curr_head[0],
"roll": self._curr_head[1],
"yaw": self._curr_head[2],
- "order": "zyx"},
-
- @MiddlewareCommunicator.register("NativeObject", "$_mware",
- "ICub", "$eye_coordinates_port",
- should_wait=False)
- def acquire_eye_coordinates(self, eye_coordinates_port=EYE_COORDINATES_PORT, cv2_key=None,
- _mware=MWARE, **kwargs):
+ "order": "zyx",
+ },
+ )
+
+ @MiddlewareCommunicator.register(
+ "NativeObject", "$_mware", "ICub", "$eye_coordinates_port", should_wait=False
+ )
+ def acquire_eye_coordinates(
+ self,
+ eye_coordinates_port=EYE_COORDINATES_PORT,
+ cv2_key=None,
+ _mware=MWARE,
+ **kwargs,
+ ):
"""
Acquire eye coordinates for controlling the iCub.
@@ -437,7 +507,7 @@ def acquire_eye_coordinates(self, eye_coordinates_port=EYE_COORDINATES_PORT, cv2
if cv2_key is None:
logging.error("controlling orientation in headless mode not yet supported")
- return None,
+ return (None,)
else:
if cv2_key == 27: # Esc key to exit
exit(0)
@@ -462,30 +532,46 @@ def acquire_eye_coordinates(self, eye_coordinates_port=EYE_COORDINATES_PORT, cv2
logging.info("resetting the orientation")
else:
logging.info(cv2_key) # else print its value
- return None,
+ return (None,)
- return {"topic": eye_coordinates_port.split("/")[-1],
+ return (
+ {
+ "topic": eye_coordinates_port.split("/")[-1],
"timestamp": time.time(),
"pitch": self._curr_eyes[0],
"yaw": self._curr_eyes[1],
- "vergence": self._curr_eyes[2]},
-
- @MiddlewareCommunicator.register("NativeObject", "$_mware",
- "ICub", "$gaze_plane_coordinates_port",
- should_wait=False)
- def receive_gaze_plane_coordinates(self, gaze_plane_coordinates_port=GAZE_PLANE_COORDINATES_PORT,
- _mware=MWARE, **kwargs):
+ "vergence": self._curr_eyes[2],
+ },
+ )
+
+ @MiddlewareCommunicator.register(
+ "NativeObject",
+ "$_mware",
+ "ICub",
+ "$gaze_plane_coordinates_port",
+ should_wait=False,
+ )
+ def receive_gaze_plane_coordinates(
+ self,
+ gaze_plane_coordinates_port=GAZE_PLANE_COORDINATES_PORT,
+ _mware=MWARE,
+ **kwargs,
+ ):
"""
Receive gaze plane (normalized x,y) coordinates for controlling the iCub.
:param gaze_plane_coordinates_port: str: Port to receive gaze plane coordinates
:return: dict: Gaze plane coordinates
"""
- return None,
-
- @MiddlewareCommunicator.register("NativeObject", "$_mware",
- "ICub", "/icub_controller/logs/wait_for_gaze",
- should_wait=False)
+ return (None,)
+
+ @MiddlewareCommunicator.register(
+ "NativeObject",
+ "$_mware",
+ "ICub",
+ "/icub_controller/logs/wait_for_gaze",
+ should_wait=False,
+ )
def wait_for_gaze(self, reset=True, _mware=MWARE):
"""
Wait for the gaze actuation to complete.
@@ -506,13 +592,21 @@ def wait_for_gaze(self, reset=True, _mware=MWARE):
self._ipos.positionMove(self._encs.data())
while not self._ipos.checkMotionDone():
pass
- return {"topic": "logging_wait_for_gaze",
+ return (
+ {
+ "topic": "logging_wait_for_gaze",
"timestamp": time.time(),
- "command": f"waiting for gaze completed with reset={reset}"},
-
- @MiddlewareCommunicator.register("NativeObject", "$_mware",
- "ICub", "/icub_controller/logs/reset_gaze",
- should_wait=False)
+ "command": f"waiting for gaze completed with reset={reset}",
+ },
+ )
+
+ @MiddlewareCommunicator.register(
+ "NativeObject",
+ "$_mware",
+ "ICub",
+ "/icub_controller/logs/reset_gaze",
+ should_wait=False,
+ )
def reset_gaze(self, _mware=MWARE):
"""
Reset the eyes and head to their original position.
@@ -521,14 +615,24 @@ def reset_gaze(self, _mware=MWARE):
:return: dict: Gaze reset log for a given time step
"""
self.wait_for_gaze(reset=True)
- return {"topic": "logging_reset_gaze",
+ return (
+ {
+ "topic": "logging_reset_gaze",
"timestamp": time.time(),
- "command": f"reset gaze"},
-
- @MiddlewareCommunicator.register("NativeObject", "$mware",
- "ICub", "/icub_controller/logs/head_speed",
- should_wait=False)
- def update_head_gaze_speed(self, pitch=10.0, roll=10.0, yaw=20.0, head=0.8, _mware=MWARE, **kwargs):
+ "command": f"reset gaze",
+ },
+ )
+
+ @MiddlewareCommunicator.register(
+ "NativeObject",
+ "$mware",
+ "ICub",
+ "/icub_controller/logs/head_speed",
+ should_wait=False,
+ )
+ def update_head_gaze_speed(
+ self, pitch=10.0, roll=10.0, yaw=20.0, head=0.8, _mware=MWARE, **kwargs
+ ):
"""
Control the iCub head speed.
@@ -542,22 +646,36 @@ def update_head_gaze_speed(self, pitch=10.0, roll=10.0, yaw=20.0, head=0.8, _mwa
if self.ikingaze:
self._igaze.setNeckTrajTime(head)
- return {"topic": "logging_head_speed",
+ return (
+ {
+ "topic": "logging_head_speed",
"timestamp": time.time(),
- "command": f"head speed set to {head}"},
+ "command": f"head speed set to {head}",
+ },
+ )
else:
self._ipos.setRefSpeed(0, pitch)
self._ipos.setRefSpeed(1, roll)
self._ipos.setRefSpeed(2, yaw)
- return {"topic": "logging_head_speed",
+ return (
+ {
+ "topic": "logging_head_speed",
"timestamp": time.time(),
- "command": f"head speed set to {pitch, roll, yaw} (pitch, roll, yaw)"},
-
- @MiddlewareCommunicator.register("NativeObject", "$mware",
- "ICub", "/icub_controller/logs/eye_speed",
- should_wait=False)
- def update_eye_gaze_speed(self, pitch=10.0, yaw=10.0, vergence=20.0, eye=0.5, _mware=MWARE, **kwargs):
+ "command": f"head speed set to {pitch, roll, yaw} (pitch, roll, yaw)",
+ },
+ )
+
+ @MiddlewareCommunicator.register(
+ "NativeObject",
+ "$mware",
+ "ICub",
+ "/icub_controller/logs/eye_speed",
+ should_wait=False,
+ )
+ def update_eye_gaze_speed(
+ self, pitch=10.0, yaw=10.0, vergence=20.0, eye=0.5, _mware=MWARE, **kwargs
+ ):
"""
Control the iCub eye speed.
@@ -571,22 +689,36 @@ def update_eye_gaze_speed(self, pitch=10.0, yaw=10.0, vergence=20.0, eye=0.5, _m
if self.ikingaze:
self._igaze.setEyesTrajTime(eye)
- return {"topic": "logging_eye_speed",
+ return (
+ {
+ "topic": "logging_eye_speed",
"timestamp": time.time(),
- "command": f"eye speed set to {eye}"},
+ "command": f"eye speed set to {eye}",
+ },
+ )
else:
self._ipos.setRefSpeed(3, pitch)
self._ipos.setRefSpeed(4, yaw)
self._ipos.setRefSpeed(5, vergence)
- return {"topic": "logging_eye_speed",
+ return (
+ {
+ "topic": "logging_eye_speed",
"timestamp": time.time(),
- "command": f"eye speed set to {pitch, yaw, vergence} (pitch, yaw, vergence)"},
-
- @MiddlewareCommunicator.register("NativeObject", "$_mware",
- "ICub", "/icub_controller/logs/head_orientation_coordinates",
- should_wait=False)
- def control_head_gaze(self, pitch=0.0, roll=0.0, yaw=0.0, order="xyz", _mware=MWARE, **kwargs):
+ "command": f"eye speed set to {pitch, yaw, vergence} (pitch, yaw, vergence)",
+ },
+ )
+
+ @MiddlewareCommunicator.register(
+ "NativeObject",
+ "$_mware",
+ "ICub",
+ "/icub_controller/logs/head_orientation_coordinates",
+ should_wait=False,
+ )
+ def control_head_gaze(
+ self, pitch=0.0, roll=0.0, yaw=0.0, order="xyz", _mware=MWARE, **kwargs
+ ):
"""
Control the iCub head relative to previous coordinates following the roll,pitch,yaw convention (order=xyz)
(initialized at 0 looking straight ahead).
@@ -599,8 +731,10 @@ def control_head_gaze(self, pitch=0.0, roll=0.0, yaw=0.0, order="xyz", _mware=MW
:return: dict: Head orientation coordinates log for a given time step
"""
if order != "xyz":
- logging.error("only accepts ratation angles following the order='xyz' convention")
- return None,
+ logging.error(
+ "only accepts ratation angles following the order='xyz' convention"
+ )
+ return (None,)
# wait for the action to complete
# self.wait_for_gaze(reset=False)
@@ -615,14 +749,24 @@ def control_head_gaze(self, pitch=0.0, roll=0.0, yaw=0.0, order="xyz", _mware=MW
# self._ipos.positionMove(self.init_pos_head.data())
self._curr_head = list((pitch, roll, yaw))
- return {"topic": "logging_head_coordinates",
+ return (
+ {
+ "topic": "logging_head_coordinates",
"timestamp": time.time(),
- "command": f"head orientation set to {self._curr_head} (pitch, roll, yaw)"},
-
- @MiddlewareCommunicator.register("NativeObject", "$_mware",
- "ICub", "/icub_controller/logs/eye_orientation_coordinates",
- should_wait=False)
- def control_eye_gaze(self, pitch=0.0, yaw=0.0, vergence=0.0, _mware=MWARE, **kwargs):
+ "command": f"head orientation set to {self._curr_head} (pitch, roll, yaw)",
+ },
+ )
+
+ @MiddlewareCommunicator.register(
+ "NativeObject",
+ "$_mware",
+ "ICub",
+ "/icub_controller/logs/eye_orientation_coordinates",
+ should_wait=False,
+ )
+ def control_eye_gaze(
+ self, pitch=0.0, yaw=0.0, vergence=0.0, _mware=MWARE, **kwargs
+ ):
"""
Control the iCub eyes relative to previous coordinates (initialized at 0 looking straight ahead).
@@ -641,19 +785,28 @@ def control_eye_gaze(self, pitch=0.0, yaw=0.0, vergence=0.0, _mware=MWARE, **kwa
# eye control
self.init_pos_eyes.set(3, self.init_pos_eyes.get(3) + pitch) # eye tilt
self.init_pos_eyes.set(4, self.init_pos_eyes.get(4) + yaw) # eye pan/version
- self.init_pos_eyes.set(5, self.init_pos_eyes.get(
- 5) + vergence) # the vergence between the eyes (to align, set to 0)
+ self.init_pos_eyes.set(
+ 5, self.init_pos_eyes.get(5) + vergence
+ ) # the vergence between the eyes (to align, set to 0)
# self._ipos.positionMove(self.init_pos_eyes.data())
self._curr_eyes = list((pitch, yaw, vergence))
- return {"topic": "logging_eye_coordinates",
+ return (
+ {
+ "topic": "logging_eye_coordinates",
"timestamp": time.time(),
- "command": f"eye orientation set to {self._curr_eyes} (pitch, yaw, vergence)"},
-
- @MiddlewareCommunicator.register("NativeObject", "$_mware",
- "ICub", "/icub_controller/logs/head_eye_orientation_coordinates",
- should_wait=False)
+ "command": f"eye orientation set to {self._curr_eyes} (pitch, yaw, vergence)",
+ },
+ )
+
+ @MiddlewareCommunicator.register(
+ "NativeObject",
+ "$_mware",
+ "ICub",
+ "/icub_controller/logs/head_eye_orientation_coordinates",
+ should_wait=False,
+ )
def _control_head_eye_gaze(self, _mware=MWARE, **kwargs):
"""
Issue the movement command
@@ -671,18 +824,37 @@ def _control_head_eye_gaze(self, _mware=MWARE, **kwargs):
self.init_pos.set(2, self.init_pos_head.get(2)) # pan/yaw
self.init_pos.set(3, self.init_pos_eyes.get(3)) # eye tilt
self.init_pos.set(4, self.init_pos_eyes.get(4)) # eye pan/version
- self.init_pos.set(5, self.init_pos_eyes.get(5)) # the vergence between the eyes (to align, set to 0)
+ self.init_pos.set(
+ 5, self.init_pos_eyes.get(5)
+ ) # the vergence between the eyes (to align, set to 0)
self._ipos.positionMove(self.init_pos.data())
- return {"topic": "logging_head_eye_coordinates",
+ return (
+ {
+ "topic": "logging_head_eye_coordinates",
"timestamp": time.time(),
- "command": f"head orientation set to {self._curr_head} (pitch, roll, yaw) and eye orientation to {self._curr_eyes} (pitch, yaw, vergence)"},
-
- @MiddlewareCommunicator.register("NativeObject", "$_mware",
- "ICub", "/icub_controller/logs/gaze_plane_coordinates",
- should_wait=False)
- def control_gaze_at_plane(self, x=0.0, y=0.0, limit_x=0.3, limit_y=0.3, control_eyes=True, control_head=True,
- _mware=MWARE, **kwargs):
+ "command": f"head orientation set to {self._curr_head} (pitch, roll, yaw) and eye orientation to {self._curr_eyes} (pitch, yaw, vergence)",
+ },
+ )
+
+ @MiddlewareCommunicator.register(
+ "NativeObject",
+ "$_mware",
+ "ICub",
+ "/icub_controller/logs/gaze_plane_coordinates",
+ should_wait=False,
+ )
+ def control_gaze_at_plane(
+ self,
+ x=0.0,
+ y=0.0,
+ limit_x=0.3,
+ limit_y=0.3,
+ control_eyes=True,
+ control_head=True,
+ _mware=MWARE,
+ **kwargs,
+ ):
"""
Gaze at specific point in a normalized plane in front of the iCub.
@@ -704,8 +876,10 @@ def control_gaze_at_plane(self, x=0.0, y=0.0, limit_x=0.3, limit_y=0.3, control_
if control_eyes and control_head:
if not self.ikingaze:
- logging.error("set ikingaze=True in order to move eyes and head simultaneously")
- return None,
+ logging.error(
+ "set ikingaze=True in order to move eyes and head simultaneously"
+ )
+ return (None,)
self.init_pos_ikin = yarp.Vector(3, self._gaze_encs.data())
self.init_pos_ikin.set(0, ptr_degrees[0])
self.init_pos_ikin.set(1, ptr_degrees[1])
@@ -715,23 +889,32 @@ def control_gaze_at_plane(self, x=0.0, y=0.0, limit_x=0.3, limit_y=0.3, control_
elif control_head:
if self.ikingaze:
logging.error("set ikingaze=False in order to move head only")
- return None,
+ return (None,)
self.control_head_gaze(pitch=ptr_degrees[1], roll=0, yaw=ptr_degrees[0])
elif control_eyes:
if self.ikingaze:
logging.error("set ikingaze=False in order to move eyes only")
- return None,
+ return (None,)
self.control_eye_gaze(pitch=ptr_degrees[1], yaw=ptr_degrees[0], vergence=0)
- return {"topic": "logging_gaze_plane_coordinates",
+ return (
+ {
+ "topic": "logging_gaze_plane_coordinates",
"timestamp": time.time(),
- "command": f"moving gaze toward {ptr_degrees} with head={control_head} and eyes={control_eyes}"},
-
- @MiddlewareCommunicator.register("NativeObject", "$_mware",
- "ICub", "$facial_expressions_port",
- should_wait=False)
- def acquire_facial_expressions(self, facial_expressions_port=FACIAL_EXPRESSIONS_PORT, cv2_key=None,
- _mware=MWARE, **kwargs):
+ "command": f"moving gaze toward {ptr_degrees} with head={control_head} and eyes={control_eyes}",
+ },
+ )
+
+ @MiddlewareCommunicator.register(
+ "NativeObject", "$_mware", "ICub", "$facial_expressions_port", should_wait=False
+ )
+ def acquire_facial_expressions(
+ self,
+ facial_expressions_port=FACIAL_EXPRESSIONS_PORT,
+ cv2_key=None,
+ _mware=MWARE,
+ **kwargs,
+ ):
"""
Acquire facial expressions from the iCub.
@@ -742,7 +925,7 @@ def acquire_facial_expressions(self, facial_expressions_port=FACIAL_EXPRESSIONS_
emotion = None
if cv2_key is None:
logging.error("controlling expressions in headless mode not yet supported")
- return None,
+ return (None,)
else:
if cv2_key == 27: # Esc key to exit
exit(0)
@@ -780,15 +963,25 @@ def acquire_facial_expressions(self, facial_expressions_port=FACIAL_EXPRESSIONS_
logging.info("expressing shyness")
else:
logging.info(cv2_key) # else print its value
- return None,
- return {"topic": facial_expressions_port.split("/")[-1],
+ return (None,)
+ return (
+ {
+ "topic": facial_expressions_port.split("/")[-1],
"timestamp": time.time(),
- "emotion_category": emotion},
-
- @MiddlewareCommunicator.register("NativeObject", "$_mware",
- "ICub", "/icub_controller/logs/facial_expressions",
- should_wait=False)
- def update_facial_expressions(self, expression, part=False, smoothing="mode", _mware=MWARE, **kwargs):
+ "emotion_category": emotion,
+ },
+ )
+
+ @MiddlewareCommunicator.register(
+ "NativeObject",
+ "$_mware",
+ "ICub",
+ "/icub_controller/logs/facial_expressions",
+ should_wait=False,
+ )
+ def update_facial_expressions(
+ self, expression, part=False, smoothing="mode", _mware=MWARE, **kwargs
+ ):
"""
Control facial expressions of the iCub.
@@ -801,27 +994,34 @@ def update_facial_expressions(self, expression, part=False, smoothing="mode", _m
:return: Emotion log for a given time step
"""
if expression is None:
- return None,
+ return (None,)
if isinstance(expression, (list, tuple)):
expression = expression[-1]
if smoothing == "mode":
self.expressions_queue.append(expression)
- transmitted_expression = mode_smoothing_filter(list(self.expressions_queue), default="neu",
- window_length=self.FACIAL_EXPRESSION_SMOOTHING_WINDOW)
+ transmitted_expression = mode_smoothing_filter(
+ list(self.expressions_queue),
+ default="neu",
+ window_length=self.FACIAL_EXPRESSION_SMOOTHING_WINDOW,
+ )
else:
transmitted_expression = expression
- expressions_lookup = EMOTION_LOOKUP.get(transmitted_expression, transmitted_expression)
+ expressions_lookup = EMOTION_LOOKUP.get(
+ transmitted_expression, transmitted_expression
+ )
if isinstance(expressions_lookup, str):
expressions_lookup = [(part if part else "LIGHTS", expressions_lookup)]
- if self.last_expression[0] == (part if part else "LIGHTS") and self.last_expression[
- 1] == transmitted_expression:
+ if (
+ self.last_expression[0] == (part if part else "LIGHTS")
+ and self.last_expression[1] == transmitted_expression
+ ):
expressions_lookup = []
- for (part_lookup, expression_lookup) in expressions_lookup:
+ for part_lookup, expression_lookup in expressions_lookup:
if part_lookup == "LIGHTS":
self.client.sendline(f"set leb {expression_lookup}")
self.client.expect(">>")
@@ -838,18 +1038,50 @@ def update_facial_expressions(self, expression, part=False, smoothing="mode", _m
self.last_expression[0] = part
self.last_expression[1] = transmitted_expression
- return {"topic": "logging_facial_expressions",
+ return (
+ {
+ "topic": "logging_facial_expressions",
"timestamp": time.time(),
- "command": f"emotion set to {part} {expression} with smoothing={smoothing}"},
-
- @MiddlewareCommunicator.register("Image", "yarp", "ICub", "$cam_world_port",
- width="$img_width", height="$img_height", rgb="$_rgb")
- @MiddlewareCommunicator.register("Image", "yarp", "ICub", "$cam_left_port",
- width="$img_width", height="$img_height", rgb="$_rgb")
- @MiddlewareCommunicator.register("Image", "yarp", "ICub", "$cam_right_port",
- width="$img_width", height="$img_height", rgb="$_rgb")
- def receive_images(self, cam_world_port, cam_left_port, cam_right_port,
- img_width=CAP_PROP_FRAME_WIDTH, img_height=CAP_PROP_FRAME_HEIGHT, _rgb=True):
+ "command": f"emotion set to {part} {expression} with smoothing={smoothing}",
+ },
+ )
+
+ @MiddlewareCommunicator.register(
+ "Image",
+ "yarp",
+ "ICub",
+ "$cam_world_port",
+ width="$img_width",
+ height="$img_height",
+ rgb="$_rgb",
+ )
+ @MiddlewareCommunicator.register(
+ "Image",
+ "yarp",
+ "ICub",
+ "$cam_left_port",
+ width="$img_width",
+ height="$img_height",
+ rgb="$_rgb",
+ )
+ @MiddlewareCommunicator.register(
+ "Image",
+ "yarp",
+ "ICub",
+ "$cam_right_port",
+ width="$img_width",
+ height="$img_height",
+ rgb="$_rgb",
+ )
+ def receive_images(
+ self,
+ cam_world_port,
+ cam_left_port,
+ cam_right_port,
+ img_width=CAP_PROP_FRAME_WIDTH,
+ img_height=CAP_PROP_FRAME_HEIGHT,
+ _rgb=True,
+ ):
"""
Receive images from the iCub.
@@ -883,75 +1115,128 @@ def updateModule(self):
left_cam = cv2.cvtColor(left_cam, cv2.COLOR_BGR2RGB)
right_cam = cv2.cvtColor(right_cam, cv2.COLOR_BGR2RGB)
if not self.headless:
- cv2.imshow("ICubCam", np.concatenate((left_cam, external_cam, right_cam), axis=1))
+ cv2.imshow(
+ "ICubCam", np.concatenate((left_cam, external_cam, right_cam), axis=1)
+ )
k = cv2.waitKey(30)
else:
k = None
- switch_emotion, = self.acquire_facial_expressions(facial_expressions_port=self.FACIAL_EXPRESSIONS_PORT,
- cv2_key=k, _mware=self.MWARE)
+ (switch_emotion,) = self.acquire_facial_expressions(
+ facial_expressions_port=self.FACIAL_EXPRESSIONS_PORT,
+ cv2_key=k,
+ _mware=self.MWARE,
+ )
if switch_emotion is not None and isinstance(switch_emotion, dict):
- self.update_facial_expressions(switch_emotion.get("emotion_category", None),
- part=switch_emotion.get("part", False), _mware=self.MWARE)
+ self.update_facial_expressions(
+ switch_emotion.get("emotion_category", None),
+ part=switch_emotion.get("part", False),
+ _mware=self.MWARE,
+ )
# move robot head
- move_robot_head, = self.acquire_head_coordinates(head_coordinates_port=self.HEAD_COORDINATES_PORT,
- cv2_key=k, _mware=self.MWARE)
+ (move_robot_head,) = self.acquire_head_coordinates(
+ head_coordinates_port=self.HEAD_COORDINATES_PORT,
+ cv2_key=k,
+ _mware=self.MWARE,
+ )
if move_robot_head is not None and isinstance(move_robot_head, dict):
robot_head_speed = move_robot_head.get("speed", False)
if robot_head_speed and isinstance(robot_head_speed, dict):
- self.update_head_gaze_speed(pitch=robot_head_speed.get("pitch", 10.0),
- roll=robot_head_speed.get("roll", 10.0),
- yaw=robot_head_speed.get("yaw", 20.0), _mware=self.MWARE)
+ self.update_head_gaze_speed(
+ pitch=robot_head_speed.get("pitch", 10.0),
+ roll=robot_head_speed.get("roll", 10.0),
+ yaw=robot_head_speed.get("yaw", 20.0),
+ _mware=self.MWARE,
+ )
if move_robot_head.get("reset_gaze", False):
self.reset_gaze()
- self.control_head_gaze(pitch=move_robot_head.get("pitch", 0.0),
- roll=move_robot_head.get("roll", 0.0),
- yaw=move_robot_head.get("yaw", 0.0), _mware=self.MWARE)
+ self.control_head_gaze(
+ pitch=move_robot_head.get("pitch", 0.0),
+ roll=move_robot_head.get("roll", 0.0),
+ yaw=move_robot_head.get("yaw", 0.0),
+ _mware=self.MWARE,
+ )
# move robot eyes
- move_robot_eyes, = self.acquire_eye_coordinates(eye_coordinates_port=self.EYE_COORDINATES_PORT,
- cv2_key=k, _mware=self.MWARE)
+ (move_robot_eyes,) = self.acquire_eye_coordinates(
+ eye_coordinates_port=self.EYE_COORDINATES_PORT, cv2_key=k, _mware=self.MWARE
+ )
if move_robot_eyes is not None and isinstance(move_robot_eyes, dict):
robot_eye_speed = move_robot_eyes.get("speed", False)
if robot_eye_speed and isinstance(robot_eye_speed, dict):
- self.update_eye_gaze_speed(pitch=robot_eye_speed.get("pitch", 10.0),
- yaw=robot_eye_speed.get("yaw", 10.0),
- vergence=robot_eye_speed.get("vergence", 20.0), _mware=self.MWARE)
+ self.update_eye_gaze_speed(
+ pitch=robot_eye_speed.get("pitch", 10.0),
+ yaw=robot_eye_speed.get("yaw", 10.0),
+ vergence=robot_eye_speed.get("vergence", 20.0),
+ _mware=self.MWARE,
+ )
if move_robot_eyes.get("reset_gaze", False):
self.reset_gaze()
- self.control_eye_gaze(pitch=move_robot_eyes.get("pitch", 0.0),
- yaw=move_robot_eyes.get("yaw", 0.0),
- vergence=move_robot_eyes.get("vergence", 0.0), _mware=self.MWARE)
+ self.control_eye_gaze(
+ pitch=move_robot_eyes.get("pitch", 0.0),
+ yaw=move_robot_eyes.get("yaw", 0.0),
+ vergence=move_robot_eyes.get("vergence", 0.0),
+ _mware=self.MWARE,
+ )
if move_robot_head is not None or move_robot_eyes is not None:
self._control_head_eye_gaze()
- move_robot, = self.receive_gaze_plane_coordinates(gaze_plane_coordinates_port=self.GAZE_PLANE_COORDINATES_PORT,
- _mware=self.MWARE)
+ (move_robot,) = self.receive_gaze_plane_coordinates(
+ gaze_plane_coordinates_port=self.GAZE_PLANE_COORDINATES_PORT,
+ _mware=self.MWARE,
+ )
if move_robot is not None and isinstance(move_robot, dict):
robot_eye_speed = move_robot.get("eye_speed", False)
if robot_eye_speed and isinstance(robot_eye_speed, dict):
- self.update_eye_gaze_speed(**{"pitch": robot_eye_speed.get("pitch", 10.0),
- "yaw": robot_eye_speed.get("yaw", 10.0),
- "vergence": robot_eye_speed.get("vergence", 20.0), "_mware": self.MWARE}
- if not self.ikingaze else {"eye": robot_eye_speed.get("eye", 0.5), "_mware": self.MWARE})
+ self.update_eye_gaze_speed(
+ **(
+ {
+ "pitch": robot_eye_speed.get("pitch", 10.0),
+ "yaw": robot_eye_speed.get("yaw", 10.0),
+ "vergence": robot_eye_speed.get("vergence", 20.0),
+ "_mware": self.MWARE,
+ }
+ if not self.ikingaze
+ else {
+ "eye": robot_eye_speed.get("eye", 0.5),
+ "_mware": self.MWARE,
+ }
+ )
+ )
robot_head_speed = move_robot.get("head_speed", False)
if robot_head_speed and isinstance(robot_head_speed, dict):
- self.update_head_gaze_speed(**{"pitch": robot_head_speed.get("pitch", 10.0),
- "roll": robot_head_speed.get("roll", 10.0),
- "yaw": robot_head_speed.get("yaw", 20.0), "_mware": self.MWARE}
- if not self.ikingaze else {"head": robot_head_speed.get("head", 0.8), "_mware": self.MWARE})
+ self.update_head_gaze_speed(
+ **(
+ {
+ "pitch": robot_head_speed.get("pitch", 10.0),
+ "roll": robot_head_speed.get("roll", 10.0),
+ "yaw": robot_head_speed.get("yaw", 20.0),
+ "_mware": self.MWARE,
+ }
+ if not self.ikingaze
+ else {
+ "head": robot_head_speed.get("head", 0.8),
+ "_mware": self.MWARE,
+ }
+ )
+ )
if move_robot.get("reset_gaze", False):
self.reset_gaze()
- self.control_gaze_at_plane(x=move_robot.get("x", 0.0), y=move_robot.get("y", 0.0),
- limit_x=move_robot.get("limit_x", 0.3),
- limit_y=move_robot.get("limit_y", 0.3),
- control_head=move_robot.get("control_head",
- False if not self.ikingaze else True),
- control_eyes=move_robot.get("control_eyes", True), _mware=self.MWARE),
+ self.control_gaze_at_plane(
+ x=move_robot.get("x", 0.0),
+ y=move_robot.get("y", 0.0),
+ limit_x=move_robot.get("limit_x", 0.3),
+ limit_y=move_robot.get("limit_y", 0.3),
+ control_head=move_robot.get(
+ "control_head", False if not self.ikingaze else True
+ ),
+ control_eyes=move_robot.get("control_eyes", True),
+ _mware=self.MWARE,
+ ),
return True
@@ -961,39 +1246,77 @@ def parse_args():
parser.add_argument("--simulation", action="store_true", help="Run in simulation")
parser.add_argument("--headless", action="store_true", help="Disable CV2 GUI")
parser.add_argument("--ikingaze", action="store_true", help="Enable iKinGazeCtrl")
- parser.add_argument("--get_cam_feed", action="store_true", help="Get the camera feeds from the robot")
+ parser.add_argument(
+ "--get_cam_feed",
+ action="store_true",
+ help="Get the camera feeds from the robot",
+ )
parser.add_argument("--control_head", action="store_true", help="Control the head")
- parser.add_argument("--set_head_coordinates", action="store_true",
- help="Publish head coordinates set using keyboard commands")
- parser.add_argument("--head_coordinates_port", type=str, default="",
- help="The port (topic) name used for receiving and transmitting head orientation "
- "Setting the port name without --set_head_coordinates will only receive the coordinates")
+ parser.add_argument(
+ "--set_head_coordinates",
+ action="store_true",
+ help="Publish head coordinates set using keyboard commands",
+ )
+ parser.add_argument(
+ "--head_coordinates_port",
+ type=str,
+ default="",
+ help="The port (topic) name used for receiving and transmitting head orientation "
+ "Setting the port name without --set_head_coordinates will only receive the coordinates",
+ )
parser.add_argument("--control_eyes", action="store_true", help="Control the eyes")
- parser.add_argument("--set_eye_coordinates", action="store_true",
- help="Publish eye coordinates set using keyboard commands")
- parser.add_argument("--eye_coordinates_port", type=str, default="",
- help="The port (topic) name used for receiving and transmitting eye orientation "
- "Setting the port name without --set_eye_coordinates will only receive the coordinates")
- parser.add_argument("--gaze_plane_coordinates_port", type=str, default="",
- help="The port (topic) name used for receiving plane coordinates in 2D for robot to look at")
- parser.add_argument("--control_expressions", action="store_true", help="Control the facial expressions")
- parser.add_argument("--set_facial_expressions", action="store_true",
- help="Publish facial expressions set using keyboard commands")
- parser.add_argument("--facial_expressions_port", type=str, default="",
- help="The port (topic) name used for receiving and transmitting facial expressions. "
- "Setting the port name without --set_facial_expressions will only receive the facial expressions")
- parser.add_argument("--mware", type=str, default=ICUB_DEFAULT_COMMUNICATOR,
- help="The middleware used for communication. "
- "This can be overriden by providing either of the following environment variables "
- "{WRAPYFI_DEFAULT_COMMUNICATOR, WRAPYFI_DEFAULT_MWARE, "
- "ICUB_DEFAULT_COMMUNICATOR, ICUB_DEFAULT_MWARE}. Defaults to 'yarp'",
- choices=MiddlewareCommunicator.get_communicators())
+ parser.add_argument(
+ "--set_eye_coordinates",
+ action="store_true",
+ help="Publish eye coordinates set using keyboard commands",
+ )
+ parser.add_argument(
+ "--eye_coordinates_port",
+ type=str,
+ default="",
+ help="The port (topic) name used for receiving and transmitting eye orientation "
+ "Setting the port name without --set_eye_coordinates will only receive the coordinates",
+ )
+ parser.add_argument(
+ "--gaze_plane_coordinates_port",
+ type=str,
+ default="",
+ help="The port (topic) name used for receiving plane coordinates in 2D for robot to look at",
+ )
+ parser.add_argument(
+ "--control_expressions",
+ action="store_true",
+ help="Control the facial expressions",
+ )
+ parser.add_argument(
+ "--set_facial_expressions",
+ action="store_true",
+ help="Publish facial expressions set using keyboard commands",
+ )
+ parser.add_argument(
+ "--facial_expressions_port",
+ type=str,
+ default="",
+ help="The port (topic) name used for receiving and transmitting facial expressions. "
+ "Setting the port name without --set_facial_expressions will only receive the facial expressions",
+ )
+ parser.add_argument(
+ "--mware",
+ type=str,
+ default=ICUB_DEFAULT_COMMUNICATOR,
+ help="The middleware used for communication. "
+ "This can be overriden by providing either of the following environment variables "
+ "{WRAPYFI_DEFAULT_COMMUNICATOR, WRAPYFI_DEFAULT_MWARE, "
+ "ICUB_DEFAULT_COMMUNICATOR, ICUB_DEFAULT_MWARE}. Defaults to 'yarp'",
+ choices=MiddlewareCommunicator.get_communicators(),
+ )
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
- assert not (args.headless and (args.set_facial_expressions or args.set_head_eye_coordinates)), \
- "setters require a CV2 window for capturing keystrokes. Disable --set-... for running in headless mode"
+ assert not (
+ args.headless and (args.set_facial_expressions or args.set_head_eye_coordinates)
+ ), "setters require a CV2 window for capturing keystrokes. Disable --set-... for running in headless mode"
controller = ICub(**vars(args))
- controller.runModule()
\ No newline at end of file
+ controller.runModule()
diff --git a/examples/sensors/cam_mic.py b/examples/sensors/cam_mic.py
index a56687e..34ad2ff 100755
--- a/examples/sensors/cam_mic.py
+++ b/examples/sensors/cam_mic.py
@@ -65,9 +65,20 @@
class CamMic(MiddlewareCommunicator):
__registry__ = {}
- def __init__(self, *args, stream=("audio", "video"), mic_source=0,
- mic_rate=44100, mic_chunk=10000, mic_channels=1, img_source=0,
- img_width=320, img_height=240, mware=None, **kwargs):
+ def __init__(
+ self,
+ *args,
+ stream=("audio", "video"),
+ mic_source=0,
+ mic_rate=44100,
+ mic_chunk=10000,
+ mic_channels=1,
+ img_source=0,
+ img_width=320,
+ img_height=240,
+ mware=None,
+ **kwargs,
+ ):
super(MiddlewareCommunicator, self).__init__()
self.mic_source = mic_source
self.mic_rate = mic_rate
@@ -83,11 +94,22 @@ def __init__(self, *args, stream=("audio", "video"), mic_source=0,
self.mware = mware
- @MiddlewareCommunicator.register("Image", "$mware", "CamMic", "/cam_mic/cam_feed",
- carrier="", width="$img_width", height="$img_height", rgb=True, jpg=True,
- queue_size=10)
+ @MiddlewareCommunicator.register(
+ "Image",
+ "$mware",
+ "CamMic",
+ "/cam_mic/cam_feed",
+ carrier="",
+ width="$img_width",
+ height="$img_height",
+ rgb=True,
+ jpg=True,
+ queue_size=10,
+ )
def collect_cam(self, img_width=320, img_height=240, mware=None):
- """Collect images from the camera."""
+ """
+ Collect images from the camera.
+ """
if self.vid_cap is True:
self.vid_cap = cv2.VideoCapture(self.img_source)
if img_width > 0 and img_height > 0:
@@ -102,25 +124,53 @@ def collect_cam(self, img_width=320, img_height=240, mware=None):
print("Video frame grabbed")
else:
print("Video frame not grabbed")
- img = np.random.randint(256, size=(img_height, img_width, 3), dtype=np.uint8)
+ img = np.random.randint(
+ 256, size=(img_height, img_width, 3), dtype=np.uint8
+ )
else:
print("Video capturer not opened")
- img = np.random.randint(256, size=(img_height, img_width, 3), dtype=np.uint8)
- return img,
+ img = np.random.randint(
+ 256, size=(img_height, img_width, 3), dtype=np.uint8
+ )
+ return (img,)
- @MiddlewareCommunicator.register("AudioChunk", "$mware", "CamMic", "/cam_mic/audio_feed",
- carrier="", rate="$mic_rate", chunk="$mic_chunk", channels="$mic_channels")
- def collect_mic(self, aud=None, mic_rate=44100, mic_chunk=int(44100 / 5), mic_channels=1, mware=None):
- """Collect audio from the microphone."""
+ @MiddlewareCommunicator.register(
+ "AudioChunk",
+ "$mware",
+ "CamMic",
+ "/cam_mic/audio_feed",
+ carrier="",
+ rate="$mic_rate",
+ chunk="$mic_chunk",
+ channels="$mic_channels",
+ )
+ def collect_mic(
+ self,
+ aud=None,
+ mic_rate=44100,
+ mic_chunk=int(44100 / 5),
+ mic_channels=1,
+ mware=None,
+ ):
+ """
+ Collect audio from the microphone.
+ """
aud = aud, mic_rate
- return aud,
+ return (aud,)
def capture_cam_mic(self):
- """Capture audio and video from the camera and microphone."""
+ """
+ Capture audio and video from the camera and microphone.
+ """
if self.enable_audio:
# capture the audio stream from the microphone
- with sd.InputStream(device=self.mic_source, channels=self.mic_channels, callback=self._mic_callback,
- blocksize=self.mic_chunk, samplerate=self.mic_rate):
+ with sd.InputStream(
+ device=self.mic_source,
+ channels=self.mic_channels,
+ callback=self._mic_callback,
+ blocksize=self.mic_chunk,
+ samplerate=self.mic_rate,
+ ):
while True:
pass
elif self.enable_video:
@@ -128,36 +178,92 @@ def capture_cam_mic(self):
self.collect_cam(mware=self.mware)
def _mic_callback(self, audio, frames, time, status):
- """Callback for the microphone audio stream."""
+ """
+ Callback for the microphone audio stream.
+ """
if self.enable_video:
- self.collect_cam(img_width=self.img_width, img_height=self.img_height, mware=self.mware)
- self.collect_mic(audio, mic_rate=self.mic_rate, mic_chunk=self.mic_chunk, mic_channels=self.mic_channels, mware=self.mware)
+ self.collect_cam(
+ img_width=self.img_width, img_height=self.img_height, mware=self.mware
+ )
+ self.collect_mic(
+ audio,
+ mic_rate=self.mic_rate,
+ mic_chunk=self.mic_chunk,
+ mic_channels=self.mic_channels,
+ mware=self.mware,
+ )
print(audio.flatten(), audio.min(), audio.mean(), audio.max())
def __del__(self):
- """Release the video capture device."""
+ """
+ Release the video capture device.
+ """
if not isinstance(self.vid_cap, bool):
self.vid_cap.release()
def parse_args():
- """Parse command line arguments."""
- parser = argparse.ArgumentParser(description="A streamer and listener for audio and video streams.")
- parser.add_argument("--mode", type=str, default="publish", choices={"publish", "listen"}, help="The transmission mode")
- parser.add_argument("--mware", type=str, default=DEFAULT_COMMUNICATOR, choices=MiddlewareCommunicator.get_communicators(),
- help="The middleware to use for transmission")
- parser.add_argument("--stream", nargs="+", default=["video", "audio"],
- choices={"video", "audio"},
- help="The streamed sensor data")
- parser.add_argument("--img_source", type=int, default=0, help="The video capture device id (int camera id)")
+ """
+ Parse command line arguments.
+ """
+ parser = argparse.ArgumentParser(
+ description="A streamer and listener for audio and video streams."
+ )
+ parser.add_argument(
+ "--mode",
+ type=str,
+ default="publish",
+ choices={"publish", "listen"},
+ help="The transmission mode",
+ )
+ parser.add_argument(
+ "--mware",
+ type=str,
+ default=DEFAULT_COMMUNICATOR,
+ choices=MiddlewareCommunicator.get_communicators(),
+ help="The middleware to use for transmission",
+ )
+ parser.add_argument(
+ "--stream",
+ nargs="+",
+ default=["video", "audio"],
+ choices={"video", "audio"},
+ help="The streamed sensor data",
+ )
+ parser.add_argument(
+ "--img_source",
+ type=int,
+ default=0,
+ help="The video capture device id (int camera id)",
+ )
parser.add_argument("--img_width", type=int, default=320, help="The image width")
parser.add_argument("--img_height", type=int, default=240, help="The image height")
- parser.add_argument("--mic_source", type=int, default=None, help="The audio capture device id (int microphone id from python3 -m sounddevice)")
- parser.add_argument("--mic_rate", type=int, default=44100, help="The audio sampling rate")
- parser.add_argument("--mic_channels", type=int, default=1, help="The audio channels")
- parser.add_argument("--mic_chunk", type=int, default=10000, help="The transmitted audio chunk size")
- parser.add_argument("--sound_device", type=int, default=0, help="The sound device to use for audio playback")
- parser.add_argument("--list_sound_devices", action="store_true", help="List all available sound devices and exit")
+ parser.add_argument(
+ "--mic_source",
+ type=int,
+ default=None,
+ help="The audio capture device id (int microphone id from python3 -m sounddevice)",
+ )
+ parser.add_argument(
+ "--mic_rate", type=int, default=44100, help="The audio sampling rate"
+ )
+ parser.add_argument(
+ "--mic_channels", type=int, default=1, help="The audio channels"
+ )
+ parser.add_argument(
+ "--mic_chunk", type=int, default=10000, help="The transmitted audio chunk size"
+ )
+ parser.add_argument(
+ "--sound_device",
+ type=int,
+ default=0,
+ help="The sound device to use for audio playback",
+ )
+ parser.add_argument(
+ "--list_sound_devices",
+ action="store_true",
+ help="List all available sound devices and exit",
+ )
return parser.parse_args()
@@ -174,7 +280,9 @@ def sound_play(my_aud, blocking=True, device=0):
sd.play(*my_aud, blocking=blocking, device=device)
return True
except sd.PortAudioError:
- logging.warning("PortAudioError: No device is found or the device is already in use. Will try again in 3 seconds.")
+ logging.warning(
+ "PortAudioError: No device is found or the device is already in use. Will try again in 3 seconds."
+ )
return False
@@ -194,11 +302,21 @@ def main(args):
cam_mic.activate_communication(CamMic.collect_mic, mode="listen")
while True:
if "audio" in args.stream:
- (aud, mic_rate), = cam_mic.collect_mic(mic_rate=args.mic_rate, mic_chunk=args.mic_chunk, mic_channels=args.mic_channels, mware=args.mware)
+ ((aud, mic_rate),) = cam_mic.collect_mic(
+ mic_rate=args.mic_rate,
+ mic_chunk=args.mic_chunk,
+ mic_channels=args.mic_channels,
+ mware=args.mware,
+ )
else:
aud = mic_rate = None
if "video" in args.stream:
- img, = cam_mic.collect_cam(img_source=args.img_source, img_width=args.img_width, img_height=args.img_height, mware=args.mware)
+ (img,) = cam_mic.collect_cam(
+ img_source=args.img_source,
+ img_width=args.img_width,
+ img_height=args.img_height,
+ mware=args.mware,
+ )
else:
img = None
if img is not None:
diff --git a/setup.py b/setup.py
index 9ca68ba..ae26782 100755
--- a/setup.py
+++ b/setup.py
@@ -8,21 +8,31 @@ def check_cv2(default_python="opencv-python"):
import pkg_resources
from packaging import version
import cv2
+
if version.parse(cv2.__version__) < version.parse(REQUIRED_CV2_VERSION):
UPGRADE_CV2 = True
raise ImportError(f"OpenCV version must be at least {REQUIRED_CV2_VERSION}")
except ImportError as e:
import pkg_resources
+
if UPGRADE_CV2:
print(e, "Will try to upgrade OpenCV")
if "opencv-python" in [p.project_name for p in pkg_resources.working_set]:
additional_packages = [f"opencv-python>={REQUIRED_CV2_VERSION}"]
- elif "opencv-contrib-python" in [p.project_name for p in pkg_resources.working_set]:
+ elif "opencv-contrib-python" in [
+ p.project_name for p in pkg_resources.working_set
+ ]:
additional_packages = [f"opencv-contrib-python>={REQUIRED_CV2_VERSION}"]
- elif "opencv-python-headless" in [p.project_name for p in pkg_resources.working_set]:
- additional_packages = [f"opencv-python-headless>={REQUIRED_CV2_VERSION}"]
+ elif "opencv-python-headless" in [
+ p.project_name for p in pkg_resources.working_set
+ ]:
+ additional_packages = [
+ f"opencv-python-headless>={REQUIRED_CV2_VERSION}"
+ ]
else:
- raise ImportError(f"Unknown OpenCV package installed. Please upgrade manually to version >={REQUIRED_CV2_VERSION}")
+ raise ImportError(
+ f"Unknown OpenCV package installed. Please upgrade manually to version >={REQUIRED_CV2_VERSION}"
+ )
else:
print(f"OpenCV not found. Will try to install {default_python}")
additional_packages = [f"{default_python}>={REQUIRED_CV2_VERSION}"]
@@ -33,26 +43,53 @@ def check_cv2(default_python="opencv-python"):
setuptools.setup(
- name = 'wrapyfi',
- version = '0.4.39',
- description = 'Wrapyfi is a wrapper for simplifying Middleware communication',
- url = 'https://github.com/fabawi/wrapyfi/blob/main/',
+ name="wrapyfi",
+ version="0.4.40",
+ description="Wrapyfi is a wrapper for simplifying Middleware communication",
+ url="https://github.com/fabawi/wrapyfi/blob/main/",
project_urls={
- 'Documentation': 'https://wrapyfi.readthedocs.io/en/latest/',
- 'Source': 'https://github.com/fabawi/wrapyfi/',
- 'Tracker': 'https://github.com/fabawi/wrapyfi/issues',
+ "Documentation": "https://wrapyfi.readthedocs.io/en/latest/",
+ "Source": "https://github.com/fabawi/wrapyfi/",
+ "Tracker": "https://github.com/fabawi/wrapyfi/issues",
+ },
+ author="Fares Abawi",
+ author_email="f.abawi@outlook.com",
+ maintainer="Fares Abawi",
+ maintainer_email="f.abawi@outlook.com",
+ packages=setuptools.find_packages(),
+ extras_require={
+ "docs": ["sphinx", "sphinx_rtd_theme", "myst_parser"],
+ "pyzmq": ["pyzmq>=19.0.0"],
+ "numpy": ["numpy>=1.19.2"],
+ "headless": ["wrapyfi[pyzmq]", "wrapyfi[numpy]"]
+ + check_cv2("opencv-python-headless"),
+ "all": ["wrapyfi[pyzmq]", "wrapyfi[numpy]"]
+ + check_cv2("opencv-contrib-python"),
},
- author = 'Fares Abawi',
- author_email = 'f.abawi@outlook.com',
- maintainer = 'Fares Abawi',
- maintainer_email = 'f.abawi@outlook.com',
- packages = setuptools.find_packages(),
- extras_require ={'docs': ['sphinx', 'sphinx_rtd_theme', 'myst_parser'],
- 'pyzmq': ['pyzmq>=19.0.0'],
- 'numpy': ['numpy>=1.19.2'],
- 'headless': ['wrapyfi[pyzmq]', 'wrapyfi[numpy]'] + check_cv2("opencv-python-headless"),
- 'all': ['wrapyfi[pyzmq]', 'wrapyfi[numpy]'] + check_cv2("opencv-contrib-python")},
- install_requires = ['pyyaml>=5.1.1'],
- python_requires = '>=3.6',
- setup_requires = ['cython>=0.29.1']
+ install_requires=["pyyaml>=5.1.1"],
+ python_requires=">=3.6",
+ setup_requires=["cython>=0.29.1"],
+ classifiers=[
+ "Development Status :: 4 - Beta",
+ "Intended Audience :: Developers",
+ "Intended Audience :: Science/Research",
+ "Intended Audience :: Robotics",
+ "Topic :: Middleware Wrapper",
+ "Topic :: Middlware :: ZeroMQ",
+ "Topic :: Middlware :: YARP",
+ "Topic :: Middlware :: ROS",
+ "Topic :: Middlware :: ROS 2",
+ "Topic :: Scientific/Engineering :: Deep Learning",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+ "Topic :: Scientific/Engineering :: Robotics",
+ "Topic :: Scientific/Engineering :: Image Processing",
+ "License :: OSI Approved :: MIT License",
+ "Operating System :: Linux",
+ "Programming Language :: Python :: 3.6",
+ "Programming Language :: Python :: 3.7",
+ "Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ ],
)
diff --git a/upgrade_project.sh b/upgrade_project.sh
index 181bb77..e0d679c 100755
--- a/upgrade_project.sh
+++ b/upgrade_project.sh
@@ -39,6 +39,10 @@ update_version_in_package_xml() {
# GENERATE DOCUMENTATION
#######################################################################################################################
+# refactor code with black
+python3 -m pip install black
+black .
+
# compile docs with sphinx
cd docs
python3 -m pip install -r requirements.txt
diff --git a/wrapyfi/__init__.py b/wrapyfi/__init__.py
index 003cc19..d04edee 100755
--- a/wrapyfi/__init__.py
+++ b/wrapyfi/__init__.py
@@ -7,7 +7,7 @@ def get_project_info_from_setup():
try:
curr_dir = os.path.dirname(__file__)
setup_path = os.path.join(curr_dir, "..", "setup.py")
- with open(setup_path, 'r') as f:
+ with open(setup_path, "r") as f:
content = f.read()
except FileNotFoundError:
return {}
@@ -15,54 +15,72 @@ def get_project_info_from_setup():
version_match = re.search(r"version\s*=\s*['\"]([^'\"]*)['\"]", content)
url_match = re.search(r"url\s*=\s*['\"]([^'\"]*)['\"]", content)
doc_match = re.search(r"'Documentation':\s*['\"]([^'\"]*)['\"]", content)
-
+
if not name_match or not version_match or not url_match:
# raise RuntimeError("Unable to find name, version, or url string.")
return {}
-
+
return {
- 'name': name_match.group(1),
- 'version': version_match.group(1),
- 'url': url_match.group(1),
- 'doc': None if not doc_match else doc_match.group(1)
+ "name": name_match.group(1),
+ "version": version_match.group(1),
+ "url": url_match.group(1),
+ "doc": None if not doc_match else doc_match.group(1),
}
# extract project info
project_info = get_project_info_from_setup()
-__version__ = project_info.get('version', None)
-__url__ = project_info.get('url', None)
-__doc__ = project_info.get('doc', None)
-name = project_info.get('name', 'wrapyfi')
+__version__ = project_info.get("version", None)
+__url__ = project_info.get("url", None)
+__doc__ = project_info.get("doc", None)
+name = project_info.get("name", "wrapyfi")
if __version__ is None or __url__ is None or __doc__ is None:
try:
from importlib import metadata
+
mdata = metadata.metadata(__name__)
__version__ = metadata.version(__name__)
__url__ = mdata["Home-page"]
# when installed with PyPi
if __url__ is None:
for url_extract in mdata.get_all("Project-URL"):
- __url__ = url_extract.split(", ")[1] if url_extract.split(", ")[0] == "Homepage" else __url__
+ __url__ = (
+ url_extract.split(", ")[1]
+ if url_extract.split(", ")[0] == "Homepage"
+ else __url__
+ )
if __doc__ is None:
for url_extract in mdata.get_all("Project-URL"):
- __doc__ = url_extract.split(", ")[1] if url_extract.split(", ")[0] == "Documentation" else __doc__
+ __doc__ = (
+ url_extract.split(", ")[1]
+ if url_extract.split(", ")[0] == "Documentation"
+ else __doc__
+ )
except ImportError:
try:
# when Python < 3.8 and setuptools/pip have not been updated
import pkg_resources
+
mdata = pkg_resources.get_distribution(__name__).metadata
__version__ = pkg_resources.require(__name__)[0].version
__url__ = mdata["Home-page"]
# when installed with PyPi
if __url__ is None:
for url_extract in mdata.get_all("Project-URL"):
- __url__ = url_extract.split(", ")[1] if url_extract.split(", ")[0] == "Homepage" else __url__
+ __url__ = (
+ url_extract.split(", ")[1]
+ if url_extract.split(", ")[0] == "Homepage"
+ else __url__
+ )
if __doc__ is None:
for url_extract in mdata.get_all("Project-URL"):
- __doc__ = url_extract.split(", ")[1] if url_extract.split(", ")[0] == "Documentation" else __doc__
+ __doc__ = (
+ url_extract.split(", ")[1]
+ if url_extract.split(", ")[0] == "Documentation"
+ else __doc__
+ )
except pkg_resources.DistributionNotFound:
__version__ = "unknown_version"
__url__ = "unknown_url"
@@ -77,4 +95,5 @@ def get_project_info_from_setup():
PluginRegistrar.scan()
import logging
+
logging.getLogger().setLevel(logging.INFO)
diff --git a/wrapyfi/clients/__init__.py b/wrapyfi/clients/__init__.py
index b4c87df..13a768e 100755
--- a/wrapyfi/clients/__init__.py
+++ b/wrapyfi/clients/__init__.py
@@ -6,9 +6,18 @@
@Clients.register("MMO", "fallback")
class FallbackClient(Client):
- def __init__(self, name: str, in_topic: str, carrier: str = "", missing_middleware_object: str = "", **kwargs):
- logging.warning(f"Fallback client employed due to missing middleware or object type: "
- f"{missing_middleware_object}")
+ def __init__(
+ self,
+ name: str,
+ in_topic: str,
+ carrier: str = "",
+ missing_middleware_object: str = "",
+ **kwargs,
+ ):
+ logging.warning(
+ f"Fallback client employed due to missing middleware or object type: "
+ f"{missing_middleware_object}"
+ )
Client.__init__(self, name, in_topic, carrier=carrier, **kwargs)
self.missing_middleware_object = missing_middleware_object
@@ -25,4 +34,4 @@ def _await_reply(self):
return None
def close(self):
- return None
\ No newline at end of file
+ return None
diff --git a/wrapyfi/clients/ros.py b/wrapyfi/clients/ros.py
index ec02b12..3117597 100755
--- a/wrapyfi/clients/ros.py
+++ b/wrapyfi/clients/ros.py
@@ -9,13 +9,23 @@
import std_msgs.msg
from wrapyfi.connect.clients import Client, Clients
-from wrapyfi.middlewares.ros import ROSMiddleware, ROSNativeObjectService, ROSImageService
+from wrapyfi.middlewares.ros import (
+ ROSMiddleware,
+ ROSNativeObjectService,
+ ROSImageService,
+)
from wrapyfi.encoders import JsonEncoder, JsonDecodeHook
class ROSClient(Client):
- def __init__(self, name: str, in_topic: str, carrier: str = "tcp",
- ros_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ in_topic: str,
+ carrier: str = "tcp",
+ ros_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
Initialize the client.
@@ -26,7 +36,9 @@ def __init__(self, name: str, in_topic: str, carrier: str = "tcp",
:param kwargs: dict: Additional kwargs for the client
"""
if carrier or carrier != "tcp":
- logging.warning("[ROS] ROS does not support other carriers than TCP for REQ/REP pattern. Using TCP.")
+ logging.warning(
+ "[ROS] ROS does not support other carriers than TCP for REQ/REP pattern. Using TCP."
+ )
carrier = "tcp"
super().__init__(name, in_topic, carrier=carrier, **kwargs)
ROSMiddleware.activate(**ros_kwargs or {})
@@ -45,9 +57,16 @@ def __del__(self):
@Clients.register("NativeObject", "ros")
class ROSNativeObjectClient(ROSClient):
- def __init__(self, name: str, in_topic: str, carrier: str = "tcp", persistent: bool = True,
- serializer_kwargs: Optional[dict] = None,
- deserializer_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ in_topic: str,
+ carrier: str = "tcp",
+ persistent: bool = True,
+ serializer_kwargs: Optional[dict] = None,
+ deserializer_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
The NativeObject client using the ROS String message assuming the data is serialized as a JSON string.
Deserializes the data (including plugins) using the decoder and parses it to a Python object.
@@ -76,7 +95,9 @@ def establish(self):
Establish the client's connection to the ROS service.
"""
rospy.wait_for_service(self.in_topic)
- self._client = rospy.ServiceProxy(self.in_topic, ROSNativeObjectService, persistent=self.persistent)
+ self._client = rospy.ServiceProxy(
+ self.in_topic, ROSNativeObjectService, persistent=self.persistent
+ )
if self.persistent:
self.established = True
@@ -103,13 +124,19 @@ def _request(self, *args, **kwargs):
:param args: tuple: Positional arguments to send in the request.
:param kwargs: dict: Keyword arguments to send in the request.
"""
- args_str = json.dumps([args, kwargs], cls=self._plugin_encoder, **self._plugin_kwargs,
- serializer_kwrags=self._serializer_kwargs)
+ args_str = json.dumps(
+ [args, kwargs],
+ cls=self._plugin_encoder,
+ **self._plugin_kwargs,
+ serializer_kwrags=self._serializer_kwargs,
+ )
args_msg = std_msgs.msg.String()
args_msg.data = args_str
msg = self._client(args_msg)
- obj = json.loads(msg.data, object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs)
+ obj = json.loads(
+ msg.data, object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs
+ )
self._queue.put(obj, block=False)
def _await_reply(self) -> Any:
@@ -122,16 +149,29 @@ def _await_reply(self) -> Any:
reply = self._queue.get(block=True)
return reply
except queue.Full:
- logging.warning(f"[ROS] Discarding data because queue is full. "
- f"This happened due to bad synchronization in {self.__class__.__name__}")
+ logging.warning(
+ f"[ROS] Discarding data because queue is full. "
+ f"This happened due to bad synchronization in {self.__class__.__name__}"
+ )
return None
@Clients.register("Image", "ros")
class ROSImageClient(ROSClient):
- def __init__(self, name: str, in_topic: str, carrier: str = "tcp", width: int = -1, height: int = -1, persistent: bool = True,
- rgb: bool = True, fp: bool = False, serializer_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ in_topic: str,
+ carrier: str = "tcp",
+ width: int = -1,
+ height: int = -1,
+ persistent: bool = True,
+ rgb: bool = True,
+ fp: bool = False,
+ serializer_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
The Image client using the ROS Image message parsed to a numpy array.
@@ -146,7 +186,9 @@ def __init__(self, name: str, in_topic: str, carrier: str = "tcp", width: int =
:param serializer_kwargs: dict: Additional kwargs for the serializer
"""
if "jpg" in kwargs:
- logging.warning("[ROS] ROS currently does not support JPG encoding in REQ/REP. Using raw image.")
+ logging.warning(
+ "[ROS] ROS currently does not support JPG encoding in REQ/REP. Using raw image."
+ )
kwargs.pop("jpg")
super().__init__(name, in_topic, carrier=carrier, **kwargs)
self.width = width
@@ -155,10 +197,10 @@ def __init__(self, name: str, in_topic: str, carrier: str = "tcp", width: int =
self.fp = fp
if self.fp:
- self._encoding = '32FC3' if self.rgb else '32FC1'
+ self._encoding = "32FC3" if self.rgb else "32FC1"
self._type = np.float32
else:
- self._encoding = 'bgr8' if self.rgb else 'mono8'
+ self._encoding = "bgr8" if self.rgb else "mono8"
self._type = np.uint8
self._pixel_bytes = (3 if self.rgb else 1) * np.dtype(self._type).itemsize
@@ -176,7 +218,9 @@ def establish(self):
Establish the client's connection to the ROS service.
"""
rospy.wait_for_service(self.in_topic)
- self._client = rospy.ServiceProxy(self.in_topic, ROSImageService, persistent=self.persistent)
+ self._client = rospy.ServiceProxy(
+ self.in_topic, ROSImageService, persistent=self.persistent
+ )
if self.persistent:
self.established = True
@@ -203,12 +247,19 @@ def _request(self, *args, **kwargs):
:param args: tuple: Positional arguments to send in the request
:param kwargs: dict: Keyword arguments to send in the request
"""
- args_str = json.dumps([args, kwargs], cls=self._plugin_encoder, **self._plugin_kwargs,
- serializer_kwrags=self._serializer_kwargs)
+ args_str = json.dumps(
+ [args, kwargs],
+ cls=self._plugin_encoder,
+ **self._plugin_kwargs,
+ serializer_kwrags=self._serializer_kwargs,
+ )
args_msg = std_msgs.msg.String()
args_msg.data = args_str
msg = self._client(args_msg)
- self._queue.put((msg.height, msg.width, msg.encoding, msg.is_bigendian, msg.data), block=False)
+ self._queue.put(
+ (msg.height, msg.width, msg.encoding, msg.is_bigendian, msg.data),
+ block=False,
+ )
def _await_reply(self):
"""
@@ -220,23 +271,42 @@ def _await_reply(self):
height, width, encoding, is_bigendian, data = self._queue.get(block=True)
if encoding != self._encoding:
raise ValueError("Incorrect encoding for listener")
- if 0 < self.width != width or 0 < self.height != height or len(data) != height * width * self._pixel_bytes:
+ if (
+ 0 < self.width != width
+ or 0 < self.height != height
+ or len(data) != height * width * self._pixel_bytes
+ ):
raise ValueError("Incorrect image shape for listener")
- img = np.frombuffer(data, dtype=np.dtype(self._type).newbyteorder('>' if is_bigendian else '<')).reshape((height, width, -1))
+ img = np.frombuffer(
+ data,
+ dtype=np.dtype(self._type).newbyteorder(">" if is_bigendian else "<"),
+ ).reshape((height, width, -1))
if img.shape[2] == 1:
img = img.squeeze(axis=2)
return img
except queue.Full:
- logging.warning(f"[ROS] Discarding data because queue is full. "
- f"This happened due to bad synchronization in {self.__name__}")
+ logging.warning(
+ f"[ROS] Discarding data because queue is full. "
+ f"This happened due to bad synchronization in {self.__name__}"
+ )
return None
@Clients.register("AudioChunk", "ros")
class ROSAudioChunkClient(ROSClient):
- def __init__(self, name: str, in_topic: str, carrier: str = "tcp", persistent: bool = True,
- channels: int = 1, rate: int = 44100, chunk: int = -1, serializer_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ in_topic: str,
+ carrier: str = "tcp",
+ persistent: bool = True,
+ channels: int = 1,
+ rate: int = 44100,
+ chunk: int = -1,
+ serializer_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
The AudioChunk client using the ROS Audio message parsed to a numpy array.
@@ -271,12 +341,18 @@ def establish(self):
from wrapyfi_ros_interfaces.srv import ROSAudioService
except ImportError:
import wrapyfi
- logging.error("[ROS] Could not import ROSAudioService. "
- "Make sure the ROS services in wrapyfi_extensions/wrapyfi_ros_interfaces are compiled. "
- "Refer to the documentation for more information: \n" +
- wrapyfi.__doc__ + "ros_interfaces_lnk.html")
+
+ logging.error(
+ "[ROS] Could not import ROSAudioService. "
+ "Make sure the ROS services in wrapyfi_extensions/wrapyfi_ros_interfaces are compiled. "
+ "Refer to the documentation for more information: \n"
+ + wrapyfi.__doc__
+ + "ros_interfaces_lnk.html"
+ )
sys.exit(1)
- self._client = rospy.ServiceProxy(self.in_topic, ROSAudioService, persistent=self.persistent)
+ self._client = rospy.ServiceProxy(
+ self.in_topic, ROSAudioService, persistent=self.persistent
+ )
self._req_msg = ROSAudioService._request_class()
if self.persistent:
self.established = True
@@ -304,12 +380,26 @@ def _request(self, *args, **kwargs):
:param args: tuple: Positional arguments to send in the request
:param kwargs: dict: Keyword arguments to send in the request
"""
- args_str = json.dumps([args, kwargs], cls=self._plugin_encoder, **self._plugin_kwargs,
- serializer_kwrags=self._serializer_kwargs)
+ args_str = json.dumps(
+ [args, kwargs],
+ cls=self._plugin_encoder,
+ **self._plugin_kwargs,
+ serializer_kwrags=self._serializer_kwargs,
+ )
args_msg = self._req_msg
args_msg.request = args_str
msg = self._client(args_msg).response
- self._queue.put((msg.chunk_size, msg.channels, msg.sample_rate, msg.encoding, msg.is_bigendian, msg.data), block=False)
+ self._queue.put(
+ (
+ msg.chunk_size,
+ msg.channels,
+ msg.sample_rate,
+ msg.encoding,
+ msg.is_bigendian,
+ msg.data,
+ ),
+ block=False,
+ )
def _await_reply(self):
"""
@@ -318,19 +408,28 @@ def _await_reply(self):
:return: Tuple[np.array, int]: The received audio chunk and rate from the ROS service
"""
try:
- chunk, channels, rate, encoding, is_bigendian, data = self._queue.get(block=True)
+ chunk, channels, rate, encoding, is_bigendian, data = self._queue.get(
+ block=True
+ )
if 0 < self.rate != rate:
raise ValueError("Incorrect audio rate for client")
- if encoding not in ['S16LE', 'S16BE']:
+ if encoding not in ["S16LE", "S16BE"]:
raise ValueError("Incorrect encoding for client")
- if 0 < self.chunk != chunk or self.channels != channels or len(data) != chunk * channels * 4:
+ if (
+ 0 < self.chunk != chunk
+ or self.channels != channels
+ or len(data) != chunk * channels * 4
+ ):
raise ValueError("Incorrect audio shape for client")
- aud = np.frombuffer(data, dtype=np.dtype(np.float32).newbyteorder('>' if is_bigendian else '<')).reshape(
- (chunk, channels))
+ aud = np.frombuffer(
+ data,
+ dtype=np.dtype(np.float32).newbyteorder(">" if is_bigendian else "<"),
+ ).reshape((chunk, channels))
# aud = aud / 32767.0
return aud, rate
except queue.Full:
- logging.warning(f"[ROS] Discarding data because queue is full. "
- f"This happened due to bad synchronization in {self.__name__}")
+ logging.warning(
+ f"[ROS] Discarding data because queue is full. "
+ f"This happened due to bad synchronization in {self.__name__}"
+ )
return None, self.rate
-
diff --git a/wrapyfi/clients/ros2.py b/wrapyfi/clients/ros2.py
index 291d84e..4190adb 100755
--- a/wrapyfi/clients/ros2.py
+++ b/wrapyfi/clients/ros2.py
@@ -20,7 +20,9 @@
class ROS2Client(Client, Node):
- def __init__(self, name: str, in_topic: str, ros2_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self, name: str, in_topic: str, ros2_kwargs: Optional[dict] = None, **kwargs
+ ):
"""
Initialize the client.
@@ -32,7 +34,8 @@ def __init__(self, name: str, in_topic: str, ros2_kwargs: Optional[dict] = None,
carrier = "tcp"
if "carrier" in kwargs and kwargs["carrier"] not in ["", None]:
logging.warning(
- "[ROS 2] ROS 2 currently does not support explicit carrier setting for PUB/SUB pattern. Using TCP.")
+ "[ROS 2] ROS 2 currently does not support explicit carrier setting for PUB/SUB pattern. Using TCP."
+ )
if "carrier" in kwargs:
del kwargs["carrier"]
ROS2Middleware.activate(**ros2_kwargs or {})
@@ -53,9 +56,14 @@ def __del__(self):
@Clients.register("NativeObject", "ros2")
class ROS2NativeObjectClient(ROS2Client):
- def __init__(self, name: str, in_topic: str,
- serializer_kwargs: Optional[dict] = None,
- deserializer_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ in_topic: str,
+ serializer_kwargs: Optional[dict] = None,
+ deserializer_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
The NativeObject listener using the ROS 2 String message assuming the data is serialized as a JSON string.
Deserializes the data (including plugins) using the decoder and parses it to a Python object.
@@ -83,16 +91,20 @@ def establish(self):
from wrapyfi_ros2_interfaces.srv import ROS2NativeObjectService
except ImportError:
import wrapyfi
- logging.error("[ROS 2] Could not import ROS2NativeObjectService. "
- "Make sure the ROS 2 services in wrapyfi_extensions/wrapyfi_ros2_interfaces are compiled. "
- "Refer to the documentation for more information: \n" +
- wrapyfi.__doc__ + "ros2_interfaces_lnk.html")
+
+ logging.error(
+ "[ROS 2] Could not import ROS2NativeObjectService. "
+ "Make sure the ROS 2 services in wrapyfi_extensions/wrapyfi_ros2_interfaces are compiled. "
+ "Refer to the documentation for more information: \n"
+ + wrapyfi.__doc__
+ + "ros2_interfaces_lnk.html"
+ )
sys.exit(1)
self._client = self.create_client(ROS2NativeObjectService, self.in_topic)
self._req_msg = ROS2NativeObjectService.Request()
while not self._client.wait_for_service(timeout_sec=1.0):
- logging.info('[ROS 2] Service not available, waiting again...')
+ logging.info("[ROS 2] Service not available, waiting again...")
self.established = True
def request(self, *args, **kwargs):
@@ -119,8 +131,12 @@ def _request(self, *args, **kwargs):
:param kwargs: dict: Keyword arguments to send in the request
"""
# transmit args to server
- args_str = json.dumps([args, kwargs], cls=self._plugin_encoder, **self._plugin_kwargs,
- serializer_kwrags=self._serializer_kwargs)
+ args_str = json.dumps(
+ [args, kwargs],
+ cls=self._plugin_encoder,
+ **self._plugin_kwargs,
+ serializer_kwrags=self._serializer_kwargs,
+ )
self._req_msg.request = args_str
future = self._client.call_async(self._req_msg)
# receive message from server
@@ -129,7 +145,11 @@ def _request(self, *args, **kwargs):
if future.done():
try:
msg = future.result()
- obj = json.loads(msg.response, object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs)
+ obj = json.loads(
+ msg.response,
+ object_hook=self._plugin_decoder_hook,
+ **self._deserializer_kwargs,
+ )
self._queue.put(obj, block=False)
except Exception as e:
logging.error("[ROS 2] Service call failed: %s" % e)
@@ -145,15 +165,27 @@ def _await_reply(self) -> Any:
reply = self._queue.get(block=True)
return reply
except queue.Full:
- logging.warning(f"[ROS 2] Discarding data because queue is full. "
- f"This happened due to bad synchronization in {self.__name__}")
+ logging.warning(
+ f"[ROS 2] Discarding data because queue is full. "
+ f"This happened due to bad synchronization in {self.__name__}"
+ )
return None
@Clients.register("Image", "ros2")
class ROS2ImageClient(ROS2Client):
- def __init__(self, name: str, in_topic: str, width: int = -1, height: int = -1,
- rgb: bool = True, fp: bool = False, jpg: bool = False, serializer_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ in_topic: str,
+ width: int = -1,
+ height: int = -1,
+ rgb: bool = True,
+ fp: bool = False,
+ jpg: bool = False,
+ serializer_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
The Image client using the ROS 2 Image message parsed to a numpy array.
@@ -177,10 +209,10 @@ def __init__(self, name: str, in_topic: str, width: int = -1, height: int = -1,
self.jpg = jpg
if self.fp:
- self._encoding = '32FC3' if self.rgb else '32FC1'
+ self._encoding = "32FC3" if self.rgb else "32FC1"
self._type = np.float32
else:
- self._encoding = 'bgr8' if self.rgb else 'mono8'
+ self._encoding = "bgr8" if self.rgb else "mono8"
self._type = np.uint8
self._pixel_bytes = (3 if self.rgb else 1) * np.dtype(self._type).itemsize
@@ -196,13 +228,20 @@ def establish(self):
Establish the client's connection to the ROS 2 service.
"""
try:
- from wrapyfi_ros2_interfaces.srv import ROS2ImageService, ROS2CompressedImageService
+ from wrapyfi_ros2_interfaces.srv import (
+ ROS2ImageService,
+ ROS2CompressedImageService,
+ )
except ImportError:
import wrapyfi
- logging.error("[ROS 2] Could not import ROS2ImageService. "
- "Make sure the ROS 2 services in wrapyfi_extensions/wrapyfi_ros2_interfaces are compiled. "
- "Refer to the documentation for more information: \n" +
- wrapyfi.__doc__ + "ros2_interfaces_lnk.html")
+
+ logging.error(
+ "[ROS 2] Could not import ROS2ImageService. "
+ "Make sure the ROS 2 services in wrapyfi_extensions/wrapyfi_ros2_interfaces are compiled. "
+ "Refer to the documentation for more information: \n"
+ + wrapyfi.__doc__
+ + "ros2_interfaces_lnk.html"
+ )
sys.exit(1)
if self.jpg:
self._client = self.create_client(ROS2CompressedImageService, self.in_topic)
@@ -212,7 +251,7 @@ def establish(self):
self._req_msg = ROS2ImageService.Request()
while not self._client.wait_for_service(timeout_sec=1.0):
- logging.info('[ROS 2] Service not available, waiting again...')
+ logging.info("[ROS 2] Service not available, waiting again...")
self.established = True
def request(self, *args, **kwargs):
@@ -239,8 +278,12 @@ def _request(self, *args, **kwargs):
:param kwargs: dict: Keyword arguments to send in the request
"""
# transmit args to server
- args_str = json.dumps([args, kwargs], cls=self._plugin_encoder, **self._plugin_kwargs,
- serializer_kwrags=self._serializer_kwargs)
+ args_str = json.dumps(
+ [args, kwargs],
+ cls=self._plugin_encoder,
+ **self._plugin_kwargs,
+ serializer_kwrags=self._serializer_kwargs,
+ )
self._req_msg.request = args_str
future = self._client.call_async(self._req_msg)
# receive message from server
@@ -253,8 +296,16 @@ def _request(self, *args, **kwargs):
if self.jpg:
self._queue.put((data.format, data.data), block=False)
else:
- self._queue.put((data.height, data.width, data.encoding, data.is_bigendian, data.data),
- block=False)
+ self._queue.put(
+ (
+ data.height,
+ data.width,
+ data.encoding,
+ data.is_bigendian,
+ data.data,
+ ),
+ block=False,
+ )
except Exception as e:
logging.error("[ROS 2] Service call failed: %s" % e)
break
@@ -268,33 +319,55 @@ def _await_reply(self):
try:
if self.jpg:
format, data = self._queue.get(block=True)
- if format != 'jpeg':
+ if format != "jpeg":
raise ValueError(f"Unsupported image format: {format}")
if self.rgb:
img = cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_COLOR)
else:
- img = cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_GRAYSCALE)
+ img = cv2.imdecode(
+ np.frombuffer(data, np.uint8), cv2.IMREAD_GRAYSCALE
+ )
else:
- height, width, encoding, is_bigendian, data = self._queue.get(block=True)
+ height, width, encoding, is_bigendian, data = self._queue.get(
+ block=True
+ )
if encoding != self._encoding:
raise ValueError("Incorrect encoding for listener")
- if 0 < self.width != width or 0 < self.height != height or len(data) != height * width * self._pixel_bytes:
+ if (
+ 0 < self.width != width
+ or 0 < self.height != height
+ or len(data) != height * width * self._pixel_bytes
+ ):
raise ValueError("Incorrect image shape for listener")
- img = np.frombuffer(data, dtype=np.dtype(self._type).newbyteorder('>' if is_bigendian else '<')).reshape((height, width, -1))
+ img = np.frombuffer(
+ data,
+ dtype=np.dtype(self._type).newbyteorder(
+ ">" if is_bigendian else "<"
+ ),
+ ).reshape((height, width, -1))
if img.shape[2] == 1:
img = img.squeeze(axis=2)
return img
except queue.Full:
- logging.warning(f"[ROS 2] Discarding data because queue is full. "
- f"This happened due to bad synchronization in {self.__name__}")
+ logging.warning(
+ f"[ROS 2] Discarding data because queue is full. "
+ f"This happened due to bad synchronization in {self.__name__}"
+ )
return None
@Clients.register("AudioChunk", "ros2")
class ROS2AudioChunkClient(ROS2Client):
- def __init__(self, name: str, in_topic: str,
- channels: int = 1, rate: int = 44100, chunk: int = -1,
- serializer_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ in_topic: str,
+ channels: int = 1,
+ rate: int = 44100,
+ chunk: int = -1,
+ serializer_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
The AudioChunk client using the ROS 2 Audio message parsed to a numpy array.
@@ -328,16 +401,20 @@ def establish(self):
from wrapyfi_ros2_interfaces.srv import ROS2AudioService
except ImportError:
import wrapyfi
- logging.error("[ROS 2] Could not import ROS2AudioService. "
- "Make sure the ROS 2 services in wrapyfi_extensions/wrapyfi_ros2_interfaces are compiled. "
- "Refer to the documentation for more information: \n" +
- wrapyfi.__doc__ + "ros2_interfaces_lnk.html")
+
+ logging.error(
+ "[ROS 2] Could not import ROS2AudioService. "
+ "Make sure the ROS 2 services in wrapyfi_extensions/wrapyfi_ros2_interfaces are compiled. "
+ "Refer to the documentation for more information: \n"
+ + wrapyfi.__doc__
+ + "ros2_interfaces_lnk.html"
+ )
sys.exit(1)
self._client = self.create_client(ROS2AudioService, self.in_topic)
self._req_msg = ROS2AudioService.Request()
while not self._client.wait_for_service(timeout_sec=1.0):
- logging.info('[ROS 2] Service not available, waiting again...')
+ logging.info("[ROS 2] Service not available, waiting again...")
self.established = True
def request(self, *args, **kwargs):
@@ -363,8 +440,12 @@ def _request(self, *args, **kwargs):
:param args: tuple: Positional arguments to send in the request
:param kwargs: dict: Keyword arguments to send in the request
"""
- args_str = json.dumps([args, kwargs], cls=self._plugin_encoder, **self._plugin_kwargs,
- serializer_kwrags=self._serializer_kwargs)
+ args_str = json.dumps(
+ [args, kwargs],
+ cls=self._plugin_encoder,
+ **self._plugin_kwargs,
+ serializer_kwrags=self._serializer_kwargs,
+ )
self._req_msg.request = args_str
future = self._client.call_async(self._req_msg)
@@ -374,8 +455,17 @@ def _request(self, *args, **kwargs):
try:
msg = future.result()
data = msg.response
- self._queue.put((data.chunk_size, data.channels, data.sample_rate, data.encoding, data.is_bigendian, data.data),
- block=False)
+ self._queue.put(
+ (
+ data.chunk_size,
+ data.channels,
+ data.sample_rate,
+ data.encoding,
+ data.is_bigendian,
+ data.data,
+ ),
+ block=False,
+ )
except Exception as e:
logging.error("[ROS 2] Service call failed: %s" % e)
break
@@ -387,19 +477,28 @@ def _await_reply(self):
:return: Tuple[np.ndarray, int]: The received message as a numpy array formatted as (np.ndarray[audio_chunk, channels], int[samplerate])
"""
try:
- chunk, channels, rate, encoding, is_bigendian, data = self._queue.get(block=False)
+ chunk, channels, rate, encoding, is_bigendian, data = self._queue.get(
+ block=False
+ )
if 0 < self.rate != rate:
raise ValueError("Incorrect audio rate for publisher")
- if encoding not in ['S16LE', 'S16BE']:
+ if encoding not in ["S16LE", "S16BE"]:
raise ValueError("Incorrect encoding for listener")
- if 0 < self.chunk != chunk or self.channels != channels or len(data) != chunk * channels * 4:
+ if (
+ 0 < self.chunk != chunk
+ or self.channels != channels
+ or len(data) != chunk * channels * 4
+ ):
raise ValueError("Incorrect audio shape for listener")
- aud = np.frombuffer(data, dtype=np.dtype(np.float32).newbyteorder('>' if is_bigendian else '<')).reshape(
- (chunk, channels))
+ aud = np.frombuffer(
+ data,
+ dtype=np.dtype(np.float32).newbyteorder(">" if is_bigendian else "<"),
+ ).reshape((chunk, channels))
# aud = aud / 32767.0
return aud, rate
except queue.Full:
- logging.warning(f"[ROS 2] Discarding data because queue is full. "
- f"This happened due to bad synchronization in {self.__name__}")
+ logging.warning(
+ f"[ROS 2] Discarding data because queue is full. "
+ f"This happened due to bad synchronization in {self.__name__}"
+ )
return None
-
diff --git a/wrapyfi/clients/yarp.py b/wrapyfi/clients/yarp.py
index 7b53fee..4ffe3a2 100755
--- a/wrapyfi/clients/yarp.py
+++ b/wrapyfi/clients/yarp.py
@@ -14,8 +14,15 @@
class YarpClient(Client):
- def __init__(self, name: str, in_topic: str, carrier: Literal["tcp", "udp", "mcast"] = "tcp",
- persistent: bool = True, yarp_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ in_topic: str,
+ carrier: Literal["tcp", "udp", "mcast"] = "tcp",
+ persistent: bool = True,
+ yarp_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
Initialize the client.
@@ -30,6 +37,7 @@ def __init__(self, name: str, in_topic: str, carrier: Literal["tcp", "udp", "mca
YarpMiddleware.activate(**yarp_kwargs or {})
self.persistent = persistent
+
def close(self):
"""
Close the client.
@@ -44,9 +52,16 @@ def __del__(self):
@Clients.register("NativeObject", "yarp")
class YarpNativeObjectClient(YarpClient):
- def __init__(self, name: str, in_topic: str, carrier: Literal["tcp", "udp", "mcast"] = "tcp",
- persistent: bool = True,
- serializer_kwargs: Optional[dict] = None, deserializer_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ in_topic: str,
+ carrier: Literal["tcp", "udp", "mcast"] = "tcp",
+ persistent: bool = True,
+ serializer_kwargs: Optional[dict] = None,
+ deserializer_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
The NativeObject listener using the YARP Bottle construct assuming the data is serialized as a JSON string.
Deserializes the data (including plugins) using the decoder and parses it to a Python object.
@@ -61,7 +76,9 @@ def __init__(self, name: str, in_topic: str, carrier: Literal["tcp", "udp", "mca
:param deserializer_kwargs: dict: Additional kwargs for the deserializer
:param kwargs: dict: Additional kwargs for the client
"""
- super().__init__(name, in_topic, carrier=carrier, persistent=persistent, **kwargs)
+ super().__init__(
+ name, in_topic, carrier=carrier, persistent=persistent, **kwargs
+ )
self._port = None
self._queue = queue.Queue(maxsize=1)
@@ -108,8 +125,12 @@ def _request(self, *args, **kwargs):
:param args: tuple: Positional arguments to send in the request
:param kwargs: dict: Keyword arguments to send in the request
"""
- args_str = json.dumps([args, kwargs], cls=self._plugin_encoder, **self._plugin_kwargs,
- serializer_kwrags=self._serializer_kwargs)
+ args_str = json.dumps(
+ [args, kwargs],
+ cls=self._plugin_encoder,
+ **self._plugin_kwargs,
+ serializer_kwrags=self._serializer_kwargs,
+ )
args_msg = yarp.Bottle()
args_msg.clear()
args_msg.addString(args_str)
@@ -118,7 +139,11 @@ def _request(self, *args, **kwargs):
msg.clear()
self._port.write(args_msg, msg)
- obj = json.loads(msg.get(0).asString(), object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs)
+ obj = json.loads(
+ msg.get(0).asString(),
+ object_hook=self._plugin_decoder_hook,
+ **self._deserializer_kwargs,
+ )
self._queue.put(obj, block=False)
def _await_reply(self):
@@ -131,16 +156,28 @@ def _await_reply(self):
reply = self._queue.get(block=True)
return reply
except queue.Full:
- logging.warning(f"[YARP] Discarding data because queue is full. "
- f"This happened due to bad synchronization in {self.__name__}")
+ logging.warning(
+ f"[YARP] Discarding data because queue is full. "
+ f"This happened due to bad synchronization in {self.__name__}"
+ )
return None
@Clients.register("Image", "yarp")
class YarpImageClient(YarpNativeObjectClient):
- def __init__(self, name: str, in_topic: str, carrier: Literal["tcp", "udp", "mcast"] = "tcp",
- width: int = -1, height: int = -1, rgb: bool = True, fp: bool = False,
- persistent: bool = True, serializer_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ in_topic: str,
+ carrier: Literal["tcp", "udp", "mcast"] = "tcp",
+ width: int = -1,
+ height: int = -1,
+ rgb: bool = True,
+ fp: bool = False,
+ persistent: bool = True,
+ serializer_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
The Image client using the YARP Bottle construct parsed to a numpy array.
@@ -155,9 +192,18 @@ def __init__(self, name: str, in_topic: str, carrier: Literal["tcp", "udp", "mca
:param serializer_kwargs: dict: Additional kwargs for the serializer
"""
if "jpg" in kwargs:
- logging.warning("[YARP] YARP currently does not support JPG encoding in REQ/REP. Using raw image.")
+ logging.warning(
+ "[YARP] YARP currently does not support JPG encoding in REQ/REP. Using raw image."
+ )
kwargs.pop("jpg")
- super().__init__(name, in_topic, carrier=carrier, persistent=persistent, serializer_kwargs=serializer_kwargs, **kwargs)
+ super().__init__(
+ name,
+ in_topic,
+ carrier=carrier,
+ persistent=persistent,
+ serializer_kwargs=serializer_kwargs,
+ **kwargs,
+ )
self.width = width
self.height = height
self.rgb = rgb
@@ -170,15 +216,23 @@ def _request(self, *args, **kwargs):
:param args: tuple: Positional arguments to send in the request
:param kwargs: dict: Keyword arguments to send in the request
"""
- args_str = json.dumps([args, kwargs], cls=self._plugin_encoder, **self._plugin_kwargs,
- serializer_kwrags=self._serializer_kwargs)
+ args_str = json.dumps(
+ [args, kwargs],
+ cls=self._plugin_encoder,
+ **self._plugin_kwargs,
+ serializer_kwrags=self._serializer_kwargs,
+ )
args_msg = yarp.Bottle()
args_msg.clear()
args_msg.addString(args_str)
msg = yarp.Bottle()
msg.clear()
self._port.write(args_msg, msg)
- img = json.loads(msg.get(0).asString(), object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs)
+ img = json.loads(
+ msg.get(0).asString(),
+ object_hook=self._plugin_decoder_hook,
+ **self._deserializer_kwargs,
+ )
height, width, channels = img.shape
if 0 < self.width != width or 0 < self.height != height:
raise ValueError("Incorrect image shape for client")
@@ -188,9 +242,18 @@ def _request(self, *args, **kwargs):
@Clients.register("AudioChunk", "yarp")
class YarpAudioChunkClient(YarpNativeObjectClient):
- def __init__(self, name: str, in_topic: str, carrier: Literal["tcp", "udp", "mcast"] = "tcp",
- channels: int = 1, rate: int = 44100, chunk: int = -1,
- persistent: bool = True, serializer_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ in_topic: str,
+ carrier: Literal["tcp", "udp", "mcast"] = "tcp",
+ channels: int = 1,
+ rate: int = 44100,
+ chunk: int = -1,
+ persistent: bool = True,
+ serializer_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
The AudioChunk client using the YARP Bottle construct parsed to a numpy array.
@@ -203,7 +266,14 @@ def __init__(self, name: str, in_topic: str, carrier: Literal["tcp", "udp", "mca
:param persistent: bool: Whether to keep the service connection alive across multiple service calls. Default is True
:param serializer_kwargs: dict: Additional kwargs for the serializer
"""
- super().__init__(name, in_topic, carrier=carrier, persistent=persistent, serializer_kwargs=serializer_kwargs, **kwargs)
+ super().__init__(
+ name,
+ in_topic,
+ carrier=carrier,
+ persistent=persistent,
+ serializer_kwargs=serializer_kwargs,
+ **kwargs,
+ )
self.channels = channels
self.rate = rate
self.chunk = chunk
@@ -215,18 +285,30 @@ def _request(self, *args, **kwargs):
:param args: tuple: Positional arguments to send in the request
:param kwargs: dict: Keyword arguments to send in the request
"""
- args_str = json.dumps([args, kwargs], cls=self._plugin_encoder, **self._plugin_kwargs,
- serializer_kwrags=self._serializer_kwargs)
+ args_str = json.dumps(
+ [args, kwargs],
+ cls=self._plugin_encoder,
+ **self._plugin_kwargs,
+ serializer_kwrags=self._serializer_kwargs,
+ )
args_msg = yarp.Bottle()
args_msg.clear()
args_msg.addString(args_str)
msg = yarp.Bottle()
msg.clear()
self._port.write(args_msg, msg)
- chunk, channels, rate, aud = json.loads(msg.get(0).asString(), object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs)
+ chunk, channels, rate, aud = json.loads(
+ msg.get(0).asString(),
+ object_hook=self._plugin_decoder_hook,
+ **self._deserializer_kwargs,
+ )
if 0 < self.rate != rate:
raise ValueError("Incorrect audio rate for client")
- if 0 < self.chunk != chunk or self.channels != channels or aud.size != chunk * channels:
+ if (
+ 0 < self.chunk != chunk
+ or self.channels != channels
+ or aud.size != chunk * channels
+ ):
raise ValueError("Incorrect audio shape for client")
else:
- self._queue.put((aud, rate), block=False)
\ No newline at end of file
+ self._queue.put((aud, rate), block=False)
diff --git a/wrapyfi/clients/zeromq.py b/wrapyfi/clients/zeromq.py
index 8c6f692..5d93839 100644
--- a/wrapyfi/clients/zeromq.py
+++ b/wrapyfi/clients/zeromq.py
@@ -18,9 +18,16 @@
class ZeroMQClient(Client):
- def __init__(self, name, in_topic, carrier="tcp",
- socket_ip: str = SOCKET_IP, socket_rep_port: int = SOCKET_REP_PORT,
- zeromq_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name,
+ in_topic,
+ carrier="tcp",
+ socket_ip: str = SOCKET_IP,
+ socket_rep_port: int = SOCKET_REP_PORT,
+ zeromq_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
Initialize the client.
@@ -31,10 +38,14 @@ def __init__(self, name, in_topic, carrier="tcp",
:param kwargs: dict: Additional kwargs for the client
"""
if in_topic != "":
- logging.warning(f"[ZeroMQ] ZeroMQ does not support topics for the REQ/REP pattern. Topic {in_topic} removed")
+ logging.warning(
+ f"[ZeroMQ] ZeroMQ does not support topics for the REQ/REP pattern. Topic {in_topic} removed"
+ )
in_topic = ""
if carrier or carrier != "tcp":
- logging.warning("[ZeroMQ] ZeroMQ does not support other carriers than TCP for REQ/REP pattern. Using TCP.")
+ logging.warning(
+ "[ZeroMQ] ZeroMQ does not support other carriers than TCP for REQ/REP pattern. Using TCP."
+ )
carrier = "tcp"
super().__init__(name, in_topic, carrier=carrier, **kwargs)
@@ -56,9 +67,15 @@ def __del__(self):
@Clients.register("NativeObject", "zeromq")
class ZeroMQNativeObjectClient(ZeroMQClient):
- def __init__(self, name: str, in_topic: str, carrier: str = "tcp",
- serializer_kwargs: Optional[dict] = None,
- deserializer_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ in_topic: str,
+ carrier: str = "tcp",
+ serializer_kwargs: Optional[dict] = None,
+ deserializer_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
Specific client for handling native Python objects, serializing them to JSON strings for transmission.
@@ -86,9 +103,13 @@ def establish(self, **kwargs):
self._socket = zmq.Context().instance().socket(zmq.REQ)
for socket_property in ZeroMQMiddlewareReqRep().zeromq_kwargs.items():
if isinstance(socket_property[1], str):
- self._socket.setsockopt_string(getattr(zmq, socket_property[0]), socket_property[1])
+ self._socket.setsockopt_string(
+ getattr(zmq, socket_property[0]), socket_property[1]
+ )
else:
- self._socket.setsockopt(getattr(zmq, socket_property[0]), socket_property[1])
+ self._socket.setsockopt(
+ getattr(zmq, socket_property[0]), socket_property[1]
+ )
self._socket.connect(self.socket_address)
self.established = True
@@ -115,11 +136,15 @@ def _request(self, *args, **kwargs):
:param args: tuple: Arguments to be serialized and sent
:param kwargs: dict: Keyword arguments to be serialized and sent
"""
- args_str = json.dumps([args, kwargs], cls=self._plugin_encoder, **self._serializer_kwargs)
+ args_str = json.dumps(
+ [args, kwargs], cls=self._plugin_encoder, **self._serializer_kwargs
+ )
self._socket.send_string(args_str)
obj_str = self._socket.recv_string()
- obj = json.loads(obj_str, object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs)
+ obj = json.loads(
+ obj_str, object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs
+ )
self._queue.put(obj, block=False)
def _await_reply(self):
@@ -132,16 +157,28 @@ def _await_reply(self):
reply = self._queue.get(block=True)
return reply
except queue.Empty:
- logging.warning(f"[ZeroMQ] Discarding data because queue is empty. "
- f"This happened due to bad synchronization in {self.__class__.__name__}")
+ logging.warning(
+ f"[ZeroMQ] Discarding data because queue is empty. "
+ f"This happened due to bad synchronization in {self.__class__.__name__}"
+ )
return None
@Clients.register("Image", "zeromq")
class ZeroMQImageClient(ZeroMQNativeObjectClient):
- def __init__(self, name: str, in_topic: str, carrier: str = "tcp",
- width: int = -1, height: int = -1, rgb: bool = True, fp: bool = False, jpg: bool = False,
- serializer_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ in_topic: str,
+ carrier: str = "tcp",
+ width: int = -1,
+ height: int = -1,
+ rgb: bool = True,
+ fp: bool = False,
+ jpg: bool = False,
+ serializer_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
The Image client using the ZeroMQ message construct parsed to a numpy array.
@@ -155,7 +192,13 @@ def __init__(self, name: str, in_topic: str, carrier: str = "tcp",
:param jpg: bool: True if the image should be compressed to JPG before sending. Default is False
:param serializer_kwargs: dict: Additional kwargs for the serializer
"""
- super().__init__(name, in_topic, carrier=carrier, serializer_kwargs=serializer_kwargs, **kwargs)
+ super().__init__(
+ name,
+ in_topic,
+ carrier=carrier,
+ serializer_kwargs=serializer_kwargs,
+ **kwargs,
+ )
self.width = width
self.height = height
self.rgb = rgb
@@ -171,12 +214,16 @@ def _request(self, *args, **kwargs):
:param args: tuple: Arguments to be serialized and sent
:param kwargs: dict: Keyword arguments to be serialized and sent
"""
- args_str = json.dumps([args, kwargs], cls=self._plugin_encoder, **self._serializer_kwargs)
+ args_str = json.dumps(
+ [args, kwargs], cls=self._plugin_encoder, **self._serializer_kwargs
+ )
self._socket.send_string(args_str)
if self.jpg:
reply_bytes = self._socket.recv()
- reply_img = cv2.imdecode(np.frombuffer(reply_bytes, np.uint8), cv2.IMREAD_ANYCOLOR)
+ reply_img = cv2.imdecode(
+ np.frombuffer(reply_bytes, np.uint8), cv2.IMREAD_ANYCOLOR
+ )
else:
reply_str = self._socket.recv_string()
reply_img_list = json.loads(reply_str)
@@ -192,20 +239,34 @@ def _await_reply(self):
try:
img = self._queue.get(block=True)
height, width, channels = img.shape
- if 0 < self.width != width or 0 < self.height != height or img.size != height * width * (3 if self.rgb else 1):
+ if (
+ 0 < self.width != width
+ or 0 < self.height != height
+ or img.size != height * width * (3 if self.rgb else 1)
+ ):
raise ValueError("Incorrect image shape for subscriber")
return img
except queue.Empty:
- logging.warning(f"[ZeroMQ] Discarding data because queue is empty. "
- f"This happened due to bad synchronization in {self.__class__.__name__}")
+ logging.warning(
+ f"[ZeroMQ] Discarding data because queue is empty. "
+ f"This happened due to bad synchronization in {self.__class__.__name__}"
+ )
return None
@Clients.register("AudioChunk", "zeromq")
class ZeroMQAudioChunkClient(ZeroMQNativeObjectClient):
- def __init__(self, name: str, in_topic: str, carrier: str = "tcp",
- channels: int = 1, rate: int = 44100, chunk: int = -1,
- serializer_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ in_topic: str,
+ carrier: str = "tcp",
+ channels: int = 1,
+ rate: int = 44100,
+ chunk: int = -1,
+ serializer_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
The AudioChunk client using the ZeroMQ message construct parsed to a numpy array.
@@ -219,7 +280,13 @@ def __init__(self, name: str, in_topic: str, carrier: str = "tcp",
:param jpg: bool: True if the image should be compressed to JPG before sending. Default is False
:param serializer_kwargs: dict: Additional kwargs for the serializer
"""
- super().__init__(name, in_topic, carrier=carrier, serializer_kwargs=serializer_kwargs, **kwargs)
+ super().__init__(
+ name,
+ in_topic,
+ carrier=carrier,
+ serializer_kwargs=serializer_kwargs,
+ **kwargs,
+ )
self.channels = channels
self.rate = rate
self.chunk = chunk
@@ -231,7 +298,9 @@ def _request(self, *args, **kwargs):
:param args: tuple: Arguments to be serialized and sent
:param kwargs: dict: Keyword arguments to be serialized and sent
"""
- args_str = json.dumps([args, kwargs], cls=self._plugin_encoder, **self._serializer_kwargs)
+ args_str = json.dumps(
+ [args, kwargs], cls=self._plugin_encoder, **self._serializer_kwargs
+ )
self._socket.send_string(args_str)
reply_str = self._socket.recv_string()
@@ -250,10 +319,16 @@ def _await_reply(self):
chunk, channels, rate, aud = self._queue.get(block=True)
if 0 < self.rate != rate:
raise ValueError("Incorrect audio rate for listener")
- if 0 < self.chunk != chunk or self.channels != channels or aud.size != chunk * channels:
+ if (
+ 0 < self.chunk != chunk
+ or self.channels != channels
+ or aud.size != chunk * channels
+ ):
raise ValueError("Incorrect audio shape for listener")
return aud, rate
except queue.Empty:
- logging.warning(f"[ZeroMQ] Discarding data because queue is empty. "
- f"This happened due to bad synchronization in {self.__class__.__name__}")
- return None
\ No newline at end of file
+ logging.warning(
+ f"[ZeroMQ] Discarding data because queue is empty. "
+ f"This happened due to bad synchronization in {self.__class__.__name__}"
+ )
+ return None
diff --git a/wrapyfi/config/manager.py b/wrapyfi/config/manager.py
index 36df4fa..2943fee 100755
--- a/wrapyfi/config/manager.py
+++ b/wrapyfi/config/manager.py
@@ -9,6 +9,7 @@ class ConfigManager(metaclass=SingletonOptimized):
"""
The configuration manager is a singleton which is invoked once throughout the runtime.
"""
+
def __init__(self, config: Optional[Union[dict, str]], **kwargs):
"""
Initializing the ConfigManager. The configuration can be provided as a yaml file name or as a dictionary.
@@ -44,4 +45,3 @@ def __writefile(self, filename: str):
"""
with open(filename, "w") as fp:
yaml.safe_dump(self.config, fp)
-
diff --git a/wrapyfi/connect/__init__.py b/wrapyfi/connect/__init__.py
index fe067bf..20f5431 100755
--- a/wrapyfi/connect/__init__.py
+++ b/wrapyfi/connect/__init__.py
@@ -6,4 +6,4 @@
Listeners.scan()
Publishers.scan()
Servers.scan()
-Clients.scan()
\ No newline at end of file
+Clients.scan()
diff --git a/wrapyfi/connect/clients.py b/wrapyfi/connect/clients.py
index 7a9a6c4..45666f2 100755
--- a/wrapyfi/connect/clients.py
+++ b/wrapyfi/connect/clients.py
@@ -9,6 +9,7 @@ class Clients(object):
"""
A class that holds all clients and their corresponding middleware communicators.
"""
+
registry = {}
mwares = set()
@@ -20,10 +21,12 @@ def register(cls, data_type: str, communicator: str):
:param data_type: str: The data type to register the client for e.g., "NativeObject", "Image", "AudioChunk", etc.
:param communicator: str: The middleware communicator to register the client for e.g., "ros", "ros2", "yarp", "zeromq", etc.
"""
+
def decorator(cls_):
cls.registry[data_type + ":" + communicator] = cls_
cls.mwares.add(communicator)
return cls_
+
return decorator
@staticmethod
@@ -31,9 +34,15 @@ def scan():
"""
Scan for clients and add them to the registry.
"""
- modules = glob(os.path.join(os.path.dirname(__file__), "..", "clients", "*.py"), recursive=True)
- modules = ["wrapyfi.clients." + module.replace(os.path.dirname(__file__) + "/../clients/", "") for module in
- modules]
+ modules = glob(
+ os.path.join(os.path.dirname(__file__), "..", "clients", "*.py"),
+ recursive=True,
+ )
+ modules = [
+ "wrapyfi.clients."
+ + module.replace(os.path.dirname(__file__) + "/../clients/", "")
+ for module in modules
+ ]
dynamic_module_import(modules, globals())
@@ -41,6 +50,7 @@ class Client(object):
"""
A base class for clients.
"""
+
def __init__(self, name: str, in_topic: str, carrier: str = "", **kwargs):
"""
Initialize the client.
diff --git a/wrapyfi/connect/listeners.py b/wrapyfi/connect/listeners.py
index fe4c595..54fcdfe 100755
--- a/wrapyfi/connect/listeners.py
+++ b/wrapyfi/connect/listeners.py
@@ -9,6 +9,7 @@ class ListenerWatchDog(metaclass=SingletonOptimized):
"""
A watchdog that scans for listeners and removes them from the ring if they are not established.
"""
+
def __init__(self, repeats: int = 10, inner_repeats: int = 10):
"""
Initialize the ListenerWatchDog.
@@ -53,6 +54,7 @@ class Listeners(object):
"""
A class that holds all listeners and their corresponding middleware communicators.
"""
+
registry = {}
mwares = set()
@@ -64,10 +66,12 @@ def register(cls, data_type: str, communicator: str):
:param data_type: str: The data type to register the listener for e.g., "NativeObject", "Image", "AudioChunk", etc.
:param communicator: str: The middleware communicator to register the listener for e.g., "ros", "ros2", "yarp", "zeromq", etc.
"""
+
def decorator(cls_):
cls.registry[data_type + ":" + communicator] = cls_
cls.mwares.add(communicator)
return cls_
+
return decorator
@staticmethod
@@ -75,9 +79,15 @@ def scan():
"""
Scan for listeners and add them to the registry.
"""
- modules = glob(os.path.join(os.path.dirname(__file__), "..", "listeners", "*.py"), recursive=True)
- modules = ["wrapyfi.listeners." + module.replace(os.path.dirname(__file__) + "/../listeners/", "") for module in
- modules]
+ modules = glob(
+ os.path.join(os.path.dirname(__file__), "..", "listeners", "*.py"),
+ recursive=True,
+ )
+ modules = [
+ "wrapyfi.listeners."
+ + module.replace(os.path.dirname(__file__) + "/../listeners/", "")
+ for module in modules
+ ]
dynamic_module_import(modules, globals())
@@ -85,7 +95,15 @@ class Listener(object):
"""
A base class for listeners.
"""
- def __init__(self, name: str, in_topic: str, carrier: str = "", should_wait: bool = True, **kwargs):
+
+ def __init__(
+ self,
+ name: str,
+ in_topic: str,
+ carrier: str = "",
+ should_wait: bool = True,
+ **kwargs,
+ ):
"""
Initialize the Listener.
diff --git a/wrapyfi/connect/publishers.py b/wrapyfi/connect/publishers.py
index edd9b6b..de92df1 100755
--- a/wrapyfi/connect/publishers.py
+++ b/wrapyfi/connect/publishers.py
@@ -9,6 +9,7 @@ class PublisherWatchDog(metaclass=SingletonOptimized):
"""
A watchdog that scans for publishers and removes them from the ring if they are not established.
"""
+
def __init__(self, repeats: int = 10, inner_repeats: int = 10):
"""
Initialize the PublisherWatchDog.
@@ -53,6 +54,7 @@ class Publishers(object):
"""
A class that holds all publishers and their corresponding middleware communicators.
"""
+
registry = {}
mwares = set()
@@ -65,10 +67,12 @@ def register(cls, data_type: str, communicator: str):
:param communicator: str: The middleware communicator to register the publisher for e.g., "ros", "ros2", "yarp", "zeromq", etc.
:return: Callable[..., Any]: A decorator function that registers the decorated class as a publisher for the given data type and middleware communicator
"""
+
def decorator(cls_):
cls.registry[data_type + ":" + communicator] = cls_
cls.mwares.add(communicator)
return cls_
+
return decorator
@staticmethod
@@ -76,9 +80,15 @@ def scan():
"""
Scan for publishers and add them to the registry.
"""
- modules = glob(os.path.join(os.path.dirname(__file__), "..", "publishers", "*.py"), recursive=True)
- modules = ["wrapyfi.publishers." + module.replace(os.path.dirname(__file__) + "/../publishers/", "") for module in
- modules]
+ modules = glob(
+ os.path.join(os.path.dirname(__file__), "..", "publishers", "*.py"),
+ recursive=True,
+ )
+ modules = [
+ "wrapyfi.publishers."
+ + module.replace(os.path.dirname(__file__) + "/../publishers/", "")
+ for module in modules
+ ]
dynamic_module_import(modules, globals())
@@ -86,7 +96,15 @@ class Publisher(object):
"""
A base class for all publishers.
"""
- def __init__(self, name: str, out_topic: str, carrier: str = "", should_wait: bool = True, **kwargs):
+
+ def __init__(
+ self,
+ name: str,
+ out_topic: str,
+ carrier: str = "",
+ should_wait: bool = True,
+ **kwargs,
+ ):
"""
Initialize the Publisher.
diff --git a/wrapyfi/connect/servers.py b/wrapyfi/connect/servers.py
index 4f25171..831a924 100755
--- a/wrapyfi/connect/servers.py
+++ b/wrapyfi/connect/servers.py
@@ -10,6 +10,7 @@ class Servers(object):
"""
A class that holds all servers and their corresponding middleware communicators.
"""
+
registry = {}
mwares = set()
@@ -22,10 +23,12 @@ def register(cls, data_type: str, communicator: str):
:param communicator: str: The middleware communicator to register the server for e.g., "ros", "ros2", "yarp", "zeromq", etc.
:return: Callable: A decorator that registers the server with the given data type and middleware communicator
"""
+
def decorator(cls_):
cls.registry[data_type + ":" + communicator] = cls_
cls.mwares.add(communicator)
return cls_
+
return decorator
@staticmethod
@@ -33,9 +36,15 @@ def scan():
"""
Scan for servers and add them to the registry.
"""
- modules = glob(os.path.join(os.path.dirname(__file__), "..", "servers", "*.py"), recursive=True)
- modules = ["wrapyfi.servers." + module.replace(os.path.dirname(__file__) + "/../servers/", "") for module in
- modules]
+ modules = glob(
+ os.path.join(os.path.dirname(__file__), "..", "servers", "*.py"),
+ recursive=True,
+ )
+ modules = [
+ "wrapyfi.servers."
+ + module.replace(os.path.dirname(__file__) + "/../servers/", "")
+ for module in modules
+ ]
dynamic_module_import(modules, globals())
@@ -43,7 +52,15 @@ class Server(object):
"""
A base class for servers.
"""
- def __init__(self, name: str, out_topic: str, carrier: str = "", out_topic_connect: Optional[str] = None, **kwargs):
+
+ def __init__(
+ self,
+ name: str,
+ out_topic: str,
+ carrier: str = "",
+ out_topic_connect: Optional[str] = None,
+ **kwargs,
+ ):
"""
Initialize the server.
@@ -55,7 +72,9 @@ def __init__(self, name: str, out_topic: str, carrier: str = "", out_topic_conne
self.__name__ = name
self.out_topic = out_topic
self.carrier = carrier
- self.out_topic_connect = out_topic + ":out" if out_topic_connect is None else out_topic_connect
+ self.out_topic_connect = (
+ out_topic + ":out" if out_topic_connect is None else out_topic_connect
+ )
self.established = False
def establish(self):
diff --git a/wrapyfi/connect/wrapper.py b/wrapyfi/connect/wrapper.py
index 0122bdd..6403ef2 100755
--- a/wrapyfi/connect/wrapper.py
+++ b/wrapyfi/connect/wrapper.py
@@ -29,7 +29,9 @@ def __init__(self):
self.activate_communication(getattr(self.__class__, key), mode=value)
@classmethod
- def __trigger_publish(cls, func: Callable[..., Any], instance_id: str, kwd: dict, *wds, **kwds):
+ def __trigger_publish(
+ cls, func: Callable[..., Any], instance_id: str, kwd: dict, *wds, **kwds
+ ):
"""
Triggers the publish mode of the middleware communicator.
@@ -42,60 +44,97 @@ def __trigger_publish(cls, func: Callable[..., Any], instance_id: str, kwd: dict
:raises: KeyError: If the intended publisher type and middleware are unavailable, resorting to a fallback publisher
"""
- if "wrapped_executor" not in \
- cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["communicator"][0]:
+ if (
+ "wrapped_executor"
+ not in cls._MiddlewareCommunicator__registry[
+ func.__qualname__ + instance_id
+ ]["communicator"][0]
+ ):
# instantiate the publishers
- cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["communicator"].reverse()
- for communicator in cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id][
- "communicator"]:
+ cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id][
+ "communicator"
+ ].reverse()
+ for communicator in cls._MiddlewareCommunicator__registry[
+ func.__qualname__ + instance_id
+ ]["communicator"]:
# single element
if isinstance(communicator["return_func_type"], str):
- return_func_pub_kwargs = deepcopy(communicator["return_func_kwargs"])
- return_func_pub_kwargs.update(return_func_pub_kwargs.get("publisher_kwargs", {}))
+ return_func_pub_kwargs = deepcopy(
+ communicator["return_func_kwargs"]
+ )
+ return_func_pub_kwargs.update(
+ return_func_pub_kwargs.get("publisher_kwargs", {})
+ )
return_func_pub_kwargs.pop("listener_kwargs", None)
return_func_pub_kwargs.pop("publisher_kwargs", None)
new_args, new_kwargs = match_args(
- communicator["return_func_args"], return_func_pub_kwargs, wds[1:], kwd)
+ communicator["return_func_args"],
+ return_func_pub_kwargs,
+ wds[1:],
+ kwd,
+ )
return_func_type = communicator["return_func_type"]
- return_func_middleware = new_kwargs.pop("middleware", DEFAULT_COMMUNICATOR)
+ return_func_middleware = new_kwargs.pop(
+ "middleware", DEFAULT_COMMUNICATOR
+ )
try:
communicator["wrapped_executor"] = pub.Publishers.registry[
- return_func_type + return_func_middleware](*new_args, **new_kwargs)
+ return_func_type + return_func_middleware
+ ](*new_args, **new_kwargs)
except KeyError:
communicator["wrapped_executor"] = pub.Publishers.registry[
- "MMO:fallback"](*new_args,
- missing_middleware_object=return_func_type + return_func_middleware,
- **new_kwargs)
+ "MMO:fallback"
+ ](
+ *new_args,
+ missing_middleware_object=return_func_type
+ + return_func_middleware,
+ **new_kwargs,
+ )
communicator["return_func_type"] = "MMO:"
# list for single return
elif isinstance(communicator["return_func_type"], list):
communicator["wrapped_executor"] = []
for comm_idx in range(len(communicator["return_func_type"])):
- return_func_pub_kwargs = deepcopy(communicator["return_func_kwargs"][comm_idx])
- return_func_pub_kwargs.update(return_func_pub_kwargs.get("publisher_kwargs", {}))
+ return_func_pub_kwargs = deepcopy(
+ communicator["return_func_kwargs"][comm_idx]
+ )
+ return_func_pub_kwargs.update(
+ return_func_pub_kwargs.get("publisher_kwargs", {})
+ )
return_func_pub_kwargs.pop("listener_kwargs", None)
return_func_pub_kwargs.pop("publisher_kwargs", None)
new_args, new_kwargs = match_args(
- communicator["return_func_args"][comm_idx], return_func_pub_kwargs, wds[1:], kwd)
+ communicator["return_func_args"][comm_idx],
+ return_func_pub_kwargs,
+ wds[1:],
+ kwd,
+ )
return_func_type = communicator["return_func_type"][comm_idx]
- return_func_middleware = new_kwargs.pop("middleware", DEFAULT_COMMUNICATOR)
+ return_func_middleware = new_kwargs.pop(
+ "middleware", DEFAULT_COMMUNICATOR
+ )
try:
communicator["wrapped_executor"].append(
- pub.Publishers.registry[return_func_type + return_func_middleware](*new_args,
- **new_kwargs))
+ pub.Publishers.registry[
+ return_func_type + return_func_middleware
+ ](*new_args, **new_kwargs)
+ )
except KeyError:
communicator["wrapped_executor"].append(
- pub.Publishers.registry[
- "MMO:fallback"](*new_args,
- missing_middleware_object=return_func_type + return_func_middleware,
- **new_kwargs))
+ pub.Publishers.registry["MMO:fallback"](
+ *new_args,
+ missing_middleware_object=return_func_type
+ + return_func_middleware,
+ **new_kwargs,
+ )
+ )
communicator["return_func_type"][comm_idx] = "MMO:"
returns = func(*wds, **kwds)
for ret_idx, ret in enumerate(returns):
- wrp_exec = \
- cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["communicator"][ret_idx][
- "wrapped_executor"]
+ wrp_exec = cls._MiddlewareCommunicator__registry[
+ func.__qualname__ + instance_id
+ ]["communicator"][ret_idx]["wrapped_executor"]
# single element
if isinstance(wrp_exec, pub.Publisher):
wrp_exec.publish(ret)
@@ -106,7 +145,9 @@ def __trigger_publish(cls, func: Callable[..., Any], instance_id: str, kwd: dict
return returns
@classmethod
- def __trigger_listen(cls, func: Callable[..., Any], instance_id: str, kwd: dict, *wds, **kwds):
+ def __trigger_listen(
+ cls, func: Callable[..., Any], instance_id: str, kwd: dict, *wds, **kwds
+ ):
"""
Triggers the listen mode of the middleware communicator.
@@ -119,59 +160,104 @@ def __trigger_listen(cls, func: Callable[..., Any], instance_id: str, kwd: dict,
:raises: KeyError: If the intended listener type and middleware are unavailable, resorting to a fallback listener
"""
- if "wrapped_executor" not in \
- cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["communicator"][0]:
+ if (
+ "wrapped_executor"
+ not in cls._MiddlewareCommunicator__registry[
+ func.__qualname__ + instance_id
+ ]["communicator"][0]
+ ):
# instantiate the listeners
- cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["communicator"].reverse()
- for communicator in cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["communicator"]:
+ cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id][
+ "communicator"
+ ].reverse()
+ for communicator in cls._MiddlewareCommunicator__registry[
+ func.__qualname__ + instance_id
+ ]["communicator"]:
# single element
if isinstance(communicator["return_func_type"], str):
- return_func_lsn_kwargs = deepcopy(communicator["return_func_kwargs"])
- return_func_lsn_kwargs.update(return_func_lsn_kwargs.get("listener_kwargs", {}))
+ return_func_lsn_kwargs = deepcopy(
+ communicator["return_func_kwargs"]
+ )
+ return_func_lsn_kwargs.update(
+ return_func_lsn_kwargs.get("listener_kwargs", {})
+ )
return_func_lsn_kwargs.pop("listener_kwargs", None)
return_func_lsn_kwargs.pop("publisher_kwargs", None)
- new_args, new_kwargs = match_args(communicator["return_func_args"], return_func_lsn_kwargs, wds[1:],
- kwd)
+ new_args, new_kwargs = match_args(
+ communicator["return_func_args"],
+ return_func_lsn_kwargs,
+ wds[1:],
+ kwd,
+ )
return_func_type = communicator["return_func_type"]
- return_func_middleware = new_kwargs.pop("middleware", DEFAULT_COMMUNICATOR)
+ return_func_middleware = new_kwargs.pop(
+ "middleware", DEFAULT_COMMUNICATOR
+ )
try:
communicator["wrapped_executor"] = lsn.Listeners.registry[
- return_func_type + return_func_middleware](*new_args, **new_kwargs)
+ return_func_type + return_func_middleware
+ ](*new_args, **new_kwargs)
except KeyError:
communicator["wrapped_executor"] = lsn.Listeners.registry[
- "MMO:fallback"](*new_args,
- missing_middleware_object=return_func_type + return_func_middleware,
- **new_kwargs)
+ "MMO:fallback"
+ ](
+ *new_args,
+ missing_middleware_object=return_func_type
+ + return_func_middleware,
+ **new_kwargs,
+ )
communicator["return_func_type"] = "MMO:"
# list for single return
elif isinstance(communicator["return_func_type"], list):
communicator["wrapped_executor"] = []
for comm_idx in range(len(communicator["return_func_type"])):
- return_func_lsn_kwargs = deepcopy(communicator["return_func_kwargs"][comm_idx])
- return_func_lsn_kwargs.update(return_func_lsn_kwargs.get("listener_kwargs", {}))
+ return_func_lsn_kwargs = deepcopy(
+ communicator["return_func_kwargs"][comm_idx]
+ )
+ return_func_lsn_kwargs.update(
+ return_func_lsn_kwargs.get("listener_kwargs", {})
+ )
return_func_lsn_kwargs.pop("listener_kwargs", None)
return_func_lsn_kwargs.pop("publisher_kwargs", None)
- new_args, new_kwargs = match_args(communicator["return_func_args"][comm_idx],
- return_func_lsn_kwargs, wds[1:], kwd)
+ new_args, new_kwargs = match_args(
+ communicator["return_func_args"][comm_idx],
+ return_func_lsn_kwargs,
+ wds[1:],
+ kwd,
+ )
return_func_type = communicator["return_func_type"][comm_idx]
- return_func_middleware = new_kwargs.pop("middleware", DEFAULT_COMMUNICATOR)
+ return_func_middleware = new_kwargs.pop(
+ "middleware", DEFAULT_COMMUNICATOR
+ )
try:
communicator["wrapped_executor"].append(
- lsn.Listeners.registry[return_func_type + return_func_middleware](*new_args, **new_kwargs))
+ lsn.Listeners.registry[
+ return_func_type + return_func_middleware
+ ](*new_args, **new_kwargs)
+ )
except KeyError:
communicator["wrapped_executor"].append(
- lsn.Listeners.registry[
- "MMO:fallback"](*new_args,
- missing_middleware_object=return_func_type + return_func_middleware,
- **new_kwargs))
+ lsn.Listeners.registry["MMO:fallback"](
+ *new_args,
+ missing_middleware_object=return_func_type
+ + return_func_middleware,
+ **new_kwargs,
+ )
+ )
communicator["return_func_type"][comm_idx] = "MMO:"
returns = []
for ret_idx in range(
- len(cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["communicator"])):
- wrp_exec = cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["communicator"][ret_idx][
- "wrapped_executor"]
+ len(
+ cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id][
+ "communicator"
+ ]
+ )
+ ):
+ wrp_exec = cls._MiddlewareCommunicator__registry[
+ func.__qualname__ + instance_id
+ ]["communicator"][ret_idx]["wrapped_executor"]
# single element
if isinstance(wrp_exec, lsn.Listener):
returns.append(wrp_exec.listen())
@@ -184,7 +270,9 @@ def __trigger_listen(cls, func: Callable[..., Any], instance_id: str, kwd: dict,
return returns
@classmethod
- def __trigger_reply(cls, func: Callable[..., Any], instance_id: str, kwd, *wds, **kwds):
+ def __trigger_reply(
+ cls, func: Callable[..., Any], instance_id: str, kwd, *wds, **kwds
+ ):
"""
Triggers the reply mode of the middleware communicator.
@@ -197,60 +285,99 @@ def __trigger_reply(cls, func: Callable[..., Any], instance_id: str, kwd, *wds,
:raises: KeyError: If the intended server type and middleware are unavailable, resorting to a fallback server
"""
- if "wrapped_executor" not in \
- cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["communicator"][0]:
+ if (
+ "wrapped_executor"
+ not in cls._MiddlewareCommunicator__registry[
+ func.__qualname__ + instance_id
+ ]["communicator"][0]
+ ):
# instantiate the publishers
- cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["communicator"].reverse()
- for communicator in cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["communicator"]:
+ cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id][
+ "communicator"
+ ].reverse()
+ for communicator in cls._MiddlewareCommunicator__registry[
+ func.__qualname__ + instance_id
+ ]["communicator"]:
# single element
if isinstance(communicator["return_func_type"], str):
- return_func_pub_kwargs = deepcopy(communicator["return_func_kwargs"])
- return_func_pub_kwargs.update(return_func_pub_kwargs.get("publisher_kwargs", {}))
+ return_func_pub_kwargs = deepcopy(
+ communicator["return_func_kwargs"]
+ )
+ return_func_pub_kwargs.update(
+ return_func_pub_kwargs.get("publisher_kwargs", {})
+ )
return_func_pub_kwargs.pop("listener_kwargs", None)
return_func_pub_kwargs.pop("publisher_kwargs", None)
new_args, new_kwargs = match_args(
- communicator["return_func_args"], return_func_pub_kwargs, wds[1:], kwd)
+ communicator["return_func_args"],
+ return_func_pub_kwargs,
+ wds[1:],
+ kwd,
+ )
return_func_type = communicator["return_func_type"]
- return_func_middleware = new_kwargs.pop("middleware", DEFAULT_COMMUNICATOR)
+ return_func_middleware = new_kwargs.pop(
+ "middleware", DEFAULT_COMMUNICATOR
+ )
try:
communicator["wrapped_executor"] = srv.Servers.registry[
- return_func_type + return_func_middleware](*new_args, **new_kwargs)
+ return_func_type + return_func_middleware
+ ](*new_args, **new_kwargs)
except KeyError:
communicator["wrapped_executor"] = srv.Servers.registry[
- "MMO:fallback"](*new_args,
- missing_middleware_object=return_func_type + return_func_middleware,
- **new_kwargs)
+ "MMO:fallback"
+ ](
+ *new_args,
+ missing_middleware_object=return_func_type
+ + return_func_middleware,
+ **new_kwargs,
+ )
communicator["return_func_type"] = "MMO:"
# list for single return
elif isinstance(communicator["return_func_type"], list):
communicator["wrapped_executor"] = []
for comm_idx in range(len(communicator["return_func_type"])):
- return_func_pub_kwargs = deepcopy(communicator["return_func_kwargs"][comm_idx])
+ return_func_pub_kwargs = deepcopy(
+ communicator["return_func_kwargs"][comm_idx]
+ )
return_func_pub_kwargs.update(
- return_func_pub_kwargs.get("publisher_kwargs", {}))
+ return_func_pub_kwargs.get("publisher_kwargs", {})
+ )
return_func_pub_kwargs.pop("listener_kwargs", None)
return_func_pub_kwargs.pop("publisher_kwargs", None)
new_args, new_kwargs = match_args(
- communicator["return_func_args"][comm_idx], return_func_pub_kwargs, wds[1:],
- kwd)
+ communicator["return_func_args"][comm_idx],
+ return_func_pub_kwargs,
+ wds[1:],
+ kwd,
+ )
return_func_type = communicator["return_func_type"][comm_idx]
- return_func_middleware = new_kwargs.pop("middleware", DEFAULT_COMMUNICATOR)
+ return_func_middleware = new_kwargs.pop(
+ "middleware", DEFAULT_COMMUNICATOR
+ )
try:
communicator["wrapped_executor"].append(
- srv.Servers.registry[return_func_type + return_func_middleware](
- *new_args, **new_kwargs))
+ srv.Servers.registry[
+ return_func_type + return_func_middleware
+ ](*new_args, **new_kwargs)
+ )
except KeyError:
communicator["wrapped_executor"].append(
- srv.Servers.registry[
- "MMO:fallback"](*new_args,
- missing_middleware_object=return_func_type + return_func_middleware,
- **new_kwargs))
+ srv.Servers.registry["MMO:fallback"](
+ *new_args,
+ missing_middleware_object=return_func_type
+ + return_func_middleware,
+ **new_kwargs,
+ )
+ )
communicator["return_func_type"][comm_idx] = "MMO:"
returns = None
for ret_idx, functor in enumerate(
- cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["communicator"]):
+ cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id][
+ "communicator"
+ ]
+ ):
wrp_exec = functor["wrapped_executor"]
# single element
if isinstance(wrp_exec, srv.Server):
@@ -270,7 +397,9 @@ def __trigger_reply(cls, func: Callable[..., Any], instance_id: str, kwd, *wds,
return returns
@classmethod
- def __trigger_request(cls, func: Callable[..., Any], instance_id: str, kwd, *wds, **kwds):
+ def __trigger_request(
+ cls, func: Callable[..., Any], instance_id: str, kwd, *wds, **kwds
+ ):
"""
Triggers the request mode of the middleware communicator.
@@ -283,57 +412,99 @@ def __trigger_request(cls, func: Callable[..., Any], instance_id: str, kwd, *wds
:raises: KeyError: If the intended client type and middleware are unavailable, resorting to a fallback client
"""
- if "wrapped_executor" not in \
- cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["communicator"][0]:
+ if (
+ "wrapped_executor"
+ not in cls._MiddlewareCommunicator__registry[
+ func.__qualname__ + instance_id
+ ]["communicator"][0]
+ ):
# instantiate the listeners
- cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["communicator"].reverse()
- for communicator in cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["communicator"]:
+ cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id][
+ "communicator"
+ ].reverse()
+ for communicator in cls._MiddlewareCommunicator__registry[
+ func.__qualname__ + instance_id
+ ]["communicator"]:
# single element
if isinstance(communicator["return_func_type"], str):
- return_func_lsn_kwargs = deepcopy(communicator["return_func_kwargs"])
- return_func_lsn_kwargs.update(return_func_lsn_kwargs.get("listener_kwargs", {}))
+ return_func_lsn_kwargs = deepcopy(
+ communicator["return_func_kwargs"]
+ )
+ return_func_lsn_kwargs.update(
+ return_func_lsn_kwargs.get("listener_kwargs", {})
+ )
return_func_lsn_kwargs.pop("listener_kwargs", None)
return_func_lsn_kwargs.pop("publisher_kwargs", None)
- new_args, new_kwargs = match_args(communicator["return_func_args"], return_func_lsn_kwargs, wds[1:],
- kwd)
+ new_args, new_kwargs = match_args(
+ communicator["return_func_args"],
+ return_func_lsn_kwargs,
+ wds[1:],
+ kwd,
+ )
return_func_type = communicator["return_func_type"]
- return_func_middleware = new_kwargs.pop("middleware", DEFAULT_COMMUNICATOR)
+ return_func_middleware = new_kwargs.pop(
+ "middleware", DEFAULT_COMMUNICATOR
+ )
try:
communicator["wrapped_executor"] = clt.Clients.registry[
- return_func_type + return_func_middleware](*new_args, **new_kwargs)
+ return_func_type + return_func_middleware
+ ](*new_args, **new_kwargs)
except KeyError:
communicator["wrapped_executor"] = clt.Clients.registry[
- "MMO:fallback"](*new_args,
- missing_middleware_object=return_func_type + return_func_middleware,
- **new_kwargs)
+ "MMO:fallback"
+ ](
+ *new_args,
+ missing_middleware_object=return_func_type
+ + return_func_middleware,
+ **new_kwargs,
+ )
communicator["return_func_type"] = "MMO:"
# list for single return
elif isinstance(communicator["return_func_type"], list):
communicator["wrapped_executor"] = []
for comm_idx in range(len(communicator["return_func_type"])):
- return_func_lsn_kwargs = deepcopy(communicator["return_func_kwargs"][comm_idx])
- return_func_lsn_kwargs.update(return_func_lsn_kwargs.get("listener_kwargs", {}))
+ return_func_lsn_kwargs = deepcopy(
+ communicator["return_func_kwargs"][comm_idx]
+ )
+ return_func_lsn_kwargs.update(
+ return_func_lsn_kwargs.get("listener_kwargs", {})
+ )
return_func_lsn_kwargs.pop("listener_kwargs", None)
return_func_lsn_kwargs.pop("publisher_kwargs", None)
- new_args, new_kwargs = match_args(communicator["return_func_args"][comm_idx],
- return_func_lsn_kwargs, wds[1:], kwd)
+ new_args, new_kwargs = match_args(
+ communicator["return_func_args"][comm_idx],
+ return_func_lsn_kwargs,
+ wds[1:],
+ kwd,
+ )
return_func_type = communicator["return_func_type"][comm_idx]
- return_func_middleware = new_kwargs.pop("middleware", DEFAULT_COMMUNICATOR)
+ return_func_middleware = new_kwargs.pop(
+ "middleware", DEFAULT_COMMUNICATOR
+ )
try:
communicator["wrapped_executor"].append(
- clt.Clients.registry[return_func_type + return_func_middleware](*new_args, **new_kwargs))
+ clt.Clients.registry[
+ return_func_type + return_func_middleware
+ ](*new_args, **new_kwargs)
+ )
except KeyError:
communicator["wrapped_executor"].append(
- clt.Clients.registry[
- "MMO:fallback"](*new_args,
- missing_middleware_object=return_func_type + return_func_middleware,
- **new_kwargs))
+ clt.Clients.registry["MMO:fallback"](
+ *new_args,
+ missing_middleware_object=return_func_type
+ + return_func_middleware,
+ **new_kwargs,
+ )
+ )
communicator["return_func_type"][comm_idx] = "MMO:"
returns = []
for ret_idx, functor in enumerate(
- cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["communicator"]):
+ cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id][
+ "communicator"
+ ]
+ ):
wrp_exec = functor["wrapped_executor"]
# single element
if isinstance(wrp_exec, clt.Client):
@@ -349,7 +520,13 @@ def __trigger_request(cls, func: Callable[..., Any], instance_id: str, kwd, *wds
return returns
@classmethod
- def register(cls, data_type: Union[str, List[Any]], middleware: str = DEFAULT_COMMUNICATOR, *args, **kwargs):
+ def register(
+ cls,
+ data_type: Union[str, List[Any]],
+ middleware: str = DEFAULT_COMMUNICATOR,
+ *args,
+ **kwargs,
+ ):
"""
Registers a function to the middleware communicator, defining its communication message type and associated
middleware. Note that the function returned is a wrapper that can alter the behavior of the registered function
@@ -363,6 +540,7 @@ def register(cls, data_type: Union[str, List[Any]], middleware: str = DEFAULT_CO
:raises: NotImplementedError: If `data_type` is a dictionary or an unsupported type
"""
+
def encapsulate(func):
# define the communication message type (single element)
if isinstance(data_type, str):
@@ -376,29 +554,45 @@ def encapsulate(func):
return_func_args, return_func_kwargs, return_func_type = [], [], []
for arg in data_type:
data_spec = arg[0] + ":"
- return_func_args.append([a for a in arg[2:] if not isinstance(a, dict)])
- return_func_kwargs.append(*[a for a in arg[2:] if isinstance(a, dict)])
+ return_func_args.append(
+ [a for a in arg[2:] if not isinstance(a, dict)]
+ )
+ return_func_kwargs.append(
+ *[a for a in arg[2:] if isinstance(a, dict)]
+ )
return_func_kwargs[-1]["middleware"] = str(arg[1])
return_func_type.append(data_spec)
# define the communication message type (dict for single return). NOTE: supports 1 layer depth only
elif isinstance(data_type, dict):
- raise NotImplementedError("Dictionaries are not yet supported as a return type")
+ raise NotImplementedError(
+ "Dictionaries are not yet supported as a return type"
+ )
else:
- raise NotImplementedError(f"Return data type not supported: {data_type}")
+ raise NotImplementedError(
+ f"Return data type not supported: {data_type}"
+ )
func_qualname = func.__qualname__
if func_qualname in cls.__registry:
- cls.__registry[func_qualname]["communicator"].append({
- "return_func_args": return_func_args,
- "return_func_kwargs": return_func_kwargs,
- "return_func_type": return_func_type})
+ cls.__registry[func_qualname]["communicator"].append(
+ {
+ "return_func_args": return_func_args,
+ "return_func_kwargs": return_func_kwargs,
+ "return_func_type": return_func_type,
+ }
+ )
else:
- cls.__registry[func_qualname] = {"communicator": [{
- "return_func_args": return_func_args,
- "return_func_kwargs": return_func_kwargs,
- "return_func_type": return_func_type}]}
+ cls.__registry[func_qualname] = {
+ "communicator": [
+ {
+ "return_func_args": return_func_args,
+ "return_func_kwargs": return_func_kwargs,
+ "return_func_type": return_func_type,
+ }
+ ]
+ }
cls.__registry[func_qualname]["mode"] = None
@wraps(func)
@@ -408,47 +602,101 @@ def wrapper(*wds, **kwds): # triggers on calling the method
instance_address = hex(id(wds[0]))
try:
- instance_id = cls._MiddlewareCommunicator__registry[func.__qualname__]["__WRAPYFI_INSTANCES"].index(instance_address) + 1
+ instance_id = (
+ cls._MiddlewareCommunicator__registry[func.__qualname__][
+ "__WRAPYFI_INSTANCES"
+ ].index(instance_address)
+ + 1
+ )
instance_id = "" if instance_id <= 1 else "." + str(instance_id)
except KeyError:
instance_id = ""
# execute the method as usual
- if cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["mode"] is None:
+ if (
+ cls._MiddlewareCommunicator__registry[
+ func.__qualname__ + instance_id
+ ]["mode"]
+ is None
+ ):
return func(*wds, **kwds)
kwd = get_default_args(func)
kwd.update(kwds)
- cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["args"] = wds
- cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["kwargs"] = kwd
+ cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id][
+ "args"
+ ] = wds
+ cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id][
+ "kwargs"
+ ] = kwd
# publishes the method returns
- if cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["mode"] == "publish":
+ if (
+ cls._MiddlewareCommunicator__registry[
+ func.__qualname__ + instance_id
+ ]["mode"]
+ == "publish"
+ ):
return cls.__trigger_publish(func, instance_id, kwd, *wds, **kwds)
# listens to the publisher and returns the messages
- elif cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["mode"] == "listen":
+ elif (
+ cls._MiddlewareCommunicator__registry[
+ func.__qualname__ + instance_id
+ ]["mode"]
+ == "listen"
+ ):
return cls.__trigger_listen(func, instance_id, kwd, *wds, **kwds)
# server awaits request from client and replies with method returns
- elif cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["mode"] == "reply":
+ elif (
+ cls._MiddlewareCommunicator__registry[
+ func.__qualname__ + instance_id
+ ]["mode"]
+ == "reply"
+ ):
return cls.__trigger_reply(func, instance_id, kwd, *wds, **kwds)
# client requests with args from server and awaits reply
- elif cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["mode"] == "request":
+ elif (
+ cls._MiddlewareCommunicator__registry[
+ func.__qualname__ + instance_id
+ ]["mode"]
+ == "request"
+ ):
return cls.__trigger_request(func, instance_id, kwd, *wds, **kwds)
# WARNING: use with caution. This produces "None" for all the method's returns
- elif cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["mode"] == "disable":
- cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["last_results"] = []
- for ret_idx in range(len(cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["communicator"])):
- cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["last_results"].append(None)
- return cls._MiddlewareCommunicator__registry[func.__qualname__ + instance_id]["last_results"]
+ elif (
+ cls._MiddlewareCommunicator__registry[
+ func.__qualname__ + instance_id
+ ]["mode"]
+ == "disable"
+ ):
+ cls._MiddlewareCommunicator__registry[
+ func.__qualname__ + instance_id
+ ]["last_results"] = []
+ for ret_idx in range(
+ len(
+ cls._MiddlewareCommunicator__registry[
+ func.__qualname__ + instance_id
+ ]["communicator"]
+ )
+ ):
+ cls._MiddlewareCommunicator__registry[
+ func.__qualname__ + instance_id
+ ]["last_results"].append(None)
+ return cls._MiddlewareCommunicator__registry[
+ func.__qualname__ + instance_id
+ ]["last_results"]
return wrapper
+
return encapsulate
- def activate_communication(self, func: Union[str, Callable[..., Any]], mode: Union[str, List[str]]):
+ def activate_communication(
+ self, func: Union[str, Callable[..., Any]], mode: Union[str, List[str]]
+ ):
"""
Activates the communication mode for a registered function in the middleware communicator.
The mode determines how the function will interact with the middleware communicator upon invocation,
@@ -477,16 +725,24 @@ def activate_communication(self, func: Union[str, Callable[..., Any]], mode: Uni
wrapyfi_instances.append(instance_addr)
instance_id = len(wrapyfi_instances)
if instance_id > 1:
- self.__registry[f"{func.__qualname__}.{instance_id}"] = deepcopy(entry, exclude_keys=["wrapped_executor"])
+ self.__registry[f"{func.__qualname__}.{instance_id}"] = (
+ deepcopy(entry, exclude_keys=["wrapped_executor"])
+ )
- instance_qualname = f"{func.__qualname__}.{instance_id}" if instance_id > 1 else func.__qualname__
+ instance_qualname = (
+ f"{func.__qualname__}.{instance_id}"
+ if instance_id > 1
+ else func.__qualname__
+ )
if isinstance(mode, list):
try:
- self.__registry[instance_qualname]["mode"] = mode[instance_id-1]
+ self.__registry[instance_qualname]["mode"] = mode[instance_id - 1]
except IndexError:
- raise IndexError("When mode (publish|listen|disable|null) specified in configuration file is a "
- "list, No. of elements in the list should match the number of instances")
+ raise IndexError(
+ "When mode (publish|listen|disable|null) specified in configuration file is a "
+ "list, No. of elements in the list should match the number of instances"
+ )
else:
self.__registry[instance_qualname]["mode"] = mode
self.__registry[instance_qualname]["instance_addr"] = instance_addr
@@ -538,11 +794,17 @@ def close_instance(cls, instance_addr: Optional[str] = None):
for entry_key, entry_val in cls._MiddlewareCommunicator__registry.items():
if instance_addr in entry_val.get("__WRAPYFI_INSTANCES", []):
if not del_entry:
- del_entry_idx = entry_val["__WRAPYFI_INSTANCES"].index(instance_addr)
+ del_entry_idx = entry_val["__WRAPYFI_INSTANCES"].index(
+ instance_addr
+ )
if del_entry_idx == 0:
del_entry = entry_key
else:
- del_entry = re.sub("\.\d+", "\.", entry_key) + "." + str(del_entry_idx + 1)
+ del_entry = (
+ re.sub("\.\d+", "\.", entry_key)
+ + "."
+ + str(del_entry_idx + 1)
+ )
del_entry_name = re.sub("\.\d+", "\.", entry_key)
else:
if del_entry_name in entry_key:
@@ -551,7 +813,9 @@ def close_instance(cls, instance_addr: Optional[str] = None):
# delete registry entry and all its publishers/listeners/servers/clients
if del_entry:
- for communicator in cls._MiddlewareCommunicator__registry[del_entry]["communicator"]:
+ for communicator in cls._MiddlewareCommunicator__registry[del_entry][
+ "communicator"
+ ]:
wrapped_executor = communicator.get("wrapped_executor", False)
if wrapped_executor:
if isinstance(wrapped_executor, list):
@@ -563,25 +827,29 @@ def close_instance(cls, instance_addr: Optional[str] = None):
if other_entry_keys:
del cls._MiddlewareCommunicator__registry[del_entry]
else:
- cls._MiddlewareCommunicator__registry[del_entry].pop("__WRAPYFI_INSTANCES")
+ cls._MiddlewareCommunicator__registry[del_entry].pop(
+ "__WRAPYFI_INSTANCES"
+ )
cls._MiddlewareCommunicator__registry[del_entry].pop("mode")
- cls._MiddlewareCommunicator__registry[del_entry].pop("instance_addr")
+ cls._MiddlewareCommunicator__registry[del_entry].pop(
+ "instance_addr"
+ )
# shift all entries backwards following the deleted one
if del_entry_idx - 1 < len(other_entry_keys):
for other_entry_key in other_entry_keys[del_entry_idx:]:
- new_key = re.split("\.(\d+)", other_entry_key)
- if len(new_key) == 1:
- new_key = new_key[0]
- elif str(int(new_key[1]) - 1) == "1":
- new_key = new_key[0]
- else:
- new_key = new_key[0] + "." + str(int(new_key[1]) - 1)
- cls._MiddlewareCommunicator__registry[new_key] = \
- cls._MiddlewareCommunicator__registry.pop(other_entry_key)
+ new_key = re.split("\.(\d+)", other_entry_key)
+ if len(new_key) == 1:
+ new_key = new_key[0]
+ elif str(int(new_key[1]) - 1) == "1":
+ new_key = new_key[0]
+ else:
+ new_key = new_key[0] + "." + str(int(new_key[1]) - 1)
+ cls._MiddlewareCommunicator__registry[new_key] = (
+ cls._MiddlewareCommunicator__registry.pop(other_entry_key)
+ )
else:
break
def __del__(self):
self.close()
-
diff --git a/wrapyfi/encoders.py b/wrapyfi/encoders.py
index 9d59f02..d4bd2f6 100644
--- a/wrapyfi/encoders.py
+++ b/wrapyfi/encoders.py
@@ -18,13 +18,14 @@ class JsonEncoder(json.JSONEncoder):
- Numpy ndarray objects
- Objects registered with the PluginRegistrar
"""
+
def __init__(self, **kwargs):
"""
Initialize the JsonEncoder.
:param kwargs: dict: Additional keyword arguments extracting values from the 'serializer_kwargs' key and passing them to the base class. All other keyword arguments are passed to the corresponding Plugin.
"""
- super().__init__(**kwargs.get('serializer_kwargs', {}))
+ super().__init__(**kwargs.get("serializer_kwargs", {}))
self.plugins = dict()
for plugin_key, plugin_val in PluginRegistrar.encoder_registry.items():
self.plugins[plugin_key] = plugin_val(**kwargs)
@@ -37,7 +38,7 @@ def find_plugin(self, obj):
:return: Plugin: The plugin for the given object if its type is registered, None otherwise
"""
for cls in reversed(type(obj).__mro__[:-1]):
- if cls.__module__ == 'collections.abc':
+ if cls.__module__ == "collections.abc":
continue # skip classes from collections.abc
if issubclass(cls, abc.ABCMeta):
if cls.__abstractmethods__:
@@ -52,9 +53,10 @@ def encode(self, obj):
:param obj: Any: The object to encode
:return: str: The JSON string representation of the object returned by the base class
"""
+
def hint_tuples(item):
if isinstance(item, tuple):
- return dict(__wrapyfi__=('tuple', item))
+ return dict(__wrapyfi__=("tuple", item))
if isinstance(item, list):
return [hint_tuples(e) for e in item]
if isinstance(item, dict):
@@ -72,19 +74,19 @@ def default(self, obj):
:return: dict: A dictionary containing the class name and encoded data string
"""
if isinstance(obj, set):
- return dict(__wrapyfi__=('set', list(obj)))
+ return dict(__wrapyfi__=("set", list(obj)))
elif isinstance(obj, datetime):
- return dict(__wrapyfi__=('datetime', obj.isoformat()))
+ return dict(__wrapyfi__=("datetime", obj.isoformat()))
elif isinstance(obj, np.datetime64):
- return dict(__wrapyfi__=('numpy.datetime64', str(obj)))
+ return dict(__wrapyfi__=("numpy.datetime64", str(obj)))
elif isinstance(obj, (np.ndarray, np.generic)):
with io.BytesIO() as memfile:
np.save(memfile, obj)
- obj_data = base64.b64encode(memfile.getvalue()).decode('ascii')
- return dict(__wrapyfi__=('numpy.ndarray', obj_data))
+ obj_data = base64.b64encode(memfile.getvalue()).decode("ascii")
+ return dict(__wrapyfi__=("numpy.ndarray", obj_data))
plugin_match = self.find_plugin(obj)
if plugin_match is not None:
@@ -106,6 +108,7 @@ class JsonDecodeHook(object):
- Numpy ndarray objects
- Objects registered with the PluginRegistrar
"""
+
def __init__(self, **kwargs):
"""
Initialize the JsonDecodeHook.
@@ -124,24 +127,26 @@ def object_hook(self, obj):
:return: Any: The decoded object
"""
if isinstance(obj, dict):
- wrapyfi = obj.get('__wrapyfi__', None)
+ wrapyfi = obj.get("__wrapyfi__", None)
if wrapyfi is not None:
obj_type = wrapyfi[0]
- if obj_type == 'tuple':
+ if obj_type == "tuple":
return tuple(wrapyfi[1])
- elif obj_type == 'set':
+ elif obj_type == "set":
return set(wrapyfi[1])
- elif obj_type == 'datetime':
+ elif obj_type == "datetime":
return datetime.fromisoformat(wrapyfi[1])
- elif obj_type == 'numpy.datetime64':
+ elif obj_type == "numpy.datetime64":
return np.datetime64(wrapyfi[1])
- elif obj_type == 'numpy.ndarray':
- with io.BytesIO(base64.b64decode(wrapyfi[1].encode('ascii'))) as memfile:
+ elif obj_type == "numpy.ndarray":
+ with io.BytesIO(
+ base64.b64decode(wrapyfi[1].encode("ascii"))
+ ) as memfile:
return np.load(memfile)
plugin_match = self.plugins.get(obj_type, None)
diff --git a/wrapyfi/listeners/__init__.py b/wrapyfi/listeners/__init__.py
index 95bd55e..65a8ce4 100755
--- a/wrapyfi/listeners/__init__.py
+++ b/wrapyfi/listeners/__init__.py
@@ -6,11 +6,22 @@
@Listeners.register("MMO", "fallback")
class FallbackListener(Listener):
- def __init__(self, name: str, in_topic: str, carrier: str = "tcp",
- should_wait: bool = True, missing_middleware_object: str = "", **kwargs):
- logging.warning(f"Fallback listener employed due to missing middleware or object type: "
- f"{missing_middleware_object}")
- Listener.__init__(self, name, in_topic, carrier=carrier, should_wait=should_wait, **kwargs)
+ def __init__(
+ self,
+ name: str,
+ in_topic: str,
+ carrier: str = "tcp",
+ should_wait: bool = True,
+ missing_middleware_object: str = "",
+ **kwargs,
+ ):
+ logging.warning(
+ f"Fallback listener employed due to missing middleware or object type: "
+ f"{missing_middleware_object}"
+ )
+ Listener.__init__(
+ self, name, in_topic, carrier=carrier, should_wait=should_wait, **kwargs
+ )
self.missing_middleware_object = missing_middleware_object
def establish(self, repeats: int = -1, **kwargs):
@@ -20,4 +31,4 @@ def listen(self):
return None
def close(self):
- return None
\ No newline at end of file
+ return None
diff --git a/wrapyfi/listeners/ros.py b/wrapyfi/listeners/ros.py
index 456cdf3..fbd1772 100755
--- a/wrapyfi/listeners/ros.py
+++ b/wrapyfi/listeners/ros.py
@@ -24,8 +24,16 @@
class ROSListener(Listener):
- def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait: bool = True,
- queue_size: int = QUEUE_SIZE, ros_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ in_topic: str,
+ carrier: str = "tcp",
+ should_wait: bool = True,
+ queue_size: int = QUEUE_SIZE,
+ ros_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
Initialize the subscriber.
@@ -38,11 +46,15 @@ def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait:
:param kwargs: dict: Additional kwargs for the subscriber
"""
if carrier or carrier != "tcp":
- logging.warning("[ROS] ROS does not support other carriers than TCP for PUB/SUB pattern. Using TCP.")
+ logging.warning(
+ "[ROS] ROS does not support other carriers than TCP for PUB/SUB pattern. Using TCP."
+ )
carrier = "tcp"
- super().__init__(name, in_topic, carrier=carrier, should_wait=should_wait, **kwargs)
+ super().__init__(
+ name, in_topic, carrier=carrier, should_wait=should_wait, **kwargs
+ )
ROSMiddleware.activate(**ros_kwargs or {})
-
+
self.queue_size = queue_size
def close(self):
@@ -60,8 +72,16 @@ def __del__(self):
@Listeners.register("NativeObject", "ros")
class ROSNativeObjectListener(ROSListener):
- def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait: bool = True, queue_size: int =QUEUE_SIZE,
- deserializer_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ in_topic: str,
+ carrier: str = "tcp",
+ should_wait: bool = True,
+ queue_size: int = QUEUE_SIZE,
+ deserializer_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
The NativeObject listener using the ROS String message assuming the data is serialized as a JSON string.
Deserializes the data (including plugins) using the decoder and parses it to a Python object.
@@ -73,7 +93,14 @@ def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait:
:param queue_size: int: Size of the queue for the subscriber. Default is 5
:param deserializer_kwargs: dict: Additional kwargs for the deserializer
"""
- super().__init__(name, in_topic, carrier=carrier, should_wait=should_wait, queue_size=queue_size, **kwargs)
+ super().__init__(
+ name,
+ in_topic,
+ carrier=carrier,
+ should_wait=should_wait,
+ queue_size=queue_size,
+ **kwargs,
+ )
self._subscriber = self._queue = None
@@ -86,8 +113,16 @@ def establish(self):
"""
Establish the subscriber.
"""
- self._queue = queue.Queue(maxsize=0 if self.queue_size is None or self.queue_size <= 0 else self.queue_size)
- self._subscriber = rospy.Subscriber(self.in_topic, std_msgs.msg.String, callback=self._message_callback)
+ self._queue = queue.Queue(
+ maxsize=(
+ 0
+ if self.queue_size is None or self.queue_size <= 0
+ else self.queue_size
+ )
+ )
+ self._subscriber = rospy.Subscriber(
+ self.in_topic, std_msgs.msg.String, callback=self._message_callback
+ )
self.established = True
def listen(self) -> Any:
@@ -100,7 +135,11 @@ def listen(self) -> Any:
self.establish()
try:
obj_str = self._queue.get(block=self.should_wait)
- return json.loads(obj_str, object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs)
+ return json.loads(
+ obj_str,
+ object_hook=self._plugin_decoder_hook,
+ **self._deserializer_kwargs,
+ )
except queue.Empty:
return None
@@ -113,14 +152,28 @@ def _message_callback(self, msg):
try:
self._queue.put(msg.data, block=False)
except queue.Full:
- logging.warning(f"[ROS] Discarding data because listener queue is full: {self.in_topic}")
+ logging.warning(
+ f"[ROS] Discarding data because listener queue is full: {self.in_topic}"
+ )
@Listeners.register("Image", "ros")
class ROSImageListener(ROSListener):
- def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait: bool = True, queue_size: int = QUEUE_SIZE,
- width: int = -1, height: int = -1, rgb: bool = True, fp: bool = False, jpg: bool = False, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ in_topic: str,
+ carrier: str = "tcp",
+ should_wait: bool = True,
+ queue_size: int = QUEUE_SIZE,
+ width: int = -1,
+ height: int = -1,
+ rgb: bool = True,
+ fp: bool = False,
+ jpg: bool = False,
+ **kwargs,
+ ):
"""
The Image listener using the ROS Image message parsed to a numpy array.
@@ -135,7 +188,14 @@ def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait:
:param fp: bool: True if the image is floating point, False if it is integer. Default is False
:param jpg: bool: True if the image should be decompressed from JPG. Default is False
"""
- super().__init__(name, in_topic, carrier=carrier, should_wait=should_wait, queue_size=queue_size, **kwargs)
+ super().__init__(
+ name,
+ in_topic,
+ carrier=carrier,
+ should_wait=should_wait,
+ queue_size=queue_size,
+ **kwargs,
+ )
self.width = width
self.height = height
self.rgb = rgb
@@ -143,13 +203,13 @@ def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait:
self.jpg = jpg
if self.fp:
- self._encoding = '32FC3' if self.rgb else '32FC1'
+ self._encoding = "32FC3" if self.rgb else "32FC1"
self._type = np.float32
else:
- self._encoding = 'bgr8' if self.rgb else 'mono8'
+ self._encoding = "bgr8" if self.rgb else "mono8"
self._type = np.uint8
if self.jpg:
- self._encoding = 'jpeg'
+ self._encoding = "jpeg"
self._type = np.uint8
self._pixel_bytes = (3 if self.rgb else 1) * np.dtype(self._type).itemsize
@@ -162,11 +222,23 @@ def establish(self):
"""
Establish the subscriber.
"""
- self._queue = queue.Queue(maxsize=0 if self.queue_size is None or self.queue_size <= 0 else self.queue_size)
+ self._queue = queue.Queue(
+ maxsize=(
+ 0
+ if self.queue_size is None or self.queue_size <= 0
+ else self.queue_size
+ )
+ )
if self.jpg:
- self._subscriber = rospy.Subscriber(self.in_topic, sensor_msgs.msg.CompressedImage, callback=self._message_callback)
+ self._subscriber = rospy.Subscriber(
+ self.in_topic,
+ sensor_msgs.msg.CompressedImage,
+ callback=self._message_callback,
+ )
else:
- self._subscriber = rospy.Subscriber(self.in_topic, sensor_msgs.msg.Image, callback=self._message_callback)
+ self._subscriber = rospy.Subscriber(
+ self.in_topic, sensor_msgs.msg.Image, callback=self._message_callback
+ )
self.established = True
def listen(self):
@@ -180,19 +252,32 @@ def listen(self):
try:
if self.jpg:
format, data = self._queue.get(block=self.should_wait)
- if format != 'jpeg':
+ if format != "jpeg":
raise ValueError(f"Unsupported image format: {format}")
if self.rgb:
img = cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_COLOR)
else:
- img = cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_GRAYSCALE)
+ img = cv2.imdecode(
+ np.frombuffer(data, np.uint8), cv2.IMREAD_GRAYSCALE
+ )
else:
- height, width, encoding, is_bigendian, data = self._queue.get(block=self.should_wait)
+ height, width, encoding, is_bigendian, data = self._queue.get(
+ block=self.should_wait
+ )
if encoding != self._encoding:
raise ValueError("Incorrect encoding for listener")
- if 0 < self.width != width or 0 < self.height != height or len(data) != height * width * self._pixel_bytes:
+ if (
+ 0 < self.width != width
+ or 0 < self.height != height
+ or len(data) != height * width * self._pixel_bytes
+ ):
raise ValueError("Incorrect image shape for listener")
- img = np.frombuffer(data, dtype=np.dtype(self._type).newbyteorder('>' if is_bigendian else '<')).reshape((height, width, -1))
+ img = np.frombuffer(
+ data,
+ dtype=np.dtype(self._type).newbyteorder(
+ ">" if is_bigendian else "<"
+ ),
+ ).reshape((height, width, -1))
if img.shape[2] == 1:
img = img.squeeze(axis=2)
return img
@@ -209,16 +294,37 @@ def _message_callback(self, data):
if self.jpg:
self._queue.put((data.format, data.data), block=False)
else:
- self._queue.put((data.height, data.width, data.encoding, data.is_bigendian, data.data), block=False)
+ self._queue.put(
+ (
+ data.height,
+ data.width,
+ data.encoding,
+ data.is_bigendian,
+ data.data,
+ ),
+ block=False,
+ )
except queue.Full:
- logging.warning(f"[ROS] Discarding data because listener queue is full: {self.in_topic}")
+ logging.warning(
+ f"[ROS] Discarding data because listener queue is full: {self.in_topic}"
+ )
@Listeners.register("AudioChunk", "ros")
class ROSAudioChunkListener(ROSListener):
- def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait: bool = True, queue_size: int = QUEUE_SIZE,
- channels: int = 1, rate: int = 44100, chunk: int = -1, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ in_topic: str,
+ carrier: str = "tcp",
+ should_wait: bool = True,
+ queue_size: int = QUEUE_SIZE,
+ channels: int = 1,
+ rate: int = 44100,
+ chunk: int = -1,
+ **kwargs,
+ ):
"""
The AudioChunk listener using the ROS Image message parsed to a numpy array.
@@ -231,8 +337,19 @@ def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait:
:param rate: int: Sampling rate of the audio. Default is 44100
:param chunk: int: Number of samples in the audio chunk. Default is -1 (use the chunk size of the received audio)
"""
- super().__init__(name, in_topic, carrier=carrier, should_wait=should_wait, queue_size=queue_size,
- width=chunk, height=channels, rgb=False, fp=True, jpg=False, **kwargs)
+ super().__init__(
+ name,
+ in_topic,
+ carrier=carrier,
+ should_wait=should_wait,
+ queue_size=queue_size,
+ width=chunk,
+ height=channels,
+ rgb=False,
+ fp=True,
+ jpg=False,
+ **kwargs,
+ )
self.channels = channels
self.rate = rate
self.chunk = chunk
@@ -248,13 +365,25 @@ def establish(self):
from wrapyfi_ros_interfaces.msg import ROSAudioMessage
except ImportError:
import wrapyfi
- logging.error("[ROS] Could not import ROSAudioMessage. "
- "Make sure the ROS messages in wrapyfi_extensions/wrapyfi_ros_interfaces are compiled. "
- "Refer to the documentation for more information: \n" +
- wrapyfi.__doc__ + "ros_interfaces_lnk.html")
+
+ logging.error(
+ "[ROS] Could not import ROSAudioMessage. "
+ "Make sure the ROS messages in wrapyfi_extensions/wrapyfi_ros_interfaces are compiled. "
+ "Refer to the documentation for more information: \n"
+ + wrapyfi.__doc__
+ + "ros_interfaces_lnk.html"
+ )
sys.exit(1)
- self._queue = queue.Queue(maxsize=0 if self.queue_size is None or self.queue_size <= 0 else self.queue_size)
- self._subscriber = rospy.Subscriber(self.in_topic, ROSAudioMessage, callback=self._message_callback)
+ self._queue = queue.Queue(
+ maxsize=(
+ 0
+ if self.queue_size is None or self.queue_size <= 0
+ else self.queue_size
+ )
+ )
+ self._subscriber = rospy.Subscriber(
+ self.in_topic, ROSAudioMessage, callback=self._message_callback
+ )
self.established = True
def listen(self):
@@ -266,14 +395,23 @@ def listen(self):
if not self.established:
self.establish()
try:
- chunk, channels, rate, encoding, is_bigendian, data = self._queue.get(block=self.should_wait)
+ chunk, channels, rate, encoding, is_bigendian, data = self._queue.get(
+ block=self.should_wait
+ )
if 0 < self.rate != rate:
raise ValueError("Incorrect audio rate for listener")
- if encoding not in ['S16LE', 'S16BE']:
+ if encoding not in ["S16LE", "S16BE"]:
raise ValueError("Incorrect encoding for listener")
- if 0 < self.chunk != chunk or self.channels != channels or len(data) != chunk * channels * 4:
+ if (
+ 0 < self.chunk != chunk
+ or self.channels != channels
+ or len(data) != chunk * channels * 4
+ ):
raise ValueError("Incorrect audio shape for listener")
- aud = np.frombuffer(data, dtype=np.dtype(np.float32).newbyteorder('>' if is_bigendian else '<')).reshape((chunk, channels))
+ aud = np.frombuffer(
+ data,
+ dtype=np.dtype(np.float32).newbyteorder(">" if is_bigendian else "<"),
+ ).reshape((chunk, channels))
# aud = aud / 32767.0
return aud, rate
except queue.Empty:
@@ -286,9 +424,21 @@ def _message_callback(self, data):
:param data: wrapyfi_ros_interfaces.msg.ROSAudioMessage: The received message
"""
try:
- self._queue.put((data.chunk_size, data.channels, data.sample_rate, data.encoding, data.is_bigendian, data.data), block=False)
+ self._queue.put(
+ (
+ data.chunk_size,
+ data.channels,
+ data.sample_rate,
+ data.encoding,
+ data.is_bigendian,
+ data.data,
+ ),
+ block=False,
+ )
except queue.Full:
- logging.warning(f"[ROS] Discarding data because listener queue is full: {self.in_topic}")
+ logging.warning(
+ f"[ROS] Discarding data because listener queue is full: {self.in_topic}"
+ )
@Listeners.register("Properties", "ros")
@@ -300,7 +450,16 @@ class ROSPropertiesListener(ROSListener):
but care should be taken when using dictionaries, since they are analogous with node namespaces:
http://wiki.ros.org/rospy/Overview/Parameter%20Server
"""
- def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait: bool = True, queue_size: int = QUEUE_SIZE, **kwargs):
+
+ def __init__(
+ self,
+ name: str,
+ in_topic: str,
+ carrier: str = "tcp",
+ should_wait: bool = True,
+ queue_size: int = QUEUE_SIZE,
+ **kwargs,
+ ):
"""
The PropertiesListener using the ROS Parameter Server.
@@ -310,7 +469,14 @@ def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait:
:param should_wait: bool: Whether the subscriber should wait for a parameter to be set. Default is True
:param queue_size: int: Size of the queue for the subscriber. Default is 5
"""
- super().__init__(name, in_topic, carrier=carrier, should_wait=should_wait, queue_size=queue_size, **kwargs)
+ super().__init__(
+ name,
+ in_topic,
+ carrier=carrier,
+ should_wait=should_wait,
+ queue_size=queue_size,
+ **kwargs,
+ )
self._subscriber = self._queue = None
if not self.should_wait:
@@ -318,7 +484,9 @@ def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait:
self.previous_property = False
- def await_connection(self, in_topic: Optional[int] = None, repeats: Optional[int] = None):
+ def await_connection(
+ self, in_topic: Optional[int] = None, repeats: Optional[int] = None
+ ):
"""
Wait for a parameter to be set.
@@ -375,7 +543,15 @@ def listen(self):
@Listeners.register("ROSMessage", "ros")
class ROSMessageListener(ROSListener):
- def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait: bool = True, queue_size: int = QUEUE_SIZE, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ in_topic: str,
+ carrier: str = "tcp",
+ should_wait: bool = True,
+ queue_size: int = QUEUE_SIZE,
+ **kwargs,
+ ):
"""
The ROSMessageListener using the ROS message type inferred from the message type. Supports standard ROS msgs.
@@ -385,7 +561,14 @@ def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait:
:param should_wait: bool: Whether the subscriber should wait for a message to be published. Default is True
:param queue_size: int: Size of the queue for the subscriber. Default is 5
"""
- super().__init__(name, in_topic, carrier=carrier, should_wait=should_wait, queue_size=queue_size, **kwargs)
+ super().__init__(
+ name,
+ in_topic,
+ carrier=carrier,
+ should_wait=should_wait,
+ queue_size=queue_size,
+ **kwargs,
+ )
self._subscriber = self._queue = self._topic_type = None
ListenerWatchDog().add_listener(self)
@@ -394,11 +577,21 @@ def establish(self):
"""
Establish the subscriber.
"""
- self._queue = queue.Queue(maxsize=0 if self.queue_size is None or self.queue_size <= 0 else self.queue_size)
- self._topic_type, topic_str, _ = rostopic.get_topic_class(self.in_topic, blocking=self.should_wait)
+ self._queue = queue.Queue(
+ maxsize=(
+ 0
+ if self.queue_size is None or self.queue_size <= 0
+ else self.queue_size
+ )
+ )
+ self._topic_type, topic_str, _ = rostopic.get_topic_class(
+ self.in_topic, blocking=self.should_wait
+ )
if self._topic_type is None:
return
- self._subscriber = rospy.Subscriber(self.in_topic, self._topic_type, callback=self._message_callback)
+ self._subscriber = rospy.Subscriber(
+ self.in_topic, self._topic_type, callback=self._message_callback
+ )
self.established = True
def listen(self):
@@ -425,5 +618,6 @@ def _message_callback(self, msg):
try:
self._queue.put(msg, block=False)
except queue.Full:
- logging.warning(f"[ROS] Discarding data because listener queue is full: {self.in_topic}")
-
+ logging.warning(
+ f"[ROS] Discarding data because listener queue is full: {self.in_topic}"
+ )
diff --git a/wrapyfi/listeners/ros2.py b/wrapyfi/listeners/ros2.py
index 984528b..d498dae 100755
--- a/wrapyfi/listeners/ros2.py
+++ b/wrapyfi/listeners/ros2.py
@@ -27,8 +27,15 @@
class ROS2Listener(Listener, Node):
- def __init__(self, name: str, in_topic: str, should_wait: bool = True,
- queue_size: int = QUEUE_SIZE, ros2_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ in_topic: str,
+ should_wait: bool = True,
+ queue_size: int = QUEUE_SIZE,
+ ros2_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
Initialize the subscriber.
@@ -41,12 +48,16 @@ def __init__(self, name: str, in_topic: str, should_wait: bool = True,
"""
carrier = "tcp"
if "carrier" in kwargs and kwargs["carrier"] not in ["", None]:
- logging.warning("[ROS 2] ROS 2 currently does not support explicit carrier setting for PUB/SUB pattern. Using TCP.")
+ logging.warning(
+ "[ROS 2] ROS 2 currently does not support explicit carrier setting for PUB/SUB pattern. Using TCP."
+ )
if "carrier" in kwargs:
del kwargs["carrier"]
ROS2Middleware.activate(**ros2_kwargs or {})
- Listener.__init__(self, name, in_topic, carrier=carrier, should_wait=should_wait, **kwargs)
+ Listener.__init__(
+ self, name, in_topic, carrier=carrier, should_wait=should_wait, **kwargs
+ )
Node.__init__(self, name + str(hex(id(self))), allow_undeclared_parameters=True)
self.queue_size = queue_size
@@ -66,8 +77,15 @@ def __del__(self):
@Listeners.register("NativeObject", "ros2")
class ROS2NativeObjectListener(ROS2Listener):
- def __init__(self, name: str, in_topic: str, should_wait: bool = True, queue_size: int = QUEUE_SIZE,
- deserializer_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ in_topic: str,
+ should_wait: bool = True,
+ queue_size: int = QUEUE_SIZE,
+ deserializer_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
The NativeObject listener using the ROS 2 String message assuming the data is serialized as a JSON string.
Deserializes the data (including plugins) using the decoder and parses it to a native object.
@@ -78,7 +96,9 @@ def __init__(self, name: str, in_topic: str, should_wait: bool = True, queue_siz
:param queue_size: int: Size of the queue for the subscriber. Default is 5
:param deserializer_kwargs: dict: Additional kwargs for the deserializer
"""
- super().__init__(name, in_topic, should_wait=should_wait, queue_size=queue_size, **kwargs)
+ super().__init__(
+ name, in_topic, should_wait=should_wait, queue_size=queue_size, **kwargs
+ )
self._subscriber = self._queue = None
@@ -91,8 +111,19 @@ def establish(self):
"""
Establish the subscriber.
"""
- self._queue = queue.Queue(maxsize=0 if self.queue_size is None or self.queue_size <= 0 else self.queue_size)
- self._subscriber = self.create_subscription(std_msgs.msg.String, self.in_topic, callback=self._message_callback, qos_profile=self.queue_size)
+ self._queue = queue.Queue(
+ maxsize=(
+ 0
+ if self.queue_size is None or self.queue_size <= 0
+ else self.queue_size
+ )
+ )
+ self._subscriber = self.create_subscription(
+ std_msgs.msg.String,
+ self.in_topic,
+ callback=self._message_callback,
+ qos_profile=self.queue_size,
+ )
self.established = True
def listen(self):
@@ -106,7 +137,11 @@ def listen(self):
try:
rclpy.spin_once(self, timeout_sec=WAIT[self.should_wait])
obj_str = self._queue.get(block=self.should_wait)
- return json.loads(obj_str, object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs)
+ return json.loads(
+ obj_str,
+ object_hook=self._plugin_decoder_hook,
+ **self._deserializer_kwargs,
+ )
except queue.Empty:
return None
@@ -119,14 +154,27 @@ def _message_callback(self, msg):
try:
self._queue.put(msg.data, block=False)
except queue.Full:
- logging.warning(f"[ROS 2] Discarding data because listener queue is full: {self.in_topic}")
+ logging.warning(
+ f"[ROS 2] Discarding data because listener queue is full: {self.in_topic}"
+ )
@Listeners.register("Image", "ros2")
class ROS2ImageListener(ROS2Listener):
- def __init__(self, name: str, in_topic: str, should_wait: bool = True, queue_size: int = QUEUE_SIZE,
- width: int = -1, height: int = -1, rgb: bool = True, fp: bool = False, jpg: bool = False, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ in_topic: str,
+ should_wait: bool = True,
+ queue_size: int = QUEUE_SIZE,
+ width: int = -1,
+ height: int = -1,
+ rgb: bool = True,
+ fp: bool = False,
+ jpg: bool = False,
+ **kwargs,
+ ):
"""
The Image listener using the ROS 2 Image message parsed to a numpy array.
@@ -140,7 +188,9 @@ def __init__(self, name: str, in_topic: str, should_wait: bool = True, queue_siz
:param fp: bool: True if the image is floating point, False if it is integer. Default is False
:param jpg: bool: True if the image should be decompressed from JPG. Default is False
"""
- super().__init__(name, in_topic, should_wait=should_wait, queue_size=queue_size, **kwargs)
+ super().__init__(
+ name, in_topic, should_wait=should_wait, queue_size=queue_size, **kwargs
+ )
self.width = width
self.height = height
self.rgb = rgb
@@ -148,13 +198,13 @@ def __init__(self, name: str, in_topic: str, should_wait: bool = True, queue_siz
self.jpg = jpg
if self.fp:
- self._encoding = '32FC3' if self.rgb else '32FC1'
+ self._encoding = "32FC3" if self.rgb else "32FC1"
self._type = np.float32
else:
- self._encoding = 'bgr8' if self.rgb else 'mono8'
+ self._encoding = "bgr8" if self.rgb else "mono8"
self._type = np.uint8
if self.jpg:
- self._encoding = 'jpeg'
+ self._encoding = "jpeg"
self._type = np.uint8
self._pixel_bytes = (3 if self.rgb else 1) * np.dtype(self._type).itemsize
@@ -167,11 +217,27 @@ def establish(self):
"""
Establish the subscriber
"""
- self._queue = queue.Queue(maxsize=0 if self.queue_size is None or self.queue_size <= 0 else self.queue_size)
+ self._queue = queue.Queue(
+ maxsize=(
+ 0
+ if self.queue_size is None or self.queue_size <= 0
+ else self.queue_size
+ )
+ )
if self.jpg:
- self._subscriber = self.create_subscription(sensor_msgs.msg.CompressedImage, self.in_topic, callback=self._message_callback, qos_profile=self.queue_size)
+ self._subscriber = self.create_subscription(
+ sensor_msgs.msg.CompressedImage,
+ self.in_topic,
+ callback=self._message_callback,
+ qos_profile=self.queue_size,
+ )
else:
- self._subscriber = self.create_subscription(sensor_msgs.msg.Image, self.in_topic, callback=self._message_callback, qos_profile=self.queue_size)
+ self._subscriber = self.create_subscription(
+ sensor_msgs.msg.Image,
+ self.in_topic,
+ callback=self._message_callback,
+ qos_profile=self.queue_size,
+ )
self.established = True
def listen(self):
@@ -186,19 +252,32 @@ def listen(self):
rclpy.spin_once(self, timeout_sec=WAIT[self.should_wait])
if self.jpg:
format, data = self._queue.get(block=self.should_wait)
- if format != 'jpeg':
+ if format != "jpeg":
raise ValueError(f"Unsupported image format: {format}")
if self.rgb:
img = cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_COLOR)
else:
- img = cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_GRAYSCALE)
+ img = cv2.imdecode(
+ np.frombuffer(data, np.uint8), cv2.IMREAD_GRAYSCALE
+ )
else:
- height, width, encoding, is_bigendian, data = self._queue.get(block=self.should_wait)
+ height, width, encoding, is_bigendian, data = self._queue.get(
+ block=self.should_wait
+ )
if encoding != self._encoding:
raise ValueError("Incorrect encoding for listener")
- if 0 < self.width != width or 0 < self.height != height or len(data) != height * width * self._pixel_bytes:
+ if (
+ 0 < self.width != width
+ or 0 < self.height != height
+ or len(data) != height * width * self._pixel_bytes
+ ):
raise ValueError("Incorrect image shape for listener")
- img = np.frombuffer(data, dtype=np.dtype(self._type).newbyteorder('>' if is_bigendian else '<')).reshape((height, width, -1))
+ img = np.frombuffer(
+ data,
+ dtype=np.dtype(self._type).newbyteorder(
+ ">" if is_bigendian else "<"
+ ),
+ ).reshape((height, width, -1))
if img.shape[2] == 1:
img = img.squeeze(axis=2)
return img
@@ -215,16 +294,30 @@ def _message_callback(self, msg):
if self.jpg:
self._queue.put((msg.format, msg.data), block=False)
else:
- self._queue.put((msg.height, msg.width, msg.encoding, msg.is_bigendian, msg.data), block=False)
+ self._queue.put(
+ (msg.height, msg.width, msg.encoding, msg.is_bigendian, msg.data),
+ block=False,
+ )
except queue.Full:
- logging.warning(f"[ROS 2] Discarding data because listener queue is full: {self.in_topic}")
+ logging.warning(
+ f"[ROS 2] Discarding data because listener queue is full: {self.in_topic}"
+ )
@Listeners.register("AudioChunk", "ros2")
class ROS2AudioChunkListener(ROS2Listener):
- def __init__(self, name: str, in_topic: str, should_wait: bool = True,
- queue_size: int = QUEUE_SIZE, channels: int = 1, rate: int = 44100, chunk: int = -1, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ in_topic: str,
+ should_wait: bool = True,
+ queue_size: int = QUEUE_SIZE,
+ channels: int = 1,
+ rate: int = 44100,
+ chunk: int = -1,
+ **kwargs,
+ ):
"""
The AudioChunk listener using the ROS 2 Audio message parsed to a numpy array.
@@ -236,7 +329,9 @@ def __init__(self, name: str, in_topic: str, should_wait: bool = True,
:param rate: int: Sampling rate of the audio. Default is 44100
:param chunk: int: Number of samples in the audio chunk. Default is -1 (use the chunk size of the received audio)
"""
- super().__init__(name, in_topic, should_wait=should_wait, queue_size=queue_size, **kwargs)
+ super().__init__(
+ name, in_topic, should_wait=should_wait, queue_size=queue_size, **kwargs
+ )
self.channels = channels
self.rate = rate
self.chunk = chunk
@@ -253,13 +348,28 @@ def establish(self):
from wrapyfi_ros2_interfaces.msg import ROS2AudioMessage
except ImportError:
import wrapyfi
- logging.error("[ROS 2] Could not import ROS2AudioMessage. "
- "Make sure the ROS 2 services in wrapyfi_extensions/wrapyfi_ros2_interfaces are compiled. "
- "Refer to the documentation for more information: \n" +
- wrapyfi.__doc__ + "ros2_interfaces_lnk.html")
+
+ logging.error(
+ "[ROS 2] Could not import ROS2AudioMessage. "
+ "Make sure the ROS 2 services in wrapyfi_extensions/wrapyfi_ros2_interfaces are compiled. "
+ "Refer to the documentation for more information: \n"
+ + wrapyfi.__doc__
+ + "ros2_interfaces_lnk.html"
+ )
sys.exit(1)
- self._queue = queue.Queue(maxsize=0 if self.queue_size is None or self.queue_size <= 0 else self.queue_size)
- self._subscriber = self.create_subscription(ROS2AudioMessage, self.in_topic, callback=self._message_callback, qos_profile=self.queue_size)
+ self._queue = queue.Queue(
+ maxsize=(
+ 0
+ if self.queue_size is None or self.queue_size <= 0
+ else self.queue_size
+ )
+ )
+ self._subscriber = self.create_subscription(
+ ROS2AudioMessage,
+ self.in_topic,
+ callback=self._message_callback,
+ qos_profile=self.queue_size,
+ )
self.established = True
def listen(self):
@@ -272,14 +382,23 @@ def listen(self):
self.establish()
try:
rclpy.spin_once(self, timeout_sec=WAIT[self.should_wait])
- chunk, channels, rate, encoding, is_bigendian, data = self._queue.get(block=self.should_wait)
+ chunk, channels, rate, encoding, is_bigendian, data = self._queue.get(
+ block=self.should_wait
+ )
if 0 < self.rate != rate:
raise ValueError("Incorrect audio rate for publisher")
- if encoding not in ['S16LE', 'S16BE']:
+ if encoding not in ["S16LE", "S16BE"]:
raise ValueError("Incorrect encoding for listener")
- if 0 < self.chunk != chunk or self.channels != channels or len(data) != chunk * channels * 4:
+ if (
+ 0 < self.chunk != chunk
+ or self.channels != channels
+ or len(data) != chunk * channels * 4
+ ):
raise ValueError("Incorrect audio shape for listener")
- aud = np.frombuffer(data, dtype=np.dtype(np.float32).newbyteorder('>' if is_bigendian else '<')).reshape((chunk, channels))
+ aud = np.frombuffer(
+ data,
+ dtype=np.dtype(np.float32).newbyteorder(">" if is_bigendian else "<"),
+ ).reshape((chunk, channels))
# aud = aud / 32767.0
return aud, rate
except queue.Empty:
@@ -292,9 +411,21 @@ def _message_callback(self, msg):
:param msg: wrapyfi_ros2_interfaces.msg.ROS2AudioMessage: The received message
"""
try:
- self._queue.put((msg.chunk_size, msg.channels, msg.sample_rate, msg.encoding, msg.is_bigendian, msg.data), block=False)
+ self._queue.put(
+ (
+ msg.chunk_size,
+ msg.channels,
+ msg.sample_rate,
+ msg.encoding,
+ msg.is_bigendian,
+ msg.data,
+ ),
+ block=False,
+ )
except queue.Full:
- logging.warning(f"[ROS 2] Discarding data because listener queue is full: {self.in_topic}")
+ logging.warning(
+ f"[ROS 2] Discarding data because listener queue is full: {self.in_topic}"
+ )
@Listeners.register("Properties", "ros2")
@@ -307,7 +438,14 @@ def __init__(self, name, in_topic, **kwargs):
@Listeners.register("ROS2Message", "ros2")
class ROS2MessageListener(ROS2Listener):
- def __init__(self, name: str, in_topic: str, should_wait: bool = True, queue_size: int = QUEUE_SIZE, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ in_topic: str,
+ should_wait: bool = True,
+ queue_size: int = QUEUE_SIZE,
+ **kwargs,
+ ):
"""
The ROS2MessageListener using the ROS 2 message type inferred from the message type. Supports standard ROS 2 msgs.
@@ -316,8 +454,16 @@ def __init__(self, name: str, in_topic: str, should_wait: bool = True, queue_siz
:param should_wait: bool: Whether the subscriber should wait for the publisher to transmit a message. Default is True
:param queue_size: int: Size of the queue for the subscriber. Default is 5
"""
- super().__init__(name, in_topic, should_wait=should_wait, queue_size=queue_size, **kwargs)
- self._queue = queue.Queue(maxsize=0 if self.queue_size is None or self.queue_size <= 0 else self.queue_size)
+ super().__init__(
+ name, in_topic, should_wait=should_wait, queue_size=queue_size, **kwargs
+ )
+ self._queue = queue.Queue(
+ maxsize=(
+ 0
+ if self.queue_size is None or self.queue_size <= 0
+ else self.queue_size
+ )
+ )
def get_topic_type(self, topic_name):
"""
@@ -345,11 +491,16 @@ def establish(self):
if not topic_type_str:
return None
- module_name, class_name = topic_type_str.rsplit('/', 1)
- module_name = module_name.replace('/', '.')
+ module_name, class_name = topic_type_str.rsplit("/", 1)
+ module_name = module_name.replace("/", ".")
MessageType = getattr(importlib.import_module(module_name), class_name)
- self._subscriber = self.create_subscription(MessageType, self.in_topic, callback=self._message_callback, qos_profile=self.queue_size)
+ self._subscriber = self.create_subscription(
+ MessageType,
+ self.in_topic,
+ callback=self._message_callback,
+ qos_profile=self.queue_size,
+ )
self.established = True
def listen(self):
@@ -377,4 +528,6 @@ def _message_callback(self, msg):
try:
self._queue.put(msg, block=False)
except queue.Full:
- logging.warning(f"[ROS 2] Discarding data because listener queue is full: {self.in_topic}")
+ logging.warning(
+ f"[ROS 2] Discarding data because listener queue is full: {self.in_topic}"
+ )
diff --git a/wrapyfi/listeners/yarp.py b/wrapyfi/listeners/yarp.py
index ec2f812..470252f 100755
--- a/wrapyfi/listeners/yarp.py
+++ b/wrapyfi/listeners/yarp.py
@@ -19,8 +19,16 @@
class YarpListener(Listener):
- def __init__(self, name: str, in_topic: str, carrier: Literal["tcp", "udp", "mcast"] = "tcp", should_wait: bool = True,
- persistent: bool = True, yarp_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ in_topic: str,
+ carrier: Literal["tcp", "udp", "mcast"] = "tcp",
+ should_wait: bool = True,
+ persistent: bool = True,
+ yarp_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
Initialize the subscriber.
@@ -32,14 +40,18 @@ def __init__(self, name: str, in_topic: str, carrier: Literal["tcp", "udp", "mca
:param yarp_kwargs: dict: Additional kwargs for the Yarp middleware
:param kwargs: dict: Additional kwargs for the subscriber
"""
- super().__init__(name, in_topic, carrier=carrier, should_wait=should_wait, **kwargs)
+ super().__init__(
+ name, in_topic, carrier=carrier, should_wait=should_wait, **kwargs
+ )
self.style = yarp.ContactStyle()
self.style.persistent = persistent
self.style.carrier = self.carrier
YarpMiddleware.activate(**yarp_kwargs or {})
- def await_connection(self, in_topic: Optional[str] = None, repeats: Optional[int] = None):
+ def await_connection(
+ self, in_topic: Optional[str] = None, repeats: Optional[int] = None
+ ):
"""
Wait for the publisher to connect to the subscriber.
@@ -95,8 +107,16 @@ def __del__(self):
@Listeners.register("NativeObject", "yarp")
class YarpNativeObjectListener(YarpListener):
- def __init__(self, name: str, in_topic: str, carrier: Literal["tcp", "udp", "mcast"] = "tcp", should_wait: bool = True,
- persistent: bool = True, deserializer_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ in_topic: str,
+ carrier: Literal["tcp", "udp", "mcast"] = "tcp",
+ should_wait: bool = True,
+ persistent: bool = True,
+ deserializer_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
The NativeObject listener using the BufferedPortBottle string construct assuming the data is serialized as a JSON string.
Deserializes the data (including plugins) using the decoder and parses it to a Python object.
@@ -108,7 +128,14 @@ def __init__(self, name: str, in_topic: str, carrier: Literal["tcp", "udp", "mca
:param persistent: bool: Whether the subscriber port should remain connected after closure. Default is True
:param deserializer_kwargs: dict: Additional kwargs for the deserializer
"""
- super().__init__(name, in_topic, carrier=carrier, should_wait=should_wait, persistent=persistent, **kwargs)
+ super().__init__(
+ name,
+ in_topic,
+ carrier=carrier,
+ should_wait=should_wait,
+ persistent=persistent,
+ **kwargs,
+ )
self._port = self._netconnect = None
self._plugin_decoder_hook = JsonDecodeHook(**kwargs).object_hook
@@ -130,9 +157,13 @@ def establish(self, repeats: Optional[int] = None, **kwargs):
rnd_id = str(np.random.randint(100000, size=1)[0])
self._port.open(self.in_topic + ":in" + rnd_id)
if self.style.persistent:
- self._netconnect = yarp.Network.connect(self.in_topic, self.in_topic + ":in" + rnd_id, self.style)
+ self._netconnect = yarp.Network.connect(
+ self.in_topic, self.in_topic + ":in" + rnd_id, self.style
+ )
else:
- self._netconnect = yarp.Network.connect(self.in_topic, self.in_topic + ":in" + rnd_id, self.carrier)
+ self._netconnect = yarp.Network.connect(
+ self.in_topic, self.in_topic + ":in" + rnd_id, self.carrier
+ )
return self.check_establishment(established)
def listen(self):
@@ -147,7 +178,11 @@ def listen(self):
return None
obj_port = self.read_port(self._port)
if obj_port is not None:
- return json.loads(obj_port.get(0).asString(), object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs)
+ return json.loads(
+ obj_port.get(0).asString(),
+ object_hook=self._plugin_decoder_hook,
+ **self._deserializer_kwargs,
+ )
else:
return None
@@ -155,8 +190,20 @@ def listen(self):
@Listeners.register("Image", "yarp")
class YarpImageListener(YarpListener):
- def __init__(self, name: str, in_topic: str, carrier: Literal["tcp", "udp", "mcast"] = "tcp", should_wait: bool = True,
- persistent: bool = True, width: int = -1, height: int = -1, rgb: bool = True, fp: bool = False, jpg: bool = False, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ in_topic: str,
+ carrier: Literal["tcp", "udp", "mcast"] = "tcp",
+ should_wait: bool = True,
+ persistent: bool = True,
+ width: int = -1,
+ height: int = -1,
+ rgb: bool = True,
+ fp: bool = False,
+ jpg: bool = False,
+ **kwargs,
+ ):
"""
The Image listener using the BufferedPortImage construct parsed to a numpy array.
@@ -171,7 +218,14 @@ def __init__(self, name: str, in_topic: str, carrier: Literal["tcp", "udp", "mca
:param fp: bool: True if the image is floating point, False if it is integer. Default is False
:param jpg: bool: True if the image should be decompressed from JPG. Default is False
"""
- super().__init__(name, in_topic, carrier=carrier, should_wait=should_wait, persistent=persistent, **kwargs)
+ super().__init__(
+ name,
+ in_topic,
+ carrier=carrier,
+ should_wait=should_wait,
+ persistent=persistent,
+ **kwargs,
+ )
self.width = width
self.height = height
self.rgb = rgb
@@ -195,16 +249,30 @@ def establish(self, repeats: Optional[int] = None, **kwargs):
if self.jpg:
self._port = yarp.BufferedPortBottle()
elif self.rgb:
- self._port = yarp.BufferedPortImageRgbFloat() if self.fp else yarp.BufferedPortImageRgb()
+ self._port = (
+ yarp.BufferedPortImageRgbFloat()
+ if self.fp
+ else yarp.BufferedPortImageRgb()
+ )
else:
- self._port = yarp.BufferedPortImageFloat() if self.fp else yarp.BufferedPortImageMono()
+ self._port = (
+ yarp.BufferedPortImageFloat()
+ if self.fp
+ else yarp.BufferedPortImageMono()
+ )
self._type = np.float32 if self.fp else np.uint8
- in_topic_connect = f"{self.in_topic}:in{np.random.randint(100000, size=1).item()}"
+ in_topic_connect = (
+ f"{self.in_topic}:in{np.random.randint(100000, size=1).item()}"
+ )
self._port.open(in_topic_connect)
if self.style.persistent:
- self._netconnect = yarp.Network.connect(self.in_topic, in_topic_connect, self.style)
+ self._netconnect = yarp.Network.connect(
+ self.in_topic, in_topic_connect, self.style
+ )
else:
- self._netconnect = yarp.Network.connect(self.in_topic, in_topic_connect, self.carrier)
+ self._netconnect = yarp.Network.connect(
+ self.in_topic, in_topic_connect, self.carrier
+ )
return self.check_establishment(established)
def listen(self):
@@ -222,21 +290,34 @@ def listen(self):
return None
if self.jpg:
img_str = ret_img_msg.get(0).asString()
- with io.BytesIO(base64.b64decode(img_str.encode('ascii'))) as memfile:
+ with io.BytesIO(base64.b64decode(img_str.encode("ascii"))) as memfile:
img_str = np.load(memfile)
if self.rgb:
img = cv2.imdecode(np.frombuffer(img_str, np.uint8), cv2.IMREAD_COLOR)
else:
- img = cv2.imdecode(np.frombuffer(img_str, np.uint8), cv2.IMREAD_GRAYSCALE)
+ img = cv2.imdecode(
+ np.frombuffer(img_str, np.uint8), cv2.IMREAD_GRAYSCALE
+ )
return img
else:
- if 0 < self.width != ret_img_msg.width() or 0 < self.height != ret_img_msg.height():
+ if (
+ 0 < self.width != ret_img_msg.width()
+ or 0 < self.height != ret_img_msg.height()
+ ):
raise ValueError("Incorrect image shape for listener")
elif self.rgb:
- img = np.zeros((ret_img_msg.height(), ret_img_msg.width(), 3), dtype=self._type, order='C')
+ img = np.zeros(
+ (ret_img_msg.height(), ret_img_msg.width(), 3),
+ dtype=self._type,
+ order="C",
+ )
img_port = yarp.ImageRgbFloat() if self.fp else yarp.ImageRgb()
else:
- img = np.zeros((ret_img_msg.height(), ret_img_msg.width()), dtype=self._type, order='C')
+ img = np.zeros(
+ (ret_img_msg.height(), ret_img_msg.width()),
+ dtype=self._type,
+ order="C",
+ )
img_port = yarp.ImageFloat() if self.fp else yarp.ImageMono()
img_port.resize(img.shape[1], img.shape[0])
img_port.setExternal(img.data, img.shape[1], img.shape[0])
@@ -247,8 +328,18 @@ def listen(self):
@Listeners.register("AudioChunk", "yarp")
class YarpAudioChunkListener(YarpListener):
- def __init__(self, name: str, in_topic: str, carrier: Literal["tcp", "udp", "mcast"] = "tcp", should_wait: bool = True,
- persistent: bool = True, channels: int = 1, rate: int = 44100, chunk: int = -1, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ in_topic: str,
+ carrier: Literal["tcp", "udp", "mcast"] = "tcp",
+ should_wait: bool = True,
+ persistent: bool = True,
+ channels: int = 1,
+ rate: int = 44100,
+ chunk: int = -1,
+ **kwargs,
+ ):
"""
The AudioChunk listener using the Sound construct parsed as a numpy array.
@@ -261,7 +352,14 @@ def __init__(self, name: str, in_topic: str, carrier: Literal["tcp", "udp", "mca
:param rate: int: Sampling rate of the audio. Default is 44100
:param chunk: int: Number of samples in the audio chunk. Default is -1 (use the chunk size of the received audio)
"""
- super().__init__(name, in_topic, carrier=carrier, should_wait=should_wait, persistent=persistent, **kwargs)
+ super().__init__(
+ name,
+ in_topic,
+ carrier=carrier,
+ should_wait=should_wait,
+ persistent=persistent,
+ **kwargs,
+ )
self.channels = channels
self.rate = rate
self.chunk = chunk
@@ -283,7 +381,9 @@ def establish(self, repeats: Optional[int] = None, **kwargs):
rnd_id = str(np.random.randint(100000, size=1)[0])
self._port = yarp.Port()
self._port.open(self.in_topic + ":in" + rnd_id)
- self._netconnect = yarp.Network.connect(self.in_topic, self.in_topic + ":in" + rnd_id, self.carrier)
+ self._netconnect = yarp.Network.connect(
+ self.in_topic, self.in_topic + ":in" + rnd_id, self.carrier
+ )
self._sound_msg = yarp.Sound()
self._port.read(self._sound_msg)
@@ -307,7 +407,10 @@ def listen(self):
if not established:
return None
self._port.read(self._sound_msg)
- aud = np.array([self._sound_msg.get(i) for i in range(self._sound_msg.getSamples())], dtype=np.int16)
+ aud = np.array(
+ [self._sound_msg.get(i) for i in range(self._sound_msg.getSamples())],
+ dtype=np.int16,
+ )
aud = aud.astype(np.float32) / 32767.0
return aud, self.rate
diff --git a/wrapyfi/listeners/zeromq.py b/wrapyfi/listeners/zeromq.py
index b668ddf..961ee8e 100755
--- a/wrapyfi/listeners/zeromq.py
+++ b/wrapyfi/listeners/zeromq.py
@@ -15,18 +15,32 @@
SOCKET_IP = os.environ.get("WRAPYFI_ZEROMQ_SOCKET_IP", "127.0.0.1")
SOCKET_PUB_PORT = int(os.environ.get("WRAPYFI_ZEROMQ_SOCKET_PUB_PORT", 5555))
-ZEROMQ_PUBSUB_MONITOR_TOPIC = os.environ.get("WRAPYFI_ZEROMQ_PUBSUB_MONITOR_TOPIC", "ZEROMQ/CONNECTIONS")
-ZEROMQ_PUBSUB_MONITOR_LISTENER_SPAWN = os.environ.get("WRAPYFI_ZEROMQ_PUBSUB_MONITOR_LISTENER_SPAWN", "process")
+ZEROMQ_PUBSUB_MONITOR_TOPIC = os.environ.get(
+ "WRAPYFI_ZEROMQ_PUBSUB_MONITOR_TOPIC", "ZEROMQ/CONNECTIONS"
+)
+ZEROMQ_PUBSUB_MONITOR_LISTENER_SPAWN = os.environ.get(
+ "WRAPYFI_ZEROMQ_PUBSUB_MONITOR_LISTENER_SPAWN", "process"
+)
WATCHDOG_POLL_REPEAT = None
class ZeroMQListener(Listener):
- def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait: bool = True,
- socket_ip: str = SOCKET_IP, socket_pub_port: int = SOCKET_PUB_PORT,
- pubsub_monitor_topic: str = ZEROMQ_PUBSUB_MONITOR_TOPIC,
- pubsub_monitor_listener_spawn: Optional[str] = ZEROMQ_PUBSUB_MONITOR_LISTENER_SPAWN,
- zeromq_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ in_topic: str,
+ carrier: str = "tcp",
+ should_wait: bool = True,
+ socket_ip: str = SOCKET_IP,
+ socket_pub_port: int = SOCKET_PUB_PORT,
+ pubsub_monitor_topic: str = ZEROMQ_PUBSUB_MONITOR_TOPIC,
+ pubsub_monitor_listener_spawn: Optional[
+ str
+ ] = ZEROMQ_PUBSUB_MONITOR_LISTENER_SPAWN,
+ zeromq_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
Initialize the subscriber.
@@ -44,20 +58,28 @@ def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait:
:param kwargs: dict: Additional kwargs for the subscriber
"""
if carrier or carrier != "tcp":
- logging.warning("[ZeroMQ] ZeroMQ does not support other carriers than TCP for PUB/SUB pattern. Using TCP.")
+ logging.warning(
+ "[ZeroMQ] ZeroMQ does not support other carriers than TCP for PUB/SUB pattern. Using TCP."
+ )
carrier = "tcp"
- super().__init__(name, in_topic, carrier=carrier, should_wait=should_wait, **kwargs)
+ super().__init__(
+ name, in_topic, carrier=carrier, should_wait=should_wait, **kwargs
+ )
self.socket_address = f"{carrier}://{socket_ip}:{socket_pub_port}"
- ZeroMQMiddlewarePubSub.activate(socket_pub_address=self.socket_address,
- pubsub_monitor_topic=pubsub_monitor_topic,
- pubsub_monitor_listener_spawn=pubsub_monitor_listener_spawn,
- **zeromq_kwargs or {})
+ ZeroMQMiddlewarePubSub.activate(
+ socket_pub_address=self.socket_address,
+ pubsub_monitor_topic=pubsub_monitor_topic,
+ pubsub_monitor_listener_spawn=pubsub_monitor_listener_spawn,
+ **zeromq_kwargs or {},
+ )
ZeroMQMiddlewarePubSub().shared_monitor_data.add_topic(self.in_topic)
- def await_connection(self, socket=None, in_topic: Optional[str] = None, repeats: Optional[int] = None):
+ def await_connection(
+ self, socket=None, in_topic: Optional[str] = None, repeats: Optional[int] = None
+ ):
"""
Wait for the connection to be established.
@@ -78,7 +100,9 @@ def await_connection(self, socket=None, in_topic: Optional[str] = None, repeats:
while repeats > 0 or repeats <= -1:
repeats -= 1
- connected = ZeroMQMiddlewarePubSub().shared_monitor_data.is_connected(in_topic)
+ connected = ZeroMQMiddlewarePubSub().shared_monitor_data.is_connected(
+ in_topic
+ )
if connected:
logging.info(f"[ZeroMQ] Connected to input port: {in_topic}")
break
@@ -118,8 +142,15 @@ def __del__(self):
@Listeners.register("NativeObject", "zeromq")
class ZeroMQNativeObjectListener(ZeroMQListener):
- def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait: bool = True,
- deserializer_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ in_topic: str,
+ carrier: str = "tcp",
+ should_wait: bool = True,
+ deserializer_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
The NativeObject listener using the ZeroMQ message construct assuming the data is serialized as a JSON string.
Deserializes the data (including plugins) using the decoder and parses it to a native object.
@@ -130,7 +161,9 @@ def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait:
:param should_wait: bool: Whether the subscriber should wait for the publisher to transmit a message. Default is True
:param deserializer_kwargs: dict: Additional kwargs for the deserializer
"""
- super().__init__(name, in_topic, carrier=carrier, should_wait=should_wait, **kwargs)
+ super().__init__(
+ name, in_topic, carrier=carrier, should_wait=should_wait, **kwargs
+ )
self._socket = self._netconnect = None
self._plugin_decoder_hook = JsonDecodeHook(**kwargs).object_hook
@@ -149,9 +182,13 @@ def establish(self, repeats: Optional[int] = None, **kwargs):
self._socket = zmq.Context.instance().socket(zmq.SUB)
for socket_property in ZeroMQMiddlewarePubSub().zeromq_kwargs.items():
if isinstance(socket_property[1], str):
- self._socket.setsockopt_string(getattr(zmq, socket_property[0]), socket_property[1])
+ self._socket.setsockopt_string(
+ getattr(zmq, socket_property[0]), socket_property[1]
+ )
else:
- self._socket.setsockopt(getattr(zmq, socket_property[0]), socket_property[1])
+ self._socket.setsockopt(
+ getattr(zmq, socket_property[0]), socket_property[1]
+ )
self._socket.connect(self.socket_address)
self._topic = self.in_topic.encode("utf-8")
self._socket.setsockopt_string(zmq.SUBSCRIBE, self.in_topic)
@@ -173,7 +210,11 @@ def listen(self):
if self._socket.poll(timeout=None if self.should_wait else 0):
obj = self._socket.recv_multipart()
if obj is not None:
- return json.loads(obj[1].decode(), object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs)
+ return json.loads(
+ obj[1].decode(),
+ object_hook=self._plugin_decoder_hook,
+ **self._deserializer_kwargs,
+ )
else:
return None
else:
@@ -183,8 +224,19 @@ def listen(self):
@Listeners.register("Image", "zeromq")
class ZeroMQImageListener(ZeroMQNativeObjectListener):
- def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait: bool = True,
- width: int = -1, height: int = -1, rgb: bool = True, fp: bool = False, jpg: bool = False, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ in_topic: str,
+ carrier: str = "tcp",
+ should_wait: bool = True,
+ width: int = -1,
+ height: int = -1,
+ rgb: bool = True,
+ fp: bool = False,
+ jpg: bool = False,
+ **kwargs,
+ ):
"""
The Image listener using the ZeroMQ message construct parsed to a numpy array.
@@ -198,7 +250,9 @@ def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait:
:param fp: bool: True if the image is floating point, False if it is integer. Default is False
:param jpg: bool: True if the image should be decompressed from JPG. Default is False
"""
- super().__init__(name, in_topic, carrier=carrier, should_wait=should_wait, **kwargs)
+ super().__init__(
+ name, in_topic, carrier=carrier, should_wait=should_wait, **kwargs
+ )
self.width = width
self.height = height
self.rgb = rgb
@@ -223,14 +277,28 @@ def listen(self):
return None
elif self.jpg:
if self.rgb:
- img = cv2.imdecode(np.frombuffer(obj[2], np.uint8), cv2.IMREAD_COLOR)
+ img = cv2.imdecode(
+ np.frombuffer(obj[2], np.uint8), cv2.IMREAD_COLOR
+ )
else:
- img = cv2.imdecode(np.frombuffer(obj[2], np.uint8), cv2.IMREAD_GRAYSCALE)
+ img = cv2.imdecode(
+ np.frombuffer(obj[2], np.uint8), cv2.IMREAD_GRAYSCALE
+ )
return img
else:
- img = json.loads(obj[2].decode(), object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs)
- if 0 < self.width != img.shape[1] or 0 < self.height != img.shape[0] or \
- not ((img.ndim == 2 and not self.rgb) or (img.ndim == 3 and self.rgb and img.shape[2] == 3)):
+ img = json.loads(
+ obj[2].decode(),
+ object_hook=self._plugin_decoder_hook,
+ **self._deserializer_kwargs,
+ )
+ if (
+ 0 < self.width != img.shape[1]
+ or 0 < self.height != img.shape[0]
+ or not (
+ (img.ndim == 2 and not self.rgb)
+ or (img.ndim == 3 and self.rgb and img.shape[2] == 3)
+ )
+ ):
raise ValueError("Incorrect image shape for listener")
return img
else:
@@ -239,8 +307,17 @@ def listen(self):
@Listeners.register("AudioChunk", "zeromq")
class ZeroMQAudioChunkListener(ZeroMQImageListener):
- def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait: bool = True,
- channels: int = 1, rate: int = 44100, chunk: int = -1, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ in_topic: str,
+ carrier: str = "tcp",
+ should_wait: bool = True,
+ channels: int = 1,
+ rate: int = 44100,
+ chunk: int = -1,
+ **kwargs,
+ ):
"""
The AudioChunk listener using the ZeroMQ message construct parsed to a numpy array.
@@ -253,8 +330,18 @@ def __init__(self, name: str, in_topic: str, carrier: str = "tcp", should_wait:
:param rate: int: Sampling rate of the audio. Default is 44100
:param chunk: int: Number of samples in the audio chunk. Default is -1 (use the chunk size of the received audio)
"""
- super().__init__(name, in_topic, carrier=carrier, should_wait=should_wait,
- width=chunk, height=channels, rgb=False, fp=True, jpg=False, **kwargs)
+ super().__init__(
+ name,
+ in_topic,
+ carrier=carrier,
+ should_wait=should_wait,
+ width=chunk,
+ height=channels,
+ rgb=False,
+ fp=True,
+ jpg=False,
+ **kwargs,
+ )
self.channels = channels
self.rate = rate
self.chunk = chunk
@@ -271,10 +358,22 @@ def listen(self):
return None
if self._socket.poll(timeout=None if self.should_wait else 0):
obj = self._socket.recv_multipart()
- chunk, channels, rate, aud = json.loads(obj[2].decode(), object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs) if obj is not None else None
+ chunk, channels, rate, aud = (
+ json.loads(
+ obj[2].decode(),
+ object_hook=self._plugin_decoder_hook,
+ **self._deserializer_kwargs,
+ )
+ if obj is not None
+ else None
+ )
if 0 < self.rate != rate:
raise ValueError("Incorrect audio rate for listener")
- if 0 < self.chunk != chunk or self.channels != channels or aud.size != chunk * channels:
+ if (
+ 0 < self.chunk != chunk
+ or self.channels != channels
+ or aud.size != chunk * channels
+ ):
raise ValueError("Incorrect audio shape for listener")
return aud, rate
else:
diff --git a/wrapyfi/middlewares/ros.py b/wrapyfi/middlewares/ros.py
index 56e9441..1f77e16 100755
--- a/wrapyfi/middlewares/ros.py
+++ b/wrapyfi/middlewares/ros.py
@@ -17,6 +17,7 @@ class ROSMiddleware(metaclass=SingletonOptimized):
and destroy all connections. The ``activate`` and ``deinit`` methods are automatically called when the class is
instantiated and when the program exits, respectively.
"""
+
@staticmethod
def activate(**kwargs):
"""
@@ -26,7 +27,14 @@ def activate(**kwargs):
"""
ROSMiddleware(**kwargs)
- def __init__(self, node_name: str = "wrapyfi", anonymous: bool = True, disable_signals: bool = True, *args, **kwargs):
+ def __init__(
+ self,
+ node_name: str = "wrapyfi",
+ anonymous: bool = True,
+ disable_signals: bool = True,
+ *args,
+ **kwargs,
+ ):
"""
Initialize the ROS middleware. This method is automatically called when the class is instantiated.
@@ -47,18 +55,18 @@ def deinit():
Deinitialize the ROS middleware. This method is automatically called when the program exits.
"""
logging.info("Deinitializing ROS middleware")
- rospy.signal_shutdown('Deinit')
+ rospy.signal_shutdown("Deinit")
class ROSNativeObjectService(object):
- _type = 'wrapyfi_services/ROSNativeObject'
- _md5sum = '46a550fd1ca640b396e26ebf988aed7b' # AddTwoInts '6a2e34150c00229791cc89ff309fff21'
- _request_class = std_msgs.msg.String
- _response_class = std_msgs.msg.String
+ _type = "wrapyfi_services/ROSNativeObject"
+ _md5sum = "46a550fd1ca640b396e26ebf988aed7b" # AddTwoInts '6a2e34150c00229791cc89ff309fff21'
+ _request_class = std_msgs.msg.String
+ _response_class = std_msgs.msg.String
class ROSImageService(object):
- _type = 'wrapyfi_services/ROSImage'
- _md5sum = 'f720f2021b4bbbe86b0f93b08906381c' # AddTwoInts '6a2e34150c00229791cc89ff309fff21'
- _request_class = std_msgs.msg.String
- _response_class = sensor_msgs.msg.Image
+ _type = "wrapyfi_services/ROSImage"
+ _md5sum = "f720f2021b4bbbe86b0f93b08906381c" # AddTwoInts '6a2e34150c00229791cc89ff309fff21'
+ _request_class = std_msgs.msg.String
+ _response_class = sensor_msgs.msg.Image
diff --git a/wrapyfi/middlewares/ros2.py b/wrapyfi/middlewares/ros2.py
index 6786dc1..b1abd50 100755
--- a/wrapyfi/middlewares/ros2.py
+++ b/wrapyfi/middlewares/ros2.py
@@ -45,5 +45,6 @@ def deinit():
if rclpy.ok():
rclpy.shutdown()
else:
- logging.info("ROS 2 context is already shutdown or not initialized. Skipping shutdown.")
-
+ logging.info(
+ "ROS 2 context is already shutdown or not initialized. Skipping shutdown."
+ )
diff --git a/wrapyfi/middlewares/yarp.py b/wrapyfi/middlewares/yarp.py
index 96ba3b0..c0d300b 100755
--- a/wrapyfi/middlewares/yarp.py
+++ b/wrapyfi/middlewares/yarp.py
@@ -14,6 +14,7 @@ class YarpMiddleware(metaclass=SingletonOptimized):
and destroy all connections. The ``activate`` and ``deinit`` methods are automatically called when the class is
instantiated and when the program exits, respectively.
"""
+
@staticmethod
def activate(**kwargs):
"""
diff --git a/wrapyfi/middlewares/zeromq.py b/wrapyfi/middlewares/zeromq.py
index 59c42cc..bc70aae 100644
--- a/wrapyfi/middlewares/zeromq.py
+++ b/wrapyfi/middlewares/zeromq.py
@@ -12,8 +12,20 @@
from wrapyfi.utils import SingletonOptimized
from wrapyfi.connect.wrapper import MiddlewareCommunicator
-ZEROMQ_POST_OPTS = ["SUBSCRIBE", "UNSUBSCRIBE", "LINGER", "ROUTER_HANDOVER", "ROUTER_MANDATORY", "PROBE_ROUTER",
- "XPUB_VERBOSE", "XPUB_VERBOSER", "REQ_CORRELATE", "REQ_RELAXED", "SNDHWM", "RCVHWM"]
+ZEROMQ_POST_OPTS = [
+ "SUBSCRIBE",
+ "UNSUBSCRIBE",
+ "LINGER",
+ "ROUTER_HANDOVER",
+ "ROUTER_MANDATORY",
+ "PROBE_ROUTER",
+ "XPUB_VERBOSE",
+ "XPUB_VERBOSER",
+ "REQ_CORRELATE",
+ "REQ_RELAXED",
+ "SNDHWM",
+ "RCVHWM",
+]
class ZeroMQMiddlewarePubSub(metaclass=SingletonOptimized):
@@ -135,9 +147,18 @@ def activate(**kwargs):
except AttributeError:
pass
- ZeroMQMiddlewarePubSub(zeromq_proxy_kwargs=kwargs, zeromq_post_kwargs=zeromq_post_kwargs, **zeromq_pre_kwargs)
-
- def __init__(self, zeromq_proxy_kwargs: Optional[dict] = None, zeromq_post_kwargs: Optional[dict] = None, **kwargs):
+ ZeroMQMiddlewarePubSub(
+ zeromq_proxy_kwargs=kwargs,
+ zeromq_post_kwargs=zeromq_post_kwargs,
+ **zeromq_pre_kwargs,
+ )
+
+ def __init__(
+ self,
+ zeromq_proxy_kwargs: Optional[dict] = None,
+ zeromq_post_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
Initialize the ZeroMQ PUB/SUB middleware. This method is automatically called when the class is instantiated.
@@ -151,9 +172,13 @@ def __init__(self, zeromq_proxy_kwargs: Optional[dict] = None, zeromq_post_kwarg
self.ctx = zmq.Context.instance()
for socket_property in kwargs.items():
if isinstance(socket_property[1], str):
- self.ctx.setsockopt_string(getattr(zmq, socket_property[0]), socket_property[1])
+ self.ctx.setsockopt_string(
+ getattr(zmq, socket_property[0]), socket_property[1]
+ )
else:
- self.ctx.setsockopt(getattr(zmq, socket_property[0]), socket_property[1])
+ self.ctx.setsockopt(
+ getattr(zmq, socket_property[0]), socket_property[1]
+ )
atexit.register(MiddlewareCommunicator.close_all_instances)
atexit.register(self.deinit)
@@ -162,38 +187,59 @@ def __init__(self, zeromq_proxy_kwargs: Optional[dict] = None, zeromq_post_kwarg
# start the pubsub monitor listener
if zeromq_proxy_kwargs.get("pubsub_monitor_listener_spawn", False):
if zeromq_proxy_kwargs["pubsub_monitor_listener_spawn"] == "process":
- self.shared_monitor_data = self.ZeroMQSharedMonitorData(use_multiprocessing=True)
- self.monitor = multiprocessing.Process(name='zeromq_pubsub_monitor_listener',
- target=self.__init_monitor_listener,
- kwargs=zeromq_proxy_kwargs)
+ self.shared_monitor_data = self.ZeroMQSharedMonitorData(
+ use_multiprocessing=True
+ )
+ self.monitor = multiprocessing.Process(
+ name="zeromq_pubsub_monitor_listener",
+ target=self.__init_monitor_listener,
+ kwargs=zeromq_proxy_kwargs,
+ )
self.monitor.daemon = True
self.monitor.start()
else: # if threaded
- self.shared_monitor_data = self.ZeroMQSharedMonitorData(use_multiprocessing=False)
- self.monitor = threading.Thread(name='pubsub_monitor_listener_spawn',
- target=self.__init_monitor_listener, kwargs=zeromq_proxy_kwargs)
- self.monitor.setDaemon(True) # deprecation warning Python3.10. Previous Python versions only support this
+ self.shared_monitor_data = self.ZeroMQSharedMonitorData(
+ use_multiprocessing=False
+ )
+ self.monitor = threading.Thread(
+ name="pubsub_monitor_listener_spawn",
+ target=self.__init_monitor_listener,
+ kwargs=zeromq_proxy_kwargs,
+ )
+ self.monitor.setDaemon(
+ True
+ ) # deprecation warning Python3.10. Previous Python versions only support this
self.monitor.start()
time.sleep(1)
if zeromq_proxy_kwargs.get("start_proxy_broker", False):
if zeromq_proxy_kwargs["proxy_broker_spawn"] == "process":
- self.proxy = multiprocessing.Process(name='zeromq_pubsub_broker', target=self.__init_proxy,
- kwargs=zeromq_proxy_kwargs)
+ self.proxy = multiprocessing.Process(
+ name="zeromq_pubsub_broker",
+ target=self.__init_proxy,
+ kwargs=zeromq_proxy_kwargs,
+ )
self.proxy.daemon = True
self.proxy.start()
else: # if threaded
- self.proxy = threading.Thread(name='zeromq_pubsub_broker', target=self.__init_proxy,
- kwargs=zeromq_proxy_kwargs)
- self.proxy.setDaemon(True) # deprecation warning Python3.10. Previous Python versions only support this
+ self.proxy = threading.Thread(
+ name="zeromq_pubsub_broker",
+ target=self.__init_proxy,
+ kwargs=zeromq_proxy_kwargs,
+ )
+ self.proxy.setDaemon(
+ True
+ ) # deprecation warning Python3.10. Previous Python versions only support this
self.proxy.start()
pass
@staticmethod
- def proxy_thread(socket_pub_address: str = "tcp://127.0.0.1:5555",
- socket_sub_address: str = "tcp://127.0.0.1:5556",
- inproc_address: str = "inproc://monitor"):
+ def proxy_thread(
+ socket_pub_address: str = "tcp://127.0.0.1:5555",
+ socket_sub_address: str = "tcp://127.0.0.1:5556",
+ inproc_address: str = "inproc://monitor",
+ ):
"""
Proxy thread for the ZeroMQ PUB/SUB proxy.
@@ -224,12 +270,17 @@ def proxy_thread(socket_pub_address: str = "tcp://127.0.0.1:5555",
if str(e) == "Socket operation on non-socket":
pass
else:
- logging.error(f"[ZeroMQ BROKER] An error occurred in the ZeroMQ proxy: {str(e)}.")
+ logging.error(
+ f"[ZeroMQ BROKER] An error occurred in the ZeroMQ proxy: {str(e)}."
+ )
@staticmethod
- def subscription_monitor_thread(inproc_address: str = "inproc://monitor",
- socket_sub_address: str = "tcp://127.0.0.1:5556",
- pubsub_monitor_topic: str = "ZEROMQ/CONNECTIONS", verbose: bool = False):
+ def subscription_monitor_thread(
+ inproc_address: str = "inproc://monitor",
+ socket_sub_address: str = "tcp://127.0.0.1:5556",
+ pubsub_monitor_topic: str = "ZEROMQ/CONNECTIONS",
+ verbose: bool = False,
+ ):
"""
Subscription monitor thread for the ZeroMQ PUB/SUB proxy.
@@ -259,10 +310,12 @@ def subscription_monitor_thread(inproc_address: str = "inproc://monitor",
# ensure the message is a subscription/unsubscription message
if len(message) > 1 and (message[0] == 1 or message[0] == 0):
event = message[0]
- topic = message[1:].decode('utf-8')
+ topic = message[1:].decode("utf-8")
if verbose:
- logging.info(f"[ZeroMQ BROKER] Received event: {event}, topic: {topic}")
+ logging.info(
+ f"[ZeroMQ BROKER] Received event: {event}, topic: {topic}"
+ )
# avoid processing messages on the monitor topic.
if topic == pubsub_monitor_topic:
@@ -275,19 +328,29 @@ def subscription_monitor_thread(inproc_address: str = "inproc://monitor",
topic_subscriber_count[topic] = 0
if verbose:
- logging.info(f"[ZeroMQ BROKER] Current topic subscriber count: {dict(topic_subscriber_count)}")
+ logging.info(
+ f"[ZeroMQ BROKER] Current topic subscriber count: {dict(topic_subscriber_count)}"
+ )
# publish the updated counts
publisher.send_multipart(
- [pubsub_monitor_topic.encode(), json.dumps(dict(topic_subscriber_count)).encode()]
+ [
+ pubsub_monitor_topic.encode(),
+ json.dumps(dict(topic_subscriber_count)).encode(),
+ ]
)
except Exception as e:
- logging.error(f"[ZeroMQ BROKER] An error occurred in the ZeroMQ subscription monitor: {str(e)}")
-
- def __init_proxy(self, socket_pub_address: str = "tcp://127.0.0.1:5555",
- socket_sub_address: str = "tcp://127.0.0.1:5556",
- pubsub_monitor_topic: str = "ZEROMQ/CONNECTIONS",
- **kwargs):
+ logging.error(
+ f"[ZeroMQ BROKER] An error occurred in the ZeroMQ subscription monitor: {str(e)}"
+ )
+
+ def __init_proxy(
+ self,
+ socket_pub_address: str = "tcp://127.0.0.1:5555",
+ socket_sub_address: str = "tcp://127.0.0.1:5556",
+ pubsub_monitor_topic: str = "ZEROMQ/CONNECTIONS",
+ **kwargs,
+ ):
"""
Initialize the ZeroMQ PUB/SUB proxy and subscription monitor.
@@ -298,20 +361,32 @@ def __init_proxy(self, socket_pub_address: str = "tcp://127.0.0.1:5555",
"""
inproc_address = "inproc://monitor"
- threading.Thread(target=self.proxy_thread,
- kwargs={"socket_pub_address": socket_pub_address,
- "socket_sub_address": socket_sub_address,
- "inproc_address": inproc_address}).start(),
-
- threading.Thread(target=self.subscription_monitor_thread,
- kwargs={"socket_sub_address": socket_sub_address,
- "inproc_address": inproc_address,
- "pubsub_monitor_topic": pubsub_monitor_topic,
- "verbose": kwargs.get("verbose", False)}).start()
-
- def __init_monitor_listener(self, socket_pub_address: str = "tcp://127.0.0.1:5555",
- pubsub_monitor_topic: str = "ZEROMQ/CONNECTIONS",
- verbose: bool = False, **kwargs):
+ threading.Thread(
+ target=self.proxy_thread,
+ kwargs={
+ "socket_pub_address": socket_pub_address,
+ "socket_sub_address": socket_sub_address,
+ "inproc_address": inproc_address,
+ },
+ ).start(),
+
+ threading.Thread(
+ target=self.subscription_monitor_thread,
+ kwargs={
+ "socket_sub_address": socket_sub_address,
+ "inproc_address": inproc_address,
+ "pubsub_monitor_topic": pubsub_monitor_topic,
+ "verbose": kwargs.get("verbose", False),
+ },
+ ).start()
+
+ def __init_monitor_listener(
+ self,
+ socket_pub_address: str = "tcp://127.0.0.1:5555",
+ pubsub_monitor_topic: str = "ZEROMQ/CONNECTIONS",
+ verbose: bool = False,
+ **kwargs,
+ ):
"""
Initialize the ZeroMQ PUB/SUB monitor listener.
@@ -328,7 +403,7 @@ def __init_monitor_listener(self, socket_pub_address: str = "tcp://127.0.0.1:555
while True:
_, message = subscriber.recv_multipart()
- data = json.loads(message.decode('utf-8'))
+ data = json.loads(message.decode("utf-8"))
topic = list(data.keys())[0]
if verbose:
logging.info(f"[ZeroMQ] Topic: {topic}, Data: {data}")
@@ -336,17 +411,23 @@ def __init_monitor_listener(self, socket_pub_address: str = "tcp://127.0.0.1:555
if topic in self.shared_monitor_data.get_topics():
self.shared_monitor_data.update_connection(topic, data)
if data[topic] == 0:
- logging.info(f"[ZeroMQ] Subscriber disconnected from topic: {topic}")
+ logging.info(
+ f"[ZeroMQ] Subscriber disconnected from topic: {topic}"
+ )
self.shared_monitor_data.remove_connection(topic)
else:
logging.info(f"[ZeroMQ] Subscriber connected to topic: {topic}")
if verbose:
for monitored_topic in self.shared_monitor_data.get_topics():
- logging.info(f"[ZeroMQ] Monitored topic from main process: {monitored_topic}")
+ logging.info(
+ f"[ZeroMQ] Monitored topic from main process: {monitored_topic}"
+ )
except Exception as e:
- logging.error(f"[ZeroMQ] An error occurred in the ZeroMQ subscription monitor listener: {str(e)}")
+ logging.error(
+ f"[ZeroMQ] An error occurred in the ZeroMQ subscription monitor listener: {str(e)}"
+ )
@staticmethod
def deinit():
@@ -381,10 +462,19 @@ def activate(**kwargs):
except AttributeError:
pass
- ZeroMQMiddlewareReqRep(zeromq_proxy_kwargs=kwargs, zeromq_post_kwargs=zeromq_post_kwargs, **zeromq_pre_kwargs)
-
- def __init__(self, zeromq_proxy_kwargs: Optional[dict] = None, zeromq_post_kwargs: Optional[dict] = None, *args,
- **kwargs):
+ ZeroMQMiddlewareReqRep(
+ zeromq_proxy_kwargs=kwargs,
+ zeromq_post_kwargs=zeromq_post_kwargs,
+ **zeromq_pre_kwargs,
+ )
+
+ def __init__(
+ self,
+ zeromq_proxy_kwargs: Optional[dict] = None,
+ zeromq_post_kwargs: Optional[dict] = None,
+ *args,
+ **kwargs,
+ ):
"""
Initialize the ZeroMQ REQ/REP middleware. This method is automatically called when the class is instantiated.
@@ -399,28 +489,43 @@ def __init__(self, zeromq_proxy_kwargs: Optional[dict] = None, zeromq_post_kwarg
self.ctx = zmq.Context.instance()
for socket_property in kwargs.items():
if isinstance(socket_property[1], str):
- self.ctx.setsockopt_string(getattr(zmq, socket_property[0]), socket_property[1])
+ self.ctx.setsockopt_string(
+ getattr(zmq, socket_property[0]), socket_property[1]
+ )
else:
- self.ctx.setsockopt(getattr(zmq, socket_property[0]), socket_property[1])
+ self.ctx.setsockopt(
+ getattr(zmq, socket_property[0]), socket_property[1]
+ )
atexit.register(MiddlewareCommunicator.close_all_instances)
atexit.register(self.deinit)
if zeromq_proxy_kwargs is not None and zeromq_proxy_kwargs:
if zeromq_proxy_kwargs["proxy_broker_spawn"] == "process":
- self.proxy = multiprocessing.Process(name='zeromq_reqrep_broker', target=self.__init_device,
- kwargs=zeromq_proxy_kwargs)
+ self.proxy = multiprocessing.Process(
+ name="zeromq_reqrep_broker",
+ target=self.__init_device,
+ kwargs=zeromq_proxy_kwargs,
+ )
self.proxy.daemon = True
self.proxy.start()
else: # if threaded
- self.proxy = threading.Thread(name='zeromq_reqrep_broker', target=self.__init_device,
- kwargs=zeromq_proxy_kwargs)
- self.proxy.setDaemon(True) # deprecation warning Python3.10. Previous Python versions only support this
+ self.proxy = threading.Thread(
+ name="zeromq_reqrep_broker",
+ target=self.__init_device,
+ kwargs=zeromq_proxy_kwargs,
+ )
+ self.proxy.setDaemon(
+ True
+ ) # deprecation warning Python3.10. Previous Python versions only support this
self.proxy.start()
pass
@staticmethod
- def __init_device(socket_rep_address: str = "tcp://127.0.0.1:5559",
- socket_req_address: str = "tcp://127.0.0.1:5560", **kwargs):
+ def __init_device(
+ socket_rep_address: str = "tcp://127.0.0.1:5559",
+ socket_req_address: str = "tcp://127.0.0.1:5560",
+ **kwargs,
+ ):
"""
Initialize the ZeroMQ REQ/REP device broker.
@@ -453,7 +558,9 @@ def __init_device(socket_rep_address: str = "tcp://127.0.0.1:5559",
if str(e) == "Socket operation on non-socket":
pass
else:
- logging.error(f"[ZeroMQ] An error occurred in the ZeroMQ proxy: {str(e)}.")
+ logging.error(
+ f"[ZeroMQ] An error occurred in the ZeroMQ proxy: {str(e)}."
+ )
@staticmethod
def deinit():
@@ -485,11 +592,19 @@ def activate(**kwargs):
except AttributeError:
pass
- ZeroMQMiddlewareParamServer(zeromq_proxy_kwargs=kwargs, zeromq_post_kwargs=zeromq_post_kwargs,
- **zeromq_pre_kwargs)
-
- def __init__(self, zeromq_proxy_kwargs: Optional[dict] = None, zeromq_post_kwargs: Optional = None, *args,
- **kwargs):
+ ZeroMQMiddlewareParamServer(
+ zeromq_proxy_kwargs=kwargs,
+ zeromq_post_kwargs=zeromq_post_kwargs,
+ **zeromq_pre_kwargs,
+ )
+
+ def __init__(
+ self,
+ zeromq_proxy_kwargs: Optional[dict] = None,
+ zeromq_post_kwargs: Optional = None,
+ *args,
+ **kwargs,
+ ):
"""
Initialize the ZeroMQ parameter server middleware. This method is automatically called when the class is
instantiated.
@@ -504,9 +619,13 @@ def __init__(self, zeromq_proxy_kwargs: Optional[dict] = None, zeromq_post_kwarg
self.ctx = zmq.Context.instance()
for socket_property in kwargs.items():
if isinstance(socket_property[1], str):
- self.ctx.setsockopt_string(getattr(zmq, socket_property[0]), socket_property[1])
+ self.ctx.setsockopt_string(
+ getattr(zmq, socket_property[0]), socket_property[1]
+ )
else:
- self.ctx.setsockopt(getattr(zmq, socket_property[0]), socket_property[1])
+ self.ctx.setsockopt(
+ getattr(zmq, socket_property[0]), socket_property[1]
+ )
atexit.register(MiddlewareCommunicator.close_all_instances)
atexit.register(self.deinit)
@@ -516,32 +635,55 @@ def __init__(self, zeromq_proxy_kwargs: Optional[dict] = None, zeromq_post_kwarg
self.params = self.manager.dict()
self.params["WRAPYFI_ACTIVE"] = "True"
if zeromq_proxy_kwargs["proxy_broker_spawn"] == "process":
- self.param_broadcaster = multiprocessing.Process(name='zeromq_param_broadcaster',
- target=self.__init_broadcaster,
- kwargs=zeromq_proxy_kwargs, args=(self.params,))
+ self.param_broadcaster = multiprocessing.Process(
+ name="zeromq_param_broadcaster",
+ target=self.__init_broadcaster,
+ kwargs=zeromq_proxy_kwargs,
+ args=(self.params,),
+ )
self.param_broadcaster.daemon = True
self.param_broadcaster.start()
- self.param_server = multiprocessing.Process(name='zeromq_param_server', target=self.__init_server,
- kwargs=zeromq_proxy_kwargs, args=(self.params,))
+ self.param_server = multiprocessing.Process(
+ name="zeromq_param_server",
+ target=self.__init_server,
+ kwargs=zeromq_proxy_kwargs,
+ args=(self.params,),
+ )
self.param_server.daemon = True
self.param_server.start()
else: # if threaded
- self.param_broadcaster = threading.Thread(name='zeromq_param_broadcaster',
- target=self.__init_broadcaster,
- kwargs=zeromq_proxy_kwargs, args=(self.params,))
- self.param_broadcaster.setDaemon(True) # deprecation warning Python3.10. Previous Python versions only support this
+ self.param_broadcaster = threading.Thread(
+ name="zeromq_param_broadcaster",
+ target=self.__init_broadcaster,
+ kwargs=zeromq_proxy_kwargs,
+ args=(self.params,),
+ )
+ self.param_broadcaster.setDaemon(
+ True
+ ) # deprecation warning Python3.10. Previous Python versions only support this
self.param_broadcaster.start()
- self.param_server = threading.Thread(name='zeromq_param_server', target=self.__init_server,
- kwargs=zeromq_proxy_kwargs, args=(self.params,))
- self.param_server.setDaemon(True) # deprecation warning Python3.10. Previous Python versions only support this
+ self.param_server = threading.Thread(
+ name="zeromq_param_server",
+ target=self.__init_server,
+ kwargs=zeromq_proxy_kwargs,
+ args=(self.params,),
+ )
+ self.param_server.setDaemon(
+ True
+ ) # deprecation warning Python3.10. Previous Python versions only support this
self.param_server.start()
pass
@staticmethod
- def __init_broadcaster(params, param_pub_address: str = "tcp://127.0.0.1:5655",
- param_sub_address: str = "tcp://127.0.0.1:5656",
- param_poll_interval=1, verbose=False, **kwargs):
+ def __init_broadcaster(
+ params,
+ param_pub_address: str = "tcp://127.0.0.1:5655",
+ param_sub_address: str = "tcp://127.0.0.1:5656",
+ param_poll_interval=1,
+ verbose=False,
+ **kwargs,
+ ):
"""
Initialize the ZeroMQ parameter server broadcaster.
@@ -580,7 +722,9 @@ def __init_broadcaster(params, param_pub_address: str = "tcp://127.0.0.1:5655",
if xpub_socket in event:
message = xpub_socket.recv_multipart()
if verbose:
- logging.info("[ZeroMQ BROKER] xpub_socket recv message: %r" % message)
+ logging.info(
+ "[ZeroMQ BROKER] xpub_socket recv message: %r" % message
+ )
if message[0].startswith(b"\x00"):
root_topics.remove(message[0][1:].decode("utf-8"))
elif message[0].startswith(b"\x01"):
@@ -590,29 +734,52 @@ def __init_broadcaster(params, param_pub_address: str = "tcp://127.0.0.1:5655",
if xsub_socket in event:
message = xsub_socket.recv_multipart()
if verbose:
- logging.info("[ZeroMQ BROKER] xsub_socket recv message: %r" % message)
+ logging.info(
+ "[ZeroMQ BROKER] xsub_socket recv message: %r" % message
+ )
if message[0].startswith(b"\x01") or message[0].startswith(b"\x00"):
xpub_socket.send_multipart(message)
else:
fltr_key = message[0].decode("utf-8")
- fltr_message = {key: val for key, val in params.items()
- if key.startswith(fltr_key)}
+ fltr_message = {
+ key: val
+ for key, val in params.items()
+ if key.startswith(fltr_key)
+ }
if verbose:
- logging.info("[ZeroMQ BROKER] xsub_socket filtered message: %r" % fltr_message)
+ logging.info(
+ "[ZeroMQ BROKER] xsub_socket filtered message: %r"
+ % fltr_message
+ )
for key, val in fltr_message.items():
prefix, param = key.rsplit("/", 1) if "/" in key else ("", key)
- xpub_socket.send_multipart([prefix.encode("utf-8"), param.encode("utf-8"), val.encode("utf-8")])
+ xpub_socket.send_multipart(
+ [
+ prefix.encode("utf-8"),
+ param.encode("utf-8"),
+ val.encode("utf-8"),
+ ]
+ )
# xpub_socket.send_multipart(message)
if event:
update_trigger = True
if param_server is not None:
- update_trigger, cached_params = ZeroMQMiddlewareParamServer.publish_params(
- param_server, params, cached_params, root_topics, update_trigger)
+ update_trigger, cached_params = (
+ ZeroMQMiddlewareParamServer.publish_params(
+ param_server, params, cached_params, root_topics, update_trigger
+ )
+ )
@staticmethod
- def publish_params(param_server, params: dict, cached_params: dict, root_topics: set, update_trigger: bool):
+ def publish_params(
+ param_server,
+ params: dict,
+ cached_params: dict,
+ root_topics: set,
+ update_trigger: bool,
+ ):
"""
Publish parameters to the parameter server.
@@ -637,12 +804,16 @@ def publish_params(param_server, params: dict, cached_params: dict, root_topics:
# publish updates for all parameters to subscribed clients
for key, val in params.items():
prefix, param = key.rsplit("/", 1) if "/" in key else ("", key)
- param_server.send_multipart([prefix.encode("utf-8"), param.encode("utf-8"), val.encode("utf-8")])
+ param_server.send_multipart(
+ [prefix.encode("utf-8"), param.encode("utf-8"), val.encode("utf-8")]
+ )
return update_trigger, cached_params
@staticmethod
- def __init_server(params: dict, param_reqrep_address: str = "tcp://127.0.0.1:5659", **kwargs):
+ def __init_server(
+ params: dict, param_reqrep_address: str = "tcp://127.0.0.1:5659", **kwargs
+ ):
"""
Initialize the ZeroMQ parameter server.
@@ -658,7 +829,11 @@ def __init_server(params: dict, param_reqrep_address: str = "tcp://127.0.0.1:565
if request.startswith("get"):
try:
# extract the parameter name and namespace prefix from the request
- prefix, param = request[4:].rsplit("/", 1) if "/" in request[4:] else ("", request[4:])
+ prefix, param = (
+ request[4:].rsplit("/", 1)
+ if "/" in request[4:]
+ else ("", request[4:])
+ )
# construct the full parameter name with the namespace prefix
full_param = "/".join([prefix, param]) if prefix else param
if full_param in params:
diff --git a/wrapyfi/plugins/cupy_array.py b/wrapyfi/plugins/cupy_array.py
index f4aac63..00e318b 100644
--- a/wrapyfi/plugins/cupy_array.py
+++ b/wrapyfi/plugins/cupy_array.py
@@ -27,6 +27,7 @@
try:
import cupy as cp
+
HAVE_CUPY = True
except ImportError:
HAVE_CUPY = False
@@ -75,7 +76,9 @@ def __init__(self, load_cupy_device=None, map_cupy_devices=None, **kwargs):
self.map_cupy_devices = map_cupy_devices or {}
if load_cupy_device is not None:
self.map_cupy_devices["default"] = load_cupy_device
- self.map_cupy_devices = {k: cupy_device_to_str(v) for k, v in self.map_cupy_devices.items()}
+ self.map_cupy_devices = {
+ k: cupy_device_to_str(v) for k, v in self.map_cupy_devices.items()
+ }
def encode(self, obj, *args, **kwargs):
"""
@@ -86,9 +89,11 @@ def encode(self, obj, *args, **kwargs):
"""
with io.BytesIO() as memfile:
np.save(memfile, cp.asnumpy(obj))
- obj_data = base64.b64encode(memfile.getvalue()).decode('ascii')
+ obj_data = base64.b64encode(memfile.getvalue()).decode("ascii")
obj_device = cupy_device_to_str(obj.device)
- return True, dict(__wrapyfi__=(str(self.__class__.__name__), obj_data, obj_device))
+ return True, dict(
+ __wrapyfi__=(str(self.__class__.__name__), obj_data, obj_device)
+ )
def decode(self, obj_type, obj_full, *args, **kwargs):
"""
@@ -97,9 +102,10 @@ def decode(self, obj_type, obj_full, *args, **kwargs):
:param obj_full: tuple: A tuple containing the encoded data string and device string
:return: Tuple[bool, cp.ndarray]
"""
- with io.BytesIO(base64.b64decode(obj_full[1].encode('ascii'))) as memfile:
- obj_device_str = self.map_cupy_devices.get(obj_full[2], self.map_cupy_devices.get("default", "cuda:0"))
+ with io.BytesIO(base64.b64decode(obj_full[1].encode("ascii"))) as memfile:
+ obj_device_str = self.map_cupy_devices.get(
+ obj_full[2], self.map_cupy_devices.get("default", "cuda:0")
+ )
obj_device = cupy_str_to_device(obj_device_str)
with obj_device:
return True, cp.array(np.load(memfile))
-
diff --git a/wrapyfi/plugins/dask_data.py b/wrapyfi/plugins/dask_data.py
index f8b3e40..e206168 100644
--- a/wrapyfi/plugins/dask_data.py
+++ b/wrapyfi/plugins/dask_data.py
@@ -29,12 +29,19 @@
import dask.dataframe as dd
import dask.array as da
import pandas as pd
+
HAVE_DASK = True
except ImportError:
HAVE_DASK = False
-@PluginRegistrar.register(types=None if not HAVE_DASK else dd.DataFrame.__mro__[:-1] + dd.Series.__mro__[:-1] + da.Array.__mro__[:-1])
+@PluginRegistrar.register(
+ types=(
+ None
+ if not HAVE_DASK
+ else dd.DataFrame.__mro__[:-1] + dd.Series.__mro__[:-1] + da.Array.__mro__[:-1]
+ )
+)
class DaskData(Plugin):
def __init__(self, **kwargs):
"""
@@ -59,20 +66,27 @@ def encode(self, obj, *args, **kwargs):
with io.BytesIO() as memfile:
if isinstance(obj, dd.DataFrame):
pandas_df = obj.compute().reset_index()
- memfile.write(pandas_df.to_json(orient="records").encode('ascii'))
+ memfile.write(pandas_df.to_json(orient="records").encode("ascii"))
obj_partitions = obj.npartitions
- obj_type = 'DataFrame'
+ obj_type = "DataFrame"
elif isinstance(obj, dd.Series):
pandas_ds = obj.compute().reset_index()
- memfile.write(pandas_ds.to_json(orient="records").encode('ascii'))
+ memfile.write(pandas_ds.to_json(orient="records").encode("ascii"))
obj_partitions = obj.npartitions
- obj_type = 'Series'
+ obj_type = "Series"
elif isinstance(obj, da.Array):
np.save(memfile, obj.compute(), allow_pickle=True)
- obj_type = 'Array'
+ obj_type = "Array"
memfile.seek(0)
- obj_data = base64.b64encode(memfile.read()).decode('ascii')
- return True, dict(__wrapyfi__=(str(self.__class__.__name__), obj_data, obj_type, obj_partitions))
+ obj_data = base64.b64encode(memfile.read()).decode("ascii")
+ return True, dict(
+ __wrapyfi__=(
+ str(self.__class__.__name__),
+ obj_data,
+ obj_type,
+ obj_partitions,
+ )
+ )
def decode(self, obj_type, obj_full, *args, **kwargs):
"""
@@ -86,18 +100,26 @@ def decode(self, obj_type, obj_full, *args, **kwargs):
- bool: Always True, indicating that the decoding was successful
- Union[dd.DataFrame, da.Array]: The decoded Dask data
"""
- obj_data = base64.b64decode(obj_full[1].encode('ascii'))
+ obj_data = base64.b64decode(obj_full[1].encode("ascii"))
obj_type = obj_full[2]
obj_partitions = obj_full[3]
with io.BytesIO(obj_data) as memfile:
- if obj_type == 'DataFrame':
- pandas_df = pd.read_json(memfile.read().decode('ascii'), orient="records")
- pandas_df.set_index('index', inplace=True) # Set the index back after reading from JSON
+ if obj_type == "DataFrame":
+ pandas_df = pd.read_json(
+ memfile.read().decode("ascii"), orient="records"
+ )
+ pandas_df.set_index(
+ "index", inplace=True
+ ) # Set the index back after reading from JSON
return True, dd.from_pandas(pandas_df, npartitions=obj_partitions)
- elif obj_type == 'Series':
- pandas_ds = pd.read_json(memfile.read().decode('ascii'), orient="records")
- pandas_ds.set_index('index', inplace=True) # Set the index back after reading from JSON
+ elif obj_type == "Series":
+ pandas_ds = pd.read_json(
+ memfile.read().decode("ascii"), orient="records"
+ )
+ pandas_ds.set_index(
+ "index", inplace=True
+ ) # Set the index back after reading from JSON
return True, dd.from_pandas(pandas_ds, npartitions=obj_partitions)
- elif obj_type == 'Array':
+ elif obj_type == "Array":
np_array = np.load(memfile, allow_pickle=True)
- return True, da.from_array(np_array, chunks=np_array.shape)
\ No newline at end of file
+ return True, da.from_array(np_array, chunks=np_array.shape)
diff --git a/wrapyfi/plugins/jax_tensor.py b/wrapyfi/plugins/jax_tensor.py
index fdf43ff..d7d3bf1 100644
--- a/wrapyfi/plugins/jax_tensor.py
+++ b/wrapyfi/plugins/jax_tensor.py
@@ -26,6 +26,7 @@
try:
import jax
+
HAVE_JAX = True
except ImportError:
HAVE_JAX = False
@@ -59,7 +60,7 @@ def encode(self, obj, *args, **kwargs):
"""
with io.BytesIO() as memfile:
np.save(memfile, np.asarray(obj))
- obj_data = base64.b64encode(memfile.getvalue()).decode('ascii')
+ obj_data = base64.b64encode(memfile.getvalue()).decode("ascii")
return True, dict(__wrapyfi__=(str(self.__class__.__name__), obj_data))
def decode(self, obj_type, obj_full, *args, **kwargs):
@@ -74,6 +75,5 @@ def decode(self, obj_type, obj_full, *args, **kwargs):
- bool: Always True, indicating that the decoding was successful
- jax.numpy.DeviceArray: The decoded JAX tensor data
"""
- with io.BytesIO(base64.b64decode(obj_full[1].encode('ascii'))) as memfile:
+ with io.BytesIO(base64.b64decode(obj_full[1].encode("ascii"))) as memfile:
return True, jax.numpy.array(np.load(memfile))
-
diff --git a/wrapyfi/plugins/mxnet_tensor.py b/wrapyfi/plugins/mxnet_tensor.py
index 400bf6f..6e396cf 100644
--- a/wrapyfi/plugins/mxnet_tensor.py
+++ b/wrapyfi/plugins/mxnet_tensor.py
@@ -28,6 +28,7 @@
try:
import mxnet
+
HAVE_MXNET = True
except ImportError:
HAVE_MXNET = False
@@ -41,7 +42,7 @@ def mxnet_device_to_str(device):
:return: Union[str, dict]: A string or dictionary representing the MXNet device
"""
if device is None:
- return 'cpu:0'
+ return "cpu:0"
elif isinstance(device, dict):
device_rets = {}
for k, v in device.items():
@@ -52,7 +53,7 @@ def mxnet_device_to_str(device):
elif isinstance(device, str):
return device.replace("gpu", "cuda")
else:
- raise ValueError(f'Unknown device type {type(device)}')
+ raise ValueError(f"Unknown device type {type(device)}")
@lru_cache(maxsize=None)
@@ -69,16 +70,18 @@ def mxnet_str_to_device(device):
return device
elif isinstance(device, str):
try:
- device_type, device_id = device.split(':')
+ device_type, device_id = device.split(":")
except ValueError:
device_type = device
device_id = 0
return mxnet.Context(device_type.replace("cuda", "gpu"), int(device_id))
else:
- raise ValueError(f'Unknown device type {type(device)}')
+ raise ValueError(f"Unknown device type {type(device)}")
-@PluginRegistrar.register(types=None if not HAVE_MXNET else mxnet.nd.NDArray.__mro__[:-1])
+@PluginRegistrar.register(
+ types=None if not HAVE_MXNET else mxnet.nd.NDArray.__mro__[:-1]
+)
class MXNetTensor(Plugin):
def __init__(self, load_mxnet_device=None, map_mxnet_devices=None, **kwargs):
"""
@@ -89,7 +92,7 @@ def __init__(self, load_mxnet_device=None, map_mxnet_devices=None, **kwargs):
"""
self.map_mxnet_devices = map_mxnet_devices or {}
if load_mxnet_device is not None:
- self.map_mxnet_devices['default'] = load_mxnet_device
+ self.map_mxnet_devices["default"] = load_mxnet_device
self.map_mxnet_devices = mxnet_device_to_str(self.map_mxnet_devices)
def encode(self, obj, *args, **kwargs):
@@ -106,9 +109,11 @@ def encode(self, obj, *args, **kwargs):
"""
with io.BytesIO() as memfile:
np.save(memfile, obj.asnumpy())
- obj_data = base64.b64encode(memfile.getvalue()).decode('ascii')
+ obj_data = base64.b64encode(memfile.getvalue()).decode("ascii")
obj_device = mxnet_device_to_str(obj.context)
- return True, dict(__wrapyfi__=(str(self.__class__.__name__), obj_data, obj_device))
+ return True, dict(
+ __wrapyfi__=(str(self.__class__.__name__), obj_data, obj_device)
+ )
def decode(self, obj_type, obj_full, *args, **kwargs):
"""
@@ -122,8 +127,10 @@ def decode(self, obj_type, obj_full, *args, **kwargs):
- bool: Always True, indicating that the decoding was successful
- mxnet.nd.NDArray: The decoded MXNet tensor data
"""
- with io.BytesIO(base64.b64decode(obj_full[1].encode('ascii'))) as memfile:
- obj_device = self.map_mxnet_devices.get(obj_full[2], self.map_mxnet_devices.get('default', None))
+ with io.BytesIO(base64.b64decode(obj_full[1].encode("ascii"))) as memfile:
+ obj_device = self.map_mxnet_devices.get(
+ obj_full[2], self.map_mxnet_devices.get("default", None)
+ )
if obj_device is not None:
obj_device = mxnet_str_to_device(obj_device)
return True, mxnet.nd.array(np.load(memfile), ctx=obj_device)
diff --git a/wrapyfi/plugins/paddle_tensor.py b/wrapyfi/plugins/paddle_tensor.py
index 2a1113e..c66dcc3 100644
--- a/wrapyfi/plugins/paddle_tensor.py
+++ b/wrapyfi/plugins/paddle_tensor.py
@@ -29,6 +29,7 @@
try:
import paddle
from paddle import Tensor
+
HAVE_PADDLE = True
except ImportError:
HAVE_PADDLE = False
@@ -53,7 +54,7 @@ def paddle_device_to_str(device):
elif isinstance(device, str):
return device.replace("cuda", "gpu")
else:
- raise ValueError(f'Unknown device type {type(device)}')
+ raise ValueError(f"Unknown device type {type(device)}")
@lru_cache(maxsize=None)
@@ -70,12 +71,12 @@ def paddle_str_to_device(device):
return device
elif isinstance(device, str):
try:
- device_type, device_id = device.split(':')
+ device_type, device_id = device.split(":")
except ValueError:
device_type = device
return paddle.device._convert_to_place(device_type.replace("cuda", "gpu"))
else:
- raise ValueError(f'Unknown device type {type(device)}')
+ raise ValueError(f"Unknown device type {type(device)}")
@PluginRegistrar.register(types=None if not HAVE_PADDLE else paddle.Tensor.__mro__[:-1])
@@ -89,7 +90,7 @@ def __init__(self, load_paddle_device=None, map_paddle_devices=None, **kwargs):
"""
self.map_paddle_devices = map_paddle_devices or {}
if load_paddle_device is not None:
- self.map_paddle_devices['default'] = load_paddle_device
+ self.map_paddle_devices["default"] = load_paddle_device
self.map_paddle_devices = paddle_device_to_str(self.map_paddle_devices)
def encode(self, obj, *args, **kwargs):
@@ -106,9 +107,11 @@ def encode(self, obj, *args, **kwargs):
"""
with io.BytesIO() as memfile:
paddle.save(obj, memfile)
- obj_data = base64.b64encode(memfile.getvalue()).decode('ascii')
+ obj_data = base64.b64encode(memfile.getvalue()).decode("ascii")
obj_device = paddle_device_to_str(obj.place)
- return True, dict(__wrapyfi__=(str(self.__class__.__name__), obj_data, obj_device))
+ return True, dict(
+ __wrapyfi__=(str(self.__class__.__name__), obj_data, obj_device)
+ )
def decode(self, obj_type, obj_full, *args, **kwargs):
"""
@@ -122,11 +125,12 @@ def decode(self, obj_type, obj_full, *args, **kwargs):
- bool: Always True, indicating that the decoding was successful
- paddle.Tensor: The decoded PaddlePaddle tensor data
"""
- with io.BytesIO(base64.b64decode(obj_full[1].encode('ascii'))) as memfile:
- obj_device = self.map_paddle_devices.get(obj_full[2], self.map_paddle_devices.get('default', None))
+ with io.BytesIO(base64.b64decode(obj_full[1].encode("ascii"))) as memfile:
+ obj_device = self.map_paddle_devices.get(
+ obj_full[2], self.map_paddle_devices.get("default", None)
+ )
if obj_device is not None:
obj_device = paddle_str_to_device(obj_device)
return True, paddle.Tensor(paddle.load(memfile), place=obj_device)
else:
return True, paddle.load(memfile)
-
diff --git a/wrapyfi/plugins/pandas_data.py b/wrapyfi/plugins/pandas_data.py
index 6a0d005..c0641a6 100644
--- a/wrapyfi/plugins/pandas_data.py
+++ b/wrapyfi/plugins/pandas_data.py
@@ -26,12 +26,19 @@
try:
import pandas
+
HAVE_PANDAS = True
except ImportError:
HAVE_PANDAS = False
-@PluginRegistrar.register(types=None if not HAVE_PANDAS else pandas.DataFrame.__mro__[:-1] + pandas.Series.__mro__[:-1])
+@PluginRegistrar.register(
+ types=(
+ None
+ if not HAVE_PANDAS
+ else pandas.DataFrame.__mro__[:-1] + pandas.Series.__mro__[:-1]
+ )
+)
class PandasData(Plugin):
def __init__(self, **kwargs):
"""
@@ -53,13 +60,15 @@ def encode(self, obj, *args, **kwargs):
"""
with io.BytesIO() as memfile:
if isinstance(obj, pandas.DataFrame):
- obj_type = 'DataFrame'
+ obj_type = "DataFrame"
obj.to_json(memfile, orient="split")
elif isinstance(obj, pandas.Series):
- obj_type = 'Series'
+ obj_type = "Series"
obj.to_frame().to_json(memfile, orient="split")
- obj_data = base64.b64encode(memfile.getvalue()).decode('ascii')
- return True, dict(__wrapyfi__=(str(self.__class__.__name__), obj_data, obj_type))
+ obj_data = base64.b64encode(memfile.getvalue()).decode("ascii")
+ return True, dict(
+ __wrapyfi__=(str(self.__class__.__name__), obj_data, obj_type)
+ )
def decode(self, obj_type, obj_full, *args, **kwargs):
"""
@@ -73,10 +82,9 @@ def decode(self, obj_type, obj_full, *args, **kwargs):
- bool: Always True, indicating that the decoding was successful
- Union[pandas.DataFrame, pandas.Series]: The decoded pandas data
"""
- with io.BytesIO(base64.b64decode(obj_full[1].encode('ascii'))) as memfile:
+ with io.BytesIO(base64.b64decode(obj_full[1].encode("ascii"))) as memfile:
obj = pandas.read_json(memfile, orient="split")
obj_type = obj_full[2]
- if obj_type == 'Series':
+ if obj_type == "Series":
obj = obj.iloc[:, 0]
return True, obj
-
diff --git a/wrapyfi/plugins/pillow_image.py b/wrapyfi/plugins/pillow_image.py
index 42f56fd..f4f0447 100644
--- a/wrapyfi/plugins/pillow_image.py
+++ b/wrapyfi/plugins/pillow_image.py
@@ -24,6 +24,7 @@
try:
from PIL import Image
+
HAVE_PIL = True
except ImportError:
HAVE_PIL = False
@@ -50,12 +51,14 @@ def encode(self, obj, *args, **kwargs):
- '__wrapyfi__': A tuple containing the class name and encoded data string, with optional image size and mode for raw data
"""
if obj.format is None:
- obj_data = base64.b64encode(obj.tobytes()).decode('ascii')
- return True, dict(__wrapyfi__=(str(self.__class__.__name__), obj_data, obj.size, obj.mode))
+ obj_data = base64.b64encode(obj.tobytes()).decode("ascii")
+ return True, dict(
+ __wrapyfi__=(str(self.__class__.__name__), obj_data, obj.size, obj.mode)
+ )
else:
with io.BytesIO() as memfile:
obj.save(memfile, format=obj.format)
- obj_data = memfile.getvalue().decode('latin1')
+ obj_data = memfile.getvalue().decode("latin1")
return True, dict(__wrapyfi__=(str(self.__class__.__name__), obj_data))
def decode(self, obj_type, obj_full, *args, **kwargs):
@@ -71,10 +74,11 @@ def decode(self, obj_type, obj_full, *args, **kwargs):
- Image.Image: The decoded PIL Image data
"""
if len(obj_full) == 4:
- with io.BytesIO(obj_full[1].encode('ascii')) as memfile:
- return True, Image.frombytes(obj_full[3], obj_full[2], memfile.getvalue(), "raw")
+ with io.BytesIO(obj_full[1].encode("ascii")) as memfile:
+ return True, Image.frombytes(
+ obj_full[3], obj_full[2], memfile.getvalue(), "raw"
+ )
else:
- with io.BytesIO(obj_full[1].encode('latin1')) as memfile:
+ with io.BytesIO(obj_full[1].encode("latin1")) as memfile:
memfile.seek(0)
return True, Image.open(memfile).copy()
-
diff --git a/wrapyfi/plugins/pint_quantities.py b/wrapyfi/plugins/pint_quantities.py
index 6a9ab9e..ad64cc6 100644
--- a/wrapyfi/plugins/pint_quantities.py
+++ b/wrapyfi/plugins/pint_quantities.py
@@ -24,6 +24,7 @@
try:
from pint import Quantity
+
HAVE_PINT = True
except ImportError:
HAVE_PINT = False
@@ -50,13 +51,14 @@ def encode(self, obj, *args, **kwargs):
- '__wrapyfi__': A tuple containing the class name, encoded data string, and object type
"""
if isinstance(obj, Quantity):
- obj_type = 'Quantity'
- obj_data = json.dumps({
- 'magnitude': obj.magnitude,
- 'units': str(obj.units)
- }).encode('ascii')
- obj_data = base64.b64encode(obj_data).decode('ascii')
- return True, dict(__wrapyfi__=(str(self.__class__.__name__), obj_data, obj_type))
+ obj_type = "Quantity"
+ obj_data = json.dumps(
+ {"magnitude": obj.magnitude, "units": str(obj.units)}
+ ).encode("ascii")
+ obj_data = base64.b64encode(obj_data).decode("ascii")
+ return True, dict(
+ __wrapyfi__=(str(self.__class__.__name__), obj_data, obj_type)
+ )
else:
# TypeError("Unknown object type: {}".format(obj_type))
return False, {}
@@ -73,13 +75,16 @@ def decode(self, obj_type, obj_full, *args, **kwargs):
- bool: Indicating that the decoding was successful
- pint.Quantity: The decoded Pint Quantity data
"""
- obj_data = base64.b64decode(obj_full[1].encode('ascii')).decode('ascii')
+ obj_data = base64.b64decode(obj_full[1].encode("ascii")).decode("ascii")
obj_data = json.loads(obj_data)
obj_type = obj_full[2]
- if obj_type == 'Quantity':
+ if obj_type == "Quantity":
from pint import UnitRegistry
+
ureg = UnitRegistry()
- obj = Quantity(obj_data['magnitude'], ureg.parse_expression(obj_data['units']))
+ obj = Quantity(
+ obj_data["magnitude"], ureg.parse_expression(obj_data["units"])
+ )
return True, obj
else:
# TypeError("Unknown object type: {}".format(obj_type))
diff --git a/wrapyfi/plugins/pyarrow_array.py b/wrapyfi/plugins/pyarrow_array.py
index 3fd29be..24f9a27 100644
--- a/wrapyfi/plugins/pyarrow_array.py
+++ b/wrapyfi/plugins/pyarrow_array.py
@@ -24,12 +24,15 @@
try:
import pyarrow as pa
+
HAVE_PYARROW = True
except ImportError:
HAVE_PYARROW = False
-@PluginRegistrar.register(types=None if not HAVE_PYARROW else pa.StructArray.__mro__[:-1])
+@PluginRegistrar.register(
+ types=None if not HAVE_PYARROW else pa.StructArray.__mro__[:-1]
+)
class PyArrowArray(Plugin):
def __init__(self, **kwargs):
"""
@@ -51,8 +54,16 @@ def encode(self, obj, *args, **kwargs):
"""
buffers = []
obj_data = pickle.dumps(obj, protocol=5, buffer_callback=buffers.append)
- obj_buffers = list(map(lambda x: base64.b64encode(memoryview(x)).decode('ascii'), buffers))
- return True, dict(__wrapyfi__=(str(self.__class__.__name__), obj_data.decode('latin1'), *obj_buffers))
+ obj_buffers = list(
+ map(lambda x: base64.b64encode(memoryview(x)).decode("ascii"), buffers)
+ )
+ return True, dict(
+ __wrapyfi__=(
+ str(self.__class__.__name__),
+ obj_data.decode("latin1"),
+ *obj_buffers,
+ )
+ )
def decode(self, obj_type, obj_full, *args, **kwargs):
"""
@@ -66,8 +77,10 @@ def decode(self, obj_type, obj_full, *args, **kwargs):
- bool: Always True, indicating that the decoding was successful
- pa.StructArray: The decoded PyArrow StructArray data
"""
- obj_data = obj_full[1].encode('latin1')
- obj_buffers = list(map(lambda x: base64.b64decode(x.encode('ascii')), obj_full[2:]))
+ obj_data = obj_full[1].encode("latin1")
+ obj_buffers = list(
+ map(lambda x: base64.b64decode(x.encode("ascii")), obj_full[2:])
+ )
obj_data = bytearray(obj_data)
for buf in obj_buffers:
obj_data += buf
diff --git a/wrapyfi/plugins/pytorch_tensor.py b/wrapyfi/plugins/pytorch_tensor.py
index 75f8e18..533ca10 100644
--- a/wrapyfi/plugins/pytorch_tensor.py
+++ b/wrapyfi/plugins/pytorch_tensor.py
@@ -24,6 +24,7 @@
try:
import torch
+
HAVE_TORCH = True
except ImportError:
HAVE_TORCH = False
@@ -40,7 +41,7 @@ def __init__(self, load_torch_device=None, map_torch_devices=None, **kwargs):
"""
self.map_torch_devices = map_torch_devices or {}
if load_torch_device is not None:
- self.map_torch_devices['default'] = load_torch_device
+ self.map_torch_devices["default"] = load_torch_device
def encode(self, obj, *args, **kwargs):
"""
@@ -56,9 +57,11 @@ def encode(self, obj, *args, **kwargs):
"""
with io.BytesIO() as memfile:
torch.save(obj, memfile)
- obj_data = base64.b64encode(memfile.getvalue()).decode('ascii')
+ obj_data = base64.b64encode(memfile.getvalue()).decode("ascii")
obj_device = str(obj.device)
- return True, dict(__wrapyfi__=(str(self.__class__.__name__), obj_data, obj_device))
+ return True, dict(
+ __wrapyfi__=(str(self.__class__.__name__), obj_data, obj_device)
+ )
def decode(self, obj_type, obj_full, *args, **kwargs):
"""
@@ -72,10 +75,11 @@ def decode(self, obj_type, obj_full, *args, **kwargs):
- bool: Always True, indicating that the decoding was successful
- torch.Tensor: The decoded PyTorch tensor data
"""
- with io.BytesIO(base64.b64decode(obj_full[1].encode('ascii'))) as memfile:
- obj_device = self.map_torch_devices.get(obj_full[2], self.map_torch_devices.get('default', None))
+ with io.BytesIO(base64.b64decode(obj_full[1].encode("ascii"))) as memfile:
+ obj_device = self.map_torch_devices.get(
+ obj_full[2], self.map_torch_devices.get("default", None)
+ )
if obj_device is not None:
return True, torch.load(memfile, map_location=obj_device)
else:
return True, torch.load(memfile)
-
diff --git a/wrapyfi/plugins/tensorflow_tensor.py b/wrapyfi/plugins/tensorflow_tensor.py
index af2ae70..52bbba7 100644
--- a/wrapyfi/plugins/tensorflow_tensor.py
+++ b/wrapyfi/plugins/tensorflow_tensor.py
@@ -26,12 +26,15 @@
try:
import tensorflow
+
HAVE_TENSORFLOW = True
except ImportError:
HAVE_TENSORFLOW = False
-@PluginRegistrar.register(types=None if not HAVE_TENSORFLOW else tensorflow.Tensor.__mro__[:-1])
+@PluginRegistrar.register(
+ types=None if not HAVE_TENSORFLOW else tensorflow.Tensor.__mro__[:-1]
+)
class TensorflowTensor(Plugin):
def __init__(self, **kwargs):
"""
@@ -53,7 +56,7 @@ def encode(self, obj, *args, **kwargs):
"""
with io.BytesIO() as memfile:
np.save(memfile, obj.numpy())
- obj_data = base64.b64encode(memfile.getvalue()).decode('ascii')
+ obj_data = base64.b64encode(memfile.getvalue()).decode("ascii")
return True, dict(__wrapyfi__=(str(self.__class__.__name__), obj_data))
def decode(self, obj_type, obj_full, *args, **kwargs):
@@ -68,5 +71,5 @@ def decode(self, obj_type, obj_full, *args, **kwargs):
- bool: Always True, indicating that the decoding was successful
- tensorflow.Tensor: The decoded TensorFlow tensor data
"""
- with io.BytesIO(base64.b64decode(obj_full[1].encode('ascii'))) as memfile:
+ with io.BytesIO(base64.b64decode(obj_full[1].encode("ascii"))) as memfile:
return True, tensorflow.convert_to_tensor(np.load(memfile))
diff --git a/wrapyfi/plugins/xarray_data.py b/wrapyfi/plugins/xarray_data.py
index f06df4d..24986b7 100644
--- a/wrapyfi/plugins/xarray_data.py
+++ b/wrapyfi/plugins/xarray_data.py
@@ -28,12 +28,17 @@
try:
import xarray as xr
+
HAVE_XARRAY = True
except ImportError:
HAVE_XARRAY = False
-@PluginRegistrar.register(types=None if not HAVE_XARRAY else xr.DataArray.__mro__[:-1] + xr.Dataset.__mro__[:-1])
+@PluginRegistrar.register(
+ types=(
+ None if not HAVE_XARRAY else xr.DataArray.__mro__[:-1] + xr.Dataset.__mro__[:-1]
+ )
+)
class XArrayData(Plugin):
def __init__(self, **kwargs):
"""
@@ -53,11 +58,11 @@ def encode(self, obj, *args, **kwargs):
- dict: A dictionary containing:
- '__wrapyfi__': A tuple containing the class name, encoded data string, data type, and object name
"""
- obj_type = 'Dataset'
+ obj_type = "Dataset"
obj_name = None
if isinstance(obj, xr.DataArray):
obj_dict = obj.to_dataset().to_dict()
- obj_type = 'DataArray'
+ obj_type = "DataArray"
obj_name = obj.name
elif isinstance(obj, xr.Dataset):
obj_dict = obj.to_dict()
@@ -68,17 +73,19 @@ def traverse_and_convert(obj):
elif isinstance(obj, dict):
return {key: traverse_and_convert(value) for key, value in obj.items()}
elif isinstance(obj, np.datetime64):
- return {'data': str(obj), 'dtype': str(obj.dtype)}
+ return {"data": str(obj), "dtype": str(obj.dtype)}
elif isinstance(obj, datetime):
- return {'data': obj.isoformat(), 'dtype': 'datetime'}
+ return {"data": obj.isoformat(), "dtype": "datetime"}
else:
return obj
converted_obj_dict = traverse_and_convert(obj_dict)
obj_json = json.dumps(converted_obj_dict)
- obj_data = base64.b64encode(obj_json.encode('ascii')).decode('ascii')
+ obj_data = base64.b64encode(obj_json.encode("ascii")).decode("ascii")
- return True, dict(__wrapyfi__=(str(self.__class__.__name__), obj_data, obj_type, obj_name))
+ return True, dict(
+ __wrapyfi__=(str(self.__class__.__name__), obj_data, obj_type, obj_name)
+ )
def decode(self, obj_type, obj_full, *args, **kwargs):
"""
@@ -92,7 +99,7 @@ def decode(self, obj_type, obj_full, *args, **kwargs):
- bool: Always True, indicating that the decoding was successful
- Union[xr.DataArray, xr.Dataset]: The decoded XArray Data
"""
- obj_data = base64.b64decode(obj_full[1].encode('ascii')).decode('ascii')
+ obj_data = base64.b64decode(obj_full[1].encode("ascii")).decode("ascii")
xarray_type = obj_full[2]
xarray_name = obj_full[3]
@@ -100,21 +107,23 @@ def traverse_and_reconvert(obj):
if isinstance(obj, list):
return [traverse_and_reconvert(item) for item in obj]
elif isinstance(obj, dict):
- if 'data' in obj and 'dtype' in obj:
- if obj['dtype'] == 'datetime':
- return datetime.fromisoformat(obj['data'])
+ if "data" in obj and "dtype" in obj:
+ if obj["dtype"] == "datetime":
+ return datetime.fromisoformat(obj["data"])
else:
- return np.datetime64(obj['data'], obj['dtype'])
- return {key: traverse_and_reconvert(value) for key, value in obj.items()}
+ return np.datetime64(obj["data"], obj["dtype"])
+ return {
+ key: traverse_and_reconvert(value) for key, value in obj.items()
+ }
else:
return obj
obj_dict = json.loads(obj_data)
reconverted_obj_dict = traverse_and_reconvert(obj_dict)
- if xarray_type == 'DataArray':
+ if xarray_type == "DataArray":
dataset = xr.Dataset.from_dict(reconverted_obj_dict)
data_array = dataset[xarray_name]
return True, data_array
else:
- return True, xr.Dataset.from_dict(reconverted_obj_dict)
\ No newline at end of file
+ return True, xr.Dataset.from_dict(reconverted_obj_dict)
diff --git a/wrapyfi/plugins/zarr_array.py b/wrapyfi/plugins/zarr_array.py
index 47ded18..1fb16ff 100644
--- a/wrapyfi/plugins/zarr_array.py
+++ b/wrapyfi/plugins/zarr_array.py
@@ -26,12 +26,15 @@
try:
import zipfile
import zarr
+
HAVE_ZARR = True
except ImportError:
HAVE_ZARR = False
-@PluginRegistrar.register(types=None if not HAVE_ZARR else zarr.Array.__mro__[:-1] + zarr.Group.__mro__[:-1])
+@PluginRegistrar.register(
+ types=None if not HAVE_ZARR else zarr.Array.__mro__[:-1] + zarr.Group.__mro__[:-1]
+)
class ZarrData(Plugin):
def __init__(self, **kwargs):
"""
@@ -51,27 +54,32 @@ def encode(self, obj, *args, **kwargs):
- dict: A dictionary containing:
- '__wrapyfi__': A tuple containing the class name, encoded data string, data type, and object name.
"""
- obj_type = 'Group' if isinstance(obj, zarr.Group) else 'Array'
+ obj_type = "Group" if isinstance(obj, zarr.Group) else "Array"
obj_name = obj.name if isinstance(obj, zarr.Array) else None
with tempfile.TemporaryDirectory() as tmpdirname:
- store_path = os.path.join(tmpdirname, 'zarr_store')
+ store_path = os.path.join(tmpdirname, "zarr_store")
- if obj_type == 'Array':
+ if obj_type == "Array":
zarr.save_array(store_path, obj)
else:
zarr.save_group(store_path, obj)
with io.BytesIO() as binary_stream:
- with zipfile.ZipFile(binary_stream, 'w') as zipf:
+ with zipfile.ZipFile(binary_stream, "w") as zipf:
for foldername, subfolders, filenames in os.walk(store_path):
for filename in filenames:
- zipf.write(os.path.join(foldername, filename),
- arcname=os.path.relpath(os.path.join(foldername, filename),
- store_path))
- obj_data = base64.b64encode(binary_stream.getvalue()).decode('ascii')
-
- return True, dict(__wrapyfi__=(str(self.__class__.__name__), obj_data, obj_type, obj_name))
+ zipf.write(
+ os.path.join(foldername, filename),
+ arcname=os.path.relpath(
+ os.path.join(foldername, filename), store_path
+ ),
+ )
+ obj_data = base64.b64encode(binary_stream.getvalue()).decode("ascii")
+
+ return True, dict(
+ __wrapyfi__=(str(self.__class__.__name__), obj_data, obj_type, obj_name)
+ )
def decode(self, obj_type, obj_full, *args, **kwargs):
"""
@@ -87,17 +95,19 @@ def decode(self, obj_type, obj_full, *args, **kwargs):
"""
obj_data = obj_full[1]
zarr_type = obj_full[2]
- zarr_name = obj_full[3] # zarr_name is used only if zarr_type is 'Array'. Currently not used.
+ zarr_name = obj_full[
+ 3
+ ] # zarr_name is used only if zarr_type is 'Array'. Currently not used.
- with io.BytesIO(base64.b64decode(obj_data.encode('ascii'))) as binary_stream:
+ with io.BytesIO(base64.b64decode(obj_data.encode("ascii"))) as binary_stream:
with tempfile.TemporaryDirectory() as tmpdirname:
- store_path = os.path.join(tmpdirname, 'zarr_store')
- with zipfile.ZipFile(binary_stream, 'r') as zipf:
+ store_path = os.path.join(tmpdirname, "zarr_store")
+ with zipfile.ZipFile(binary_stream, "r") as zipf:
zipf.extractall(path=store_path)
- if zarr_type == 'Array':
- array = zarr.open_array(store_path, mode='r')
+ if zarr_type == "Array":
+ array = zarr.open_array(store_path, mode="r")
return True, array
else:
- group = zarr.open_group(store_path, mode='r')
+ group = zarr.open_group(store_path, mode="r")
return True, group
diff --git a/wrapyfi/publishers/__init__.py b/wrapyfi/publishers/__init__.py
index 6880b6e..947d5b0 100755
--- a/wrapyfi/publishers/__init__.py
+++ b/wrapyfi/publishers/__init__.py
@@ -6,11 +6,22 @@
@Publishers.register("MMO", "fallback")
class FallbackPublisher(Publisher):
- def __init__(self, name: str, out_topic: str, carrier: str = "",
- should_wait: bool = True, missing_middleware_object: str = "", **kwargs):
- logging.warning(f"Fallback publisher employed due to missing middleware or object type: "
- f"{missing_middleware_object}")
- Publisher.__init__(self, name, out_topic, carrier=carrier, should_wait=should_wait, **kwargs)
+ def __init__(
+ self,
+ name: str,
+ out_topic: str,
+ carrier: str = "",
+ should_wait: bool = True,
+ missing_middleware_object: str = "",
+ **kwargs,
+ ):
+ logging.warning(
+ f"Fallback publisher employed due to missing middleware or object type: "
+ f"{missing_middleware_object}"
+ )
+ Publisher.__init__(
+ self, name, out_topic, carrier=carrier, should_wait=should_wait, **kwargs
+ )
self.missing_middleware_object = missing_middleware_object
def establish(self, repeats: int = -1, **kwargs):
@@ -20,4 +31,4 @@ def publish(self, obj):
return obj
def close(self):
- return None
\ No newline at end of file
+ return None
diff --git a/wrapyfi/publishers/ros.py b/wrapyfi/publishers/ros.py
index f4570d5..ec6537c 100755
--- a/wrapyfi/publishers/ros.py
+++ b/wrapyfi/publishers/ros.py
@@ -25,8 +25,16 @@
class ROSPublisher(Publisher):
- def __init__(self, name: str, out_topic: str, carrier: str = "tcp", should_wait: bool = True,
- queue_size: int = QUEUE_SIZE, ros_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ out_topic: str,
+ carrier: str = "tcp",
+ should_wait: bool = True,
+ queue_size: int = QUEUE_SIZE,
+ ros_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
Initialize the publisher.
@@ -39,14 +47,20 @@ def __init__(self, name: str, out_topic: str, carrier: str = "tcp", should_wait:
:param kwargs: dict: Additional kwargs for the publisher
"""
if carrier or carrier != "tcp":
- logging.warning("[ROS] ROS does not support other carriers than TCP for PUB/SUB pattern. Using TCP.")
+ logging.warning(
+ "[ROS] ROS does not support other carriers than TCP for PUB/SUB pattern. Using TCP."
+ )
carrier = "tcp"
- super().__init__(name, out_topic, carrier=carrier, should_wait=should_wait, **kwargs)
+ super().__init__(
+ name, out_topic, carrier=carrier, should_wait=should_wait, **kwargs
+ )
ROSMiddleware.activate(**ros_kwargs or {})
self.queue_size = queue_size
- def await_connection(self, publisher, out_topic: Optional[str] = None, repeats: Optional[int] = None):
+ def await_connection(
+ self, publisher, out_topic: Optional[str] = None, repeats: Optional[int] = None
+ ):
"""
Wait for at least one subscriber to connect to the publisher.
@@ -88,8 +102,16 @@ def __del__(self):
@Publishers.register("NativeObject", "ros")
class ROSNativeObjectPublisher(ROSPublisher):
- def __init__(self, name: str, out_topic: str, carrier: str = "tcp", should_wait: bool = True,
- queue_size: int = QUEUE_SIZE, serializer_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ out_topic: str,
+ carrier: str = "tcp",
+ should_wait: bool = True,
+ queue_size: int = QUEUE_SIZE,
+ serializer_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
The NativeObject publisher using the ROS String message assuming a combination of python native objects.
@@ -101,7 +123,14 @@ def __init__(self, name: str, out_topic: str, carrier: str = "tcp", should_wait:
:param queue_size: int: Queue size for the publisher. Default is 5
:param serializer_kwargs: dict: Additional kwargs for the serializer
"""
- super().__init__(name, out_topic, carrier=carrier, should_wait=should_wait, queue_size=queue_size, **kwargs)
+ super().__init__(
+ name,
+ out_topic,
+ carrier=carrier,
+ should_wait=should_wait,
+ queue_size=queue_size,
+ **kwargs,
+ )
self._plugin_encoder = JsonEncoder
self._plugin_kwargs = kwargs
self._serializer_kwargs = serializer_kwargs or {}
@@ -118,7 +147,9 @@ def establish(self, repeats: Optional[int] = None, **kwargs):
:param repeats: int: Number of repeats to await connection. None for infinite. Default is None
:return: bool: True if connection established, False otherwise
"""
- self._publisher = rospy.Publisher(self.out_topic, std_msgs.msg.String, queue_size=self.queue_size)
+ self._publisher = rospy.Publisher(
+ self.out_topic, std_msgs.msg.String, queue_size=self.queue_size
+ )
established = self.await_connection(self._publisher, repeats=repeats)
return self.check_establishment(established)
@@ -134,16 +165,32 @@ def publish(self, obj):
return
else:
time.sleep(0.2)
- obj_str = json.dumps(obj, cls=self._plugin_encoder, **self._plugin_kwargs,
- serializer_kwrags=self._serializer_kwargs)
+ obj_str = json.dumps(
+ obj,
+ cls=self._plugin_encoder,
+ **self._plugin_kwargs,
+ serializer_kwrags=self._serializer_kwargs,
+ )
self._publisher.publish(obj_str)
@Publishers.register("Image", "ros")
class ROSImagePublisher(ROSPublisher):
- def __init__(self, name: str, out_topic: str, carrier: str = "tcp", should_wait: bool = True, queue_size: int = QUEUE_SIZE,
- width: int = -1, height: int = -1, rgb: bool = True, fp: bool = False, jpg: bool = False, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ out_topic: str,
+ carrier: str = "tcp",
+ should_wait: bool = True,
+ queue_size: int = QUEUE_SIZE,
+ width: int = -1,
+ height: int = -1,
+ rgb: bool = True,
+ fp: bool = False,
+ jpg: bool = False,
+ **kwargs,
+ ):
"""
The ImagePublisher using the ROS Image message assuming a numpy array as input.
@@ -158,7 +205,14 @@ def __init__(self, name: str, out_topic: str, carrier: str = "tcp", should_wait
:param fp: bool: True if the image is floating point, False if it is integer. Default is False
:param jpg: bool: True if the image should be compressed as JPG. Default is False
"""
- super().__init__(name, out_topic, carrier=carrier, should_wait=should_wait, queue_size=queue_size, **kwargs)
+ super().__init__(
+ name,
+ out_topic,
+ carrier=carrier,
+ should_wait=should_wait,
+ queue_size=queue_size,
+ **kwargs,
+ )
self.width = width
self.height = height
self.rgb = rgb
@@ -166,13 +220,13 @@ def __init__(self, name: str, out_topic: str, carrier: str = "tcp", should_wait
self.jpg = jpg
if self.fp:
- self._encoding = '32FC3' if self.rgb else '32FC1'
+ self._encoding = "32FC3" if self.rgb else "32FC1"
self._type = np.float32
else:
- self._encoding = 'bgr8' if self.rgb else 'mono8'
+ self._encoding = "bgr8" if self.rgb else "mono8"
self._type = np.uint8
if self.jpg:
- self._encoding = 'jpeg'
+ self._encoding = "jpeg"
self._type = np.uint8
self._publisher = None
@@ -188,9 +242,15 @@ def establish(self, repeats: Optional[int] = None, **kwargs):
:return: bool: True if connection established, False otherwise
"""
if self.jpg:
- self._publisher = rospy.Publisher(self.out_topic, sensor_msgs.msg.CompressedImage, queue_size=self.queue_size)
+ self._publisher = rospy.Publisher(
+ self.out_topic,
+ sensor_msgs.msg.CompressedImage,
+ queue_size=self.queue_size,
+ )
else:
- self._publisher = rospy.Publisher(self.out_topic, sensor_msgs.msg.Image, queue_size=self.queue_size)
+ self._publisher = rospy.Publisher(
+ self.out_topic, sensor_msgs.msg.Image, queue_size=self.queue_size
+ )
established = self.await_connection(self._publisher)
return self.check_establishment(established)
@@ -210,23 +270,31 @@ def publish(self, img: np.ndarray):
else:
time.sleep(0.2)
- if 0 < self.width != img.shape[1] or 0 < self.height != img.shape[0] or \
- not ((img.ndim == 2 and not self.rgb) or (img.ndim == 3 and self.rgb and img.shape[2] == 3)):
+ if (
+ 0 < self.width != img.shape[1]
+ or 0 < self.height != img.shape[0]
+ or not (
+ (img.ndim == 2 and not self.rgb)
+ or (img.ndim == 3 and self.rgb and img.shape[2] == 3)
+ )
+ ):
raise ValueError("Incorrect image shape for publisher")
- img = np.require(img, dtype=self._type, requirements='C')
+ img = np.require(img, dtype=self._type, requirements="C")
if self.jpg:
img_msg = sensor_msgs.msg.CompressedImage()
img_msg.header.stamp = rospy.Time.now()
img_msg.format = "jpeg"
- img_msg.data = np.array(cv2.imencode('.jpg', img)[1]).tobytes()
+ img_msg.data = np.array(cv2.imencode(".jpg", img)[1]).tobytes()
else:
img_msg = sensor_msgs.msg.Image()
img_msg.header.stamp = rospy.Time.now()
img_msg.height = img.shape[0]
img_msg.width = img.shape[1]
img_msg.encoding = self._encoding
- img_msg.is_bigendian = img.dtype.byteorder == '>' or (img.dtype.byteorder == '=' and sys.byteorder == 'big')
+ img_msg.is_bigendian = img.dtype.byteorder == ">" or (
+ img.dtype.byteorder == "=" and sys.byteorder == "big"
+ )
img_msg.step = img.strides[0]
img_msg.data = img.tobytes()
self._publisher.publish(img_msg)
@@ -235,8 +303,18 @@ def publish(self, img: np.ndarray):
@Publishers.register("AudioChunk", "ros")
class ROSAudioChunkPublisher(ROSPublisher):
- def __init__(self, name: str, out_topic: str, carrier: str = "tcp", should_wait: bool = True, queue_size: int = QUEUE_SIZE,
- channels: int = 1, rate: int = 44100, chunk: int = -1, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ out_topic: str,
+ carrier: str = "tcp",
+ should_wait: bool = True,
+ queue_size: int = QUEUE_SIZE,
+ channels: int = 1,
+ rate: int = 44100,
+ chunk: int = -1,
+ **kwargs,
+ ):
"""
The AudioChunkPublisher using the ROS Audio message assuming a numpy array as input.
@@ -249,7 +327,14 @@ def __init__(self, name: str, out_topic: str, carrier: str = "tcp", should_wait:
:param rate: int: Sampling rate. Default is 44100
:param chunk: int: Chunk size. Default is -1 meaning that the chunk size is not fixed
"""
- super().__init__(name, out_topic, carrier=carrier, should_wait=should_wait, queue_size=queue_size, **kwargs)
+ super().__init__(
+ name,
+ out_topic,
+ carrier=carrier,
+ should_wait=should_wait,
+ queue_size=queue_size,
+ **kwargs,
+ )
self.channels = channels
self.rate = rate
self.chunk = chunk
@@ -270,12 +355,18 @@ def establish(self, repeats: Optional[int] = None, **kwargs):
from wrapyfi_ros_interfaces.msg import ROSAudioMessage
except ImportError:
import wrapyfi
- logging.error("[ROS] Could not import ROSAudioMessage. "
- "Make sure the ROS messages in wrapyfi_extensions/wrapyfi_ros_interfaces are compiled. "
- "Refer to the documentation for more information: \n" +
- wrapyfi.__doc__ + "ros_interfaces_lnk.html")
+
+ logging.error(
+ "[ROS] Could not import ROSAudioMessage. "
+ "Make sure the ROS messages in wrapyfi_extensions/wrapyfi_ros_interfaces are compiled. "
+ "Refer to the documentation for more information: \n"
+ + wrapyfi.__doc__
+ + "ros_interfaces_lnk.html"
+ )
sys.exit(1)
- self._publisher = rospy.Publisher(self.out_topic, ROSAudioMessage, queue_size=self.queue_size)
+ self._publisher = rospy.Publisher(
+ self.out_topic, ROSAudioMessage, queue_size=self.queue_size
+ )
self._sound_msg = ROSAudioMessage()
established = self.await_connection(self._publisher)
return self.check_establishment(established)
@@ -303,15 +394,17 @@ def publish(self, aud: Tuple[np.ndarray, int]):
self.channels = channels if self.channels == -1 else self.channels
if 0 < self.chunk != chunk or 0 < self.channels != channels:
raise ValueError("Incorrect audio shape for publisher")
- aud = np.require(aud, dtype=np.float32, requirements='C')
+ aud = np.require(aud, dtype=np.float32, requirements="C")
aud_msg = self._sound_msg
aud_msg.header.stamp = rospy.Time.now()
aud_msg.chunk_size = chunk
aud_msg.channels = channels
aud_msg.sample_rate = rate
- aud_msg.is_bigendian = aud.dtype.byteorder == '>' or (aud.dtype.byteorder == '=' and sys.byteorder == 'big')
- aud_msg.encoding = 'S16BE' if aud_msg.is_bigendian else 'S16LE'
+ aud_msg.is_bigendian = aud.dtype.byteorder == ">" or (
+ aud.dtype.byteorder == "=" and sys.byteorder == "big"
+ )
+ aud_msg.encoding = "S16BE" if aud_msg.is_bigendian else "S16LE"
aud_msg.step = aud.strides[0]
aud_msg.data = aud.tobytes() # (aud * 32767.0).tobytes()
self._publisher.publish(aud_msg)
@@ -326,7 +419,15 @@ class ROSPropertiesPublisher(ROSPublisher):
but care should be taken when using dictionaries, since they are analogous with node namespaces:
http://wiki.ros.org/rospy/Overview/Parameter%20Server
"""
- def __init__(self, name: str, out_topic: str, carrier: str = "tcp", persistent: bool = True, **kwargs):
+
+ def __init__(
+ self,
+ name: str,
+ out_topic: str,
+ carrier: str = "tcp",
+ persistent: bool = True,
+ **kwargs,
+ ):
"""
The PropertiesPublisher using the ROS parameter server.
@@ -371,8 +472,15 @@ def __del__(self):
@Publishers.register("ROSMessage", "ros")
class ROSMessagePublisher(ROSPublisher):
- def __init__(self, name: str, out_topic: str, carrier: str = "tcp",
- should_wait: bool = True, queue_size: int = QUEUE_SIZE, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ out_topic: str,
+ carrier: str = "tcp",
+ should_wait: bool = True,
+ queue_size: int = QUEUE_SIZE,
+ **kwargs,
+ ):
"""
The ROSMessagePublisher using the ROS message type inferred from the message type. Supports standard ROS msgs.
@@ -382,7 +490,14 @@ def __init__(self, name: str, out_topic: str, carrier: str = "tcp",
:param should_wait: bool: Whether to wait for at least one listener before unblocking the script. Default is True
:param queue_size: int: Queue size for the publisher. Default is 5
"""
- super().__init__(name, out_topic, carrier=carrier, should_wait=should_wait, queue_size=queue_size, **kwargs)
+ super().__init__(
+ name,
+ out_topic,
+ carrier=carrier,
+ should_wait=should_wait,
+ queue_size=queue_size,
+ **kwargs,
+ )
self._publisher = None
@@ -402,7 +517,9 @@ def establish(self, repeats: Optional[int] = None, obj=None, **kwargs):
obj_type = obj._type.split("/")
import_msg = importlib.import_module(f"{obj_type[0]}.msg")
msg_type = getattr(import_msg, obj_type[1])
- self._publisher = rospy.Publisher(self.out_topic, msg_type, queue_size=self.queue_size)
+ self._publisher = rospy.Publisher(
+ self.out_topic, msg_type, queue_size=self.queue_size
+ )
established = self.await_connection(self._publisher, repeats=repeats)
return self.check_establishment(established)
diff --git a/wrapyfi/publishers/ros2.py b/wrapyfi/publishers/ros2.py
index 40c6ae9..b4441b8 100755
--- a/wrapyfi/publishers/ros2.py
+++ b/wrapyfi/publishers/ros2.py
@@ -24,8 +24,15 @@
class ROS2Publisher(Publisher, Node):
- def __init__(self, name: str, out_topic: str, should_wait: bool = True,
- queue_size: int = QUEUE_SIZE, ros2_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ out_topic: str,
+ should_wait: bool = True,
+ queue_size: int = QUEUE_SIZE,
+ ros2_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
Initialize the publisher.
@@ -38,16 +45,22 @@ def __init__(self, name: str, out_topic: str, should_wait: bool = True,
"""
carrier = "tcp"
if "carrier" in kwargs and kwargs["carrier"] not in ["", None]:
- logging.warning("[ROS 2] ROS 2 currently does not support explicit carrier setting for PUB/SUB pattern. Using TCP.")
+ logging.warning(
+ "[ROS 2] ROS 2 currently does not support explicit carrier setting for PUB/SUB pattern. Using TCP."
+ )
if "carrier" in kwargs:
del kwargs["carrier"]
ROS2Middleware.activate(**ros2_kwargs or {})
- Publisher.__init__(self, name, out_topic, carrier=carrier, should_wait=should_wait, **kwargs)
+ Publisher.__init__(
+ self, name, out_topic, carrier=carrier, should_wait=should_wait, **kwargs
+ )
Node.__init__(self, name + str(hex(id(self))))
self.queue_size = queue_size
- def await_connection(self, publisher, out_topic: Optional[str] = None, repeats: Optional[int] = None):
+ def await_connection(
+ self, publisher, out_topic: Optional[str] = None, repeats: Optional[int] = None
+ ):
"""
Wait for at least one subscriber to connect to the publisher.
@@ -80,7 +93,7 @@ def close(self):
"""
if hasattr(self, "_publisher") and self._publisher:
if self._publisher is not None:
- self.destroy_node()
+ self.destroy_node()
def __del__(self):
self.close()
@@ -89,8 +102,15 @@ def __del__(self):
@Publishers.register("NativeObject", "ros2")
class ROS2NativeObjectPublisher(ROS2Publisher):
- def __init__(self, name, out_topic: str, should_wait: bool = True,
- queue_size: int = QUEUE_SIZE, serializer_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name,
+ out_topic: str,
+ should_wait: bool = True,
+ queue_size: int = QUEUE_SIZE,
+ serializer_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
The NativeObject publisher using the ROS 2 String message assuming a combination of python native objects
and numpy arrays as input. Serializes the data (including plugins) using the encoder and sends it as a string.
@@ -101,7 +121,9 @@ def __init__(self, name, out_topic: str, should_wait: bool = True,
:param queue_size: int: Queue size for the publisher. Default is 5
:param serializer_kwargs: dict: Additional kwargs for the serializer
"""
- super().__init__(name, out_topic, should_wait=should_wait, queue_size=queue_size, **kwargs)
+ super().__init__(
+ name, out_topic, should_wait=should_wait, queue_size=queue_size, **kwargs
+ )
self._plugin_encoder = JsonEncoder
self._plugin_kwargs = kwargs
self._serializer_kwargs = serializer_kwargs or {}
@@ -118,7 +140,9 @@ def establish(self, repeats: Optional[int] = None, **kwargs):
:param repeats: int: Number of repeats to await connection. None for infinite. Default is None
:return: bool: True if connection established, False otherwise
"""
- self._publisher = self.create_publisher(std_msgs.msg.String, self.out_topic, qos_profile=self.queue_size)
+ self._publisher = self.create_publisher(
+ std_msgs.msg.String, self.out_topic, qos_profile=self.queue_size
+ )
established = self.await_connection(self._publisher, repeats=repeats)
return self.check_establishment(established)
@@ -134,8 +158,12 @@ def publish(self, obj):
return
else:
time.sleep(0.2)
- obj_str = json.dumps(obj, cls=self._plugin_encoder, **self._plugin_kwargs,
- serializer_kwrags=self._serializer_kwargs)
+ obj_str = json.dumps(
+ obj,
+ cls=self._plugin_encoder,
+ **self._plugin_kwargs,
+ serializer_kwrags=self._serializer_kwargs,
+ )
obj_str_msg = std_msgs.msg.String()
obj_str_msg.data = obj_str
self._publisher.publish(obj_str_msg)
@@ -144,8 +172,19 @@ def publish(self, obj):
@Publishers.register("Image", "ros2")
class ROS2ImagePublisher(ROS2Publisher):
- def __init__(self, name: str, out_topic: str, should_wait: bool = True, queue_size: int = QUEUE_SIZE,
- width: int = -1, height: int = -1, rgb: bool = True, fp: bool = False, jpg: bool = False, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ out_topic: str,
+ should_wait: bool = True,
+ queue_size: int = QUEUE_SIZE,
+ width: int = -1,
+ height: int = -1,
+ rgb: bool = True,
+ fp: bool = False,
+ jpg: bool = False,
+ **kwargs,
+ ):
"""
The ImagePublisher using the ROS 2 Image message assuming a numpy array as input.
@@ -159,7 +198,9 @@ def __init__(self, name: str, out_topic: str, should_wait: bool = True, queue_si
:param fp: bool: True if the image is floating point, False if it is integer. Default is False
:param jpg: bool: True if the image should be compressed as JPG. Default is False
"""
- super().__init__(name, out_topic, should_wait=should_wait, queue_size=queue_size, **kwargs)
+ super().__init__(
+ name, out_topic, should_wait=should_wait, queue_size=queue_size, **kwargs
+ )
self.width = width
self.height = height
self.rgb = rgb
@@ -167,13 +208,13 @@ def __init__(self, name: str, out_topic: str, should_wait: bool = True, queue_si
self.jpg = jpg
if self.fp:
- self._encoding = '32FC3' if self.rgb else '32FC1'
+ self._encoding = "32FC3" if self.rgb else "32FC1"
self._type = np.float32
else:
- self._encoding = 'bgr8' if self.rgb else 'mono8'
+ self._encoding = "bgr8" if self.rgb else "mono8"
self._type = np.uint8
if self.jpg:
- self._encoding = 'jpeg'
+ self._encoding = "jpeg"
self._type = np.uint8
self._publisher = None
@@ -189,9 +230,15 @@ def establish(self, repeats: Optional[int] = None, **kwargs):
:return: bool: True if connection established, False otherwise
"""
if self.jpg:
- self._publisher = self.create_publisher(sensor_msgs.msg.CompressedImage, self.out_topic, qos_profile=self.queue_size)
+ self._publisher = self.create_publisher(
+ sensor_msgs.msg.CompressedImage,
+ self.out_topic,
+ qos_profile=self.queue_size,
+ )
else:
- self._publisher = self.create_publisher(sensor_msgs.msg.Image, self.out_topic, qos_profile=self.queue_size)
+ self._publisher = self.create_publisher(
+ sensor_msgs.msg.Image, self.out_topic, qos_profile=self.queue_size
+ )
established = self.await_connection(self._publisher)
return self.check_establishment(established)
@@ -211,23 +258,31 @@ def publish(self, img: np.ndarray):
else:
time.sleep(0.2)
- if 0 < self.width != img.shape[1] or 0 < self.height != img.shape[0] or \
- not ((img.ndim == 2 and not self.rgb) or (img.ndim == 3 and self.rgb and img.shape[2] == 3)):
+ if (
+ 0 < self.width != img.shape[1]
+ or 0 < self.height != img.shape[0]
+ or not (
+ (img.ndim == 2 and not self.rgb)
+ or (img.ndim == 3 and self.rgb and img.shape[2] == 3)
+ )
+ ):
raise ValueError("Incorrect image shape for publisher")
- img = np.require(img, dtype=self._type, requirements='C')
+ img = np.require(img, dtype=self._type, requirements="C")
if self.jpg:
img_msg = sensor_msgs.msg.CompressedImage()
img_msg.header.stamp = rclpy.clock.Clock().now().to_msg()
img_msg.format = "jpeg"
- img_msg.data = np.array(cv2.imencode('.jpg', img)[1]).tobytes()
+ img_msg.data = np.array(cv2.imencode(".jpg", img)[1]).tobytes()
else:
img_msg = sensor_msgs.msg.Image()
img_msg.header.stamp = self.get_clock().now().to_msg()
img_msg.height = img.shape[0]
img_msg.width = img.shape[1]
img_msg.encoding = self._encoding
- img_msg.is_bigendian = img.dtype.byteorder == '>' or (img.dtype.byteorder == '=' and sys.byteorder == 'big')
+ img_msg.is_bigendian = img.dtype.byteorder == ">" or (
+ img.dtype.byteorder == "=" and sys.byteorder == "big"
+ )
img_msg.step = img.strides[0]
img_msg.data = img.tobytes()
self._publisher.publish(img_msg)
@@ -236,8 +291,17 @@ def publish(self, img: np.ndarray):
@Publishers.register("AudioChunk", "ros2")
class ROS2AudioChunkPublisher(ROS2Publisher):
- def __init__(self, name: str, out_topic: str, should_wait: bool = True, queue_size: int = QUEUE_SIZE,
- channels: int = 1, rate: int = 44100, chunk: int = -1, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ out_topic: str,
+ should_wait: bool = True,
+ queue_size: int = QUEUE_SIZE,
+ channels: int = 1,
+ rate: int = 44100,
+ chunk: int = -1,
+ **kwargs,
+ ):
"""
The AudioChunkPublisher using the ROS 2 Audio message assuming a numpy array as input.
@@ -249,7 +313,9 @@ def __init__(self, name: str, out_topic: str, should_wait: bool = True, queue_si
:param rate: int: Sampling rate. Default is 44100
:param chunk: int: Chunk size. Default is -1 meaning that the chunk size is not fixed
"""
- super().__init__(name, out_topic, should_wait=should_wait, queue_size=queue_size, **kwargs)
+ super().__init__(
+ name, out_topic, should_wait=should_wait, queue_size=queue_size, **kwargs
+ )
self.channels = channels
self.rate = rate
self.chunk = chunk
@@ -270,12 +336,18 @@ def establish(self, repeats: Optional[int] = None, **kwargs):
from wrapyfi_ros2_interfaces.msg import ROS2AudioMessage
except ImportError:
import wrapyfi
- logging.error("[ROS 2] Could not import ROS2AudioMessage. "
- "Make sure the ROS 2 services in wrapyfi_extensions/wrapyfi_ros2_interfaces are compiled. "
- "Refer to the documentation for more information: \n" +
- wrapyfi.__doc__ + "ros2_interfaces_lnk.html")
+
+ logging.error(
+ "[ROS 2] Could not import ROS2AudioMessage. "
+ "Make sure the ROS 2 services in wrapyfi_extensions/wrapyfi_ros2_interfaces are compiled. "
+ "Refer to the documentation for more information: \n"
+ + wrapyfi.__doc__
+ + "ros2_interfaces_lnk.html"
+ )
sys.exit(1)
- self._publisher = self.create_publisher(ROS2AudioMessage, self.out_topic, qos_profile=self.queue_size)
+ self._publisher = self.create_publisher(
+ ROS2AudioMessage, self.out_topic, qos_profile=self.queue_size
+ )
self._sound_msg = ROS2AudioMessage()
established = self.await_connection(self._publisher)
return self.check_establishment(established)
@@ -303,15 +375,17 @@ def publish(self, aud: Tuple[np.ndarray, int]):
self.channels = channels if self.channels == -1 else self.channels
if 0 < self.chunk != chunk or 0 < self.channels != channels:
raise ValueError("Incorrect audio shape for publisher")
- aud = np.require(aud, dtype=np.float32, requirements='C')
+ aud = np.require(aud, dtype=np.float32, requirements="C")
aud_msg = self._sound_msg
aud_msg.header.stamp = self.get_clock().now().to_msg()
aud_msg.chunk_size = chunk
aud_msg.channels = channels
aud_msg.sample_rate = rate
- aud_msg.is_bigendian = aud.dtype.byteorder == '>' or (aud.dtype.byteorder == '=' and sys.byteorder == 'big')
- aud_msg.encoding = 'S16BE' if aud_msg.is_bigendian else 'S16LE'
+ aud_msg.is_bigendian = aud.dtype.byteorder == ">" or (
+ aud.dtype.byteorder == "=" and sys.byteorder == "big"
+ )
+ aud_msg.encoding = "S16BE" if aud_msg.is_bigendian else "S16LE"
aud_msg.step = aud.strides[0]
aud_msg.data = aud.tobytes() # (aud * 32767.0).tobytes()
self._publisher.publish(aud_msg)
@@ -327,7 +401,14 @@ def __init__(self, name, out_topic, **kwargs):
@Publishers.register("ROS2Message", "ros2")
class ROS2MessagePublisher(ROS2Publisher):
- def __init__(self, name: str, out_topic: str, should_wait: bool = True, queue_size: int = QUEUE_SIZE, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ out_topic: str,
+ should_wait: bool = True,
+ queue_size: int = QUEUE_SIZE,
+ **kwargs,
+ ):
"""
The ROS2MessagePublisher using the ROS 2 message type determined dynamically.
@@ -336,7 +417,9 @@ def __init__(self, name: str, out_topic: str, should_wait: bool = True, queue_si
:param should_wait: bool: Whether to wait for at least one listener before unblocking the script. Default is True
:param queue_size: int: Queue size for the publisher. Default is 5
"""
- super().__init__(name, out_topic, should_wait=should_wait, queue_size=queue_size, **kwargs)
+ super().__init__(
+ name, out_topic, should_wait=should_wait, queue_size=queue_size, **kwargs
+ )
self._publisher = None
self._msg_type = None
@@ -362,7 +445,9 @@ def establish(self, msg, repeats: Optional[int] = None, **kwargs):
"""
self._msg_type = self.get_message_type(msg)
- self._publisher = self.create_publisher(self._msg_type, self.out_topic, qos_profile=self.queue_size)
+ self._publisher = self.create_publisher(
+ self._msg_type, self.out_topic, qos_profile=self.queue_size
+ )
established = self.await_connection(self._publisher)
return self.check_establishment(established)
diff --git a/wrapyfi/publishers/yarp.py b/wrapyfi/publishers/yarp.py
index 6a8d893..1ac954d 100755
--- a/wrapyfi/publishers/yarp.py
+++ b/wrapyfi/publishers/yarp.py
@@ -19,8 +19,17 @@
class YarpPublisher(Publisher):
- def __init__(self, name: str, out_topic: str, carrier: Literal["tcp", "udp", "mcast"] = "tcp", should_wait: bool = True,
- persistent: bool = True, out_topic_connect: Optional[str] = None, yarp_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ out_topic: str,
+ carrier: Literal["tcp", "udp", "mcast"] = "tcp",
+ should_wait: bool = True,
+ persistent: bool = True,
+ out_topic_connect: Optional[str] = None,
+ yarp_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
Initialize the publisher.
@@ -34,16 +43,22 @@ def __init__(self, name: str, out_topic: str, carrier: Literal["tcp", "udp", "mc
:param yarp_kwargs: dict: Additional kwargs for the Yarp middleware
:param kwargs: dict: Additional kwargs for the publisher
"""
- super().__init__(name, out_topic, carrier=carrier, should_wait=should_wait, **kwargs)
+ super().__init__(
+ name, out_topic, carrier=carrier, should_wait=should_wait, **kwargs
+ )
YarpMiddleware.activate(**yarp_kwargs or {})
self.style = yarp.ContactStyle()
self.style.persistent = persistent
self.style.carrier = self.carrier
- self.out_topic_connect = out_topic + ":out" if out_topic_connect is None else out_topic_connect
+ self.out_topic_connect = (
+ out_topic + ":out" if out_topic_connect is None else out_topic_connect
+ )
- def await_connection(self, port, out_topic: Optional[str] = None, repeats: Optional[int] = None):
+ def await_connection(
+ self, port, out_topic: Optional[str] = None, repeats: Optional[int] = None
+ ):
"""
Wait for at least one subscriber to connect to the publisher.
@@ -85,8 +100,17 @@ def __del__(self):
@Publishers.register("NativeObject", "yarp")
class YarpNativeObjectPublisher(YarpPublisher):
- def __init__(self, name: str, out_topic: str, carrier: Literal["tcp", "udp", "mcast"] = "tcp", should_wait: bool = True,
- persistent: bool = True, out_topic_connect: str = None, serializer_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ out_topic: str,
+ carrier: Literal["tcp", "udp", "mcast"] = "tcp",
+ should_wait: bool = True,
+ persistent: bool = True,
+ out_topic_connect: str = None,
+ serializer_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
The NativeObject publisher using the BufferedPortBottle string construct assuming a combination of python native objects
and numpy arrays as input. Serializes the data (including plugins) using the encoder and sends it as a string.
@@ -100,8 +124,15 @@ def __init__(self, name: str, out_topic: str, carrier: Literal["tcp", "udp", "mc
None appends ':out' to the out_topic. Default is None
:param serializer_kwargs: dict: Additional kwargs for the serializer
"""
- super().__init__(name, out_topic, carrier=carrier, should_wait=should_wait, persistent=persistent,
- out_topic_connect=out_topic_connect, **kwargs)
+ super().__init__(
+ name,
+ out_topic,
+ carrier=carrier,
+ should_wait=should_wait,
+ persistent=persistent,
+ out_topic_connect=out_topic_connect,
+ **kwargs,
+ )
self._plugin_encoder = JsonEncoder
self._plugin_kwargs = kwargs
self._serializer_kwargs = serializer_kwargs or {}
@@ -121,9 +152,13 @@ def establish(self, repeats: Optional[int] = None, **kwargs):
self._port = yarp.BufferedPortBottle()
self._port.open(self.out_topic)
if self.style.persistent:
- self._netconnect = yarp.Network.connect(self.out_topic, self.out_topic_connect, self.style)
+ self._netconnect = yarp.Network.connect(
+ self.out_topic, self.out_topic_connect, self.style
+ )
else:
- self._netconnect = yarp.Network.connect(self.out_topic, self.out_topic_connect, self.carrier)
+ self._netconnect = yarp.Network.connect(
+ self.out_topic, self.out_topic_connect, self.carrier
+ )
established = self.await_connection(self._port, repeats=repeats)
return self.check_establishment(established)
@@ -139,8 +174,12 @@ def publish(self, obj):
return
else:
time.sleep(0.2)
- obj_str = json.dumps(obj, cls=self._plugin_encoder, **self._plugin_kwargs,
- serializer_kwrags=self._serializer_kwargs)
+ obj_str = json.dumps(
+ obj,
+ cls=self._plugin_encoder,
+ **self._plugin_kwargs,
+ serializer_kwrags=self._serializer_kwargs,
+ )
obj_port = self._port.prepare()
obj_port.clear()
obj_port.addString(obj_str)
@@ -150,9 +189,21 @@ def publish(self, obj):
@Publishers.register("Image", "yarp")
class YarpImagePublisher(YarpPublisher):
- def __init__(self, name: str, out_topic: str, carrier: Literal["tcp", "udp", "mcast"] = "tcp", should_wait: bool = True,
- persistent: bool = True, out_topic_connect: Optional[str] = None, width: int = -1, height: int = -1,
- rgb: bool = True, fp: bool = False, jpg: bool = False, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ out_topic: str,
+ carrier: Literal["tcp", "udp", "mcast"] = "tcp",
+ should_wait: bool = True,
+ persistent: bool = True,
+ out_topic_connect: Optional[str] = None,
+ width: int = -1,
+ height: int = -1,
+ rgb: bool = True,
+ fp: bool = False,
+ jpg: bool = False,
+ **kwargs,
+ ):
"""
The Image publisher using the BufferedPortImage construct assuming a numpy array as input.
@@ -169,8 +220,15 @@ def __init__(self, name: str, out_topic: str, carrier: Literal["tcp", "udp", "mc
:param fp: bool: True if the image is floating point, False if it is integer. Default is False
:param jpg: bool: True if the image should be compressed as JPG. Default is False
"""
- super().__init__(name, out_topic, carrier=carrier, should_wait=should_wait, persistent=persistent,
- out_topic_connect=out_topic_connect, **kwargs)
+ super().__init__(
+ name,
+ out_topic,
+ carrier=carrier,
+ should_wait=should_wait,
+ persistent=persistent,
+ out_topic_connect=out_topic_connect,
+ **kwargs,
+ )
self.width = width
self.height = height
self.rgb = rgb
@@ -192,15 +250,27 @@ def establish(self, repeats: Optional[int] = None, **kwargs):
if self.jpg:
self._port = yarp.BufferedPortBottle()
elif self.rgb:
- self._port = yarp.BufferedPortImageRgbFloat() if self.fp else yarp.BufferedPortImageRgb()
+ self._port = (
+ yarp.BufferedPortImageRgbFloat()
+ if self.fp
+ else yarp.BufferedPortImageRgb()
+ )
else:
- self._port = yarp.BufferedPortImageFloat() if self.fp else yarp.BufferedPortImageMono()
+ self._port = (
+ yarp.BufferedPortImageFloat()
+ if self.fp
+ else yarp.BufferedPortImageMono()
+ )
self._type = np.float32 if self.fp else np.uint8
self._port.open(self.out_topic)
if self.style.persistent:
- self._netconnect = yarp.Network.connect(self.out_topic, self.out_topic_connect, self.style)
+ self._netconnect = yarp.Network.connect(
+ self.out_topic, self.out_topic_connect, self.style
+ )
else:
- self._netconnect = yarp.Network.connect(self.out_topic, self.out_topic_connect, self.carrier)
+ self._netconnect = yarp.Network.connect(
+ self.out_topic, self.out_topic_connect, self.carrier
+ )
established = self.await_connection(self._port, repeats=repeats)
return self.check_establishment(established)
@@ -220,16 +290,22 @@ def publish(self, img: np.ndarray):
else:
time.sleep(0.2)
- if 0 < self.width != img.shape[1] or 0 < self.height != img.shape[0] or \
- not ((img.ndim == 2 and not self.rgb) or (img.ndim == 3 and self.rgb and img.shape[2] == 3)):
+ if (
+ 0 < self.width != img.shape[1]
+ or 0 < self.height != img.shape[0]
+ or not (
+ (img.ndim == 2 and not self.rgb)
+ or (img.ndim == 3 and self.rgb and img.shape[2] == 3)
+ )
+ ):
raise ValueError("Incorrect image shape for publisher")
- img = np.require(img, dtype=self._type, requirements='C')
+ img = np.require(img, dtype=self._type, requirements="C")
if self.jpg:
- img_str = np.array(cv2.imencode('.jpg', img)[1]).tostring()
+ img_str = np.array(cv2.imencode(".jpg", img)[1]).tostring()
with io.BytesIO() as memfile:
np.save(memfile, img_str)
- img_str = base64.b64encode(memfile.getvalue()).decode('ascii')
+ img_str = base64.b64encode(memfile.getvalue()).decode("ascii")
img_port = self._port.prepare()
img_port.clear()
img_port.addString(img_str)
@@ -245,9 +321,19 @@ def publish(self, img: np.ndarray):
@Publishers.register("AudioChunk", "yarp")
class YarpAudioChunkPublisher(YarpPublisher):
- def __init__(self, name: str, out_topic: str, carrier: Literal["tcp", "udp", "mcast"] = "tcp", should_wait: bool = True,
- persistent: bool = True, out_topic_connect: Optional[str] = None,
- channels: int = 1, rate: int = 44100, chunk: int = -1, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ out_topic: str,
+ carrier: Literal["tcp", "udp", "mcast"] = "tcp",
+ should_wait: bool = True,
+ persistent: bool = True,
+ out_topic_connect: Optional[str] = None,
+ channels: int = 1,
+ rate: int = 44100,
+ chunk: int = -1,
+ **kwargs,
+ ):
"""
The AudioChunk publisher using the Sound construct assuming a numpy array as input.
@@ -262,8 +348,15 @@ def __init__(self, name: str, out_topic: str, carrier: Literal["tcp", "udp", "mc
:param rate: int: Sampling rate. Default is 44100
:param chunk: int: Chunk size. Default is -1 meaning that the chunk size is not fixed
"""
- super().__init__(name, out_topic, carrier=carrier, should_wait=should_wait, out_topic_connect=out_topic_connect,
- persistent=persistent, **kwargs)
+ super().__init__(
+ name,
+ out_topic,
+ carrier=carrier,
+ should_wait=should_wait,
+ out_topic_connect=out_topic_connect,
+ persistent=persistent,
+ **kwargs,
+ )
self.channels = channels
self.rate = rate
self.chunk = chunk
@@ -283,7 +376,9 @@ def establish(self, repeats: Optional[int] = None, **kwargs):
# create a dummy sound object for transmitting the sound props. This could be cleaner but left for future impl.
self._port = yarp.Port()
self._port.open(self.out_topic)
- self._netconnect = yarp.Network.connect(self.out_topic, self.out_topic_connect, self.carrier)
+ self._netconnect = yarp.Network.connect(
+ self.out_topic, self.out_topic_connect, self.carrier
+ )
self._sound_msg = yarp.Sound()
self._sound_msg.setFrequency(self.rate)
self._sound_msg.resize(self.chunk, self.channels)
@@ -316,10 +411,12 @@ def publish(self, aud: Tuple[np.ndarray, int]):
self.channels = channels if self.channels == -1 else self.channels
if 0 < self.chunk != chunk or 0 < self.channels != channels:
raise ValueError("Incorrect audio shape for publisher")
- aud = np.require(aud, dtype=np.float32, requirements='C')
+ aud = np.require(aud, dtype=np.float32, requirements="C")
for i in range(aud.size):
- self._sound_msg.set(int(aud.data[i] * 32767), i) # Convert float samples to 16-bit int
+ self._sound_msg.set(
+ int(aud.data[i] * 32767), i
+ ) # Convert float samples to 16-bit int
self._port.write(self._sound_msg)
diff --git a/wrapyfi/publishers/zeromq.py b/wrapyfi/publishers/zeromq.py
index 63109d6..d531b16 100644
--- a/wrapyfi/publishers/zeromq.py
+++ b/wrapyfi/publishers/zeromq.py
@@ -22,20 +22,38 @@
PARAM_SUB_PORT = int(os.environ.get("WRAPYFI_ZEROMQ_PARAM_SUB_PORT", 5656))
PARAM_REQREP_PORT = int(os.environ.get("WRAPYFI_ZEROMQ_PARAM_REQREP_PORT", 5659))
PARAM_POLL_INTERVAL = int(os.environ.get("WRAPYFI_ZEROMQ_PARAM_POLL_INTERVAL", 1))
-START_PROXY_BROKER = os.environ.get("WRAPYFI_ZEROMQ_START_PROXY_BROKER", True) != "False"
+START_PROXY_BROKER = (
+ os.environ.get("WRAPYFI_ZEROMQ_START_PROXY_BROKER", True) != "False"
+)
PROXY_BROKER_SPAWN = os.environ.get("WRAPYFI_ZEROMQ_PROXY_BROKER_SPAWN", "process")
-ZEROMQ_PUBSUB_MONITOR_TOPIC = os.environ.get("WRAPYFI_ZEROMQ_PUBSUB_MONITOR_TOPIC", "ZEROMQ/CONNECTIONS")
-ZEROMQ_PUBSUB_MONITOR_LISTENER_SPAWN = os.environ.get("WRAPYFI_ZEROMQ_PUBSUB_MONITOR_LISTENER_SPAWN", "process")
+ZEROMQ_PUBSUB_MONITOR_TOPIC = os.environ.get(
+ "WRAPYFI_ZEROMQ_PUBSUB_MONITOR_TOPIC", "ZEROMQ/CONNECTIONS"
+)
+ZEROMQ_PUBSUB_MONITOR_LISTENER_SPAWN = os.environ.get(
+ "WRAPYFI_ZEROMQ_PUBSUB_MONITOR_LISTENER_SPAWN", "process"
+)
WATCHDOG_POLL_REPEAT = None
class ZeroMQPublisher(Publisher):
- def __init__(self, name: str, out_topic: str, carrier: str = "tcp", should_wait: bool = True,
- socket_ip: str = SOCKET_IP, socket_pub_port: int = SOCKET_PUB_PORT, socket_sub_port: int = SOCKET_SUB_PORT,
- start_proxy_broker: bool = START_PROXY_BROKER, proxy_broker_spawn: str = PROXY_BROKER_SPAWN,
- pubsub_monitor_topic: str = ZEROMQ_PUBSUB_MONITOR_TOPIC,
- pubsub_monitor_listener_spawn: Optional[str] = ZEROMQ_PUBSUB_MONITOR_LISTENER_SPAWN,
- zeromq_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ out_topic: str,
+ carrier: str = "tcp",
+ should_wait: bool = True,
+ socket_ip: str = SOCKET_IP,
+ socket_pub_port: int = SOCKET_PUB_PORT,
+ socket_sub_port: int = SOCKET_SUB_PORT,
+ start_proxy_broker: bool = START_PROXY_BROKER,
+ proxy_broker_spawn: str = PROXY_BROKER_SPAWN,
+ pubsub_monitor_topic: str = ZEROMQ_PUBSUB_MONITOR_TOPIC,
+ pubsub_monitor_listener_spawn: Optional[
+ str
+ ] = ZEROMQ_PUBSUB_MONITOR_LISTENER_SPAWN,
+ zeromq_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
Initialize the publisher and start the proxy broker if necessary.
@@ -54,24 +72,32 @@ def __init__(self, name: str, out_topic: str, carrier: str = "tcp", should_wait:
:param kwargs: Additional kwargs for the publisher
"""
if carrier or carrier != "tcp":
- logging.warning("[ZeroMQ] ZeroMQ does not support other carriers than TCP for PUB/SUB pattern. Using TCP.")
+ logging.warning(
+ "[ZeroMQ] ZeroMQ does not support other carriers than TCP for PUB/SUB pattern. Using TCP."
+ )
carrier = "tcp"
- super().__init__(name, out_topic, carrier=carrier, should_wait=should_wait, **kwargs)
+ super().__init__(
+ name, out_topic, carrier=carrier, should_wait=should_wait, **kwargs
+ )
self.socket_pub_address = f"{carrier}://{socket_ip}:{socket_pub_port}"
self.socket_sub_address = f"{carrier}://{socket_ip}:{socket_sub_port}"
- ZeroMQMiddlewarePubSub.activate(socket_pub_address=self.socket_pub_address,
- socket_sub_address=self.socket_sub_address,
- start_proxy_broker=start_proxy_broker,
- proxy_broker_spawn=proxy_broker_spawn,
- pubsub_monitor_topic=pubsub_monitor_topic,
- pubsub_monitor_listener_spawn=pubsub_monitor_listener_spawn,
- **zeromq_kwargs or {})
+ ZeroMQMiddlewarePubSub.activate(
+ socket_pub_address=self.socket_pub_address,
+ socket_sub_address=self.socket_sub_address,
+ start_proxy_broker=start_proxy_broker,
+ proxy_broker_spawn=proxy_broker_spawn,
+ pubsub_monitor_topic=pubsub_monitor_topic,
+ pubsub_monitor_listener_spawn=pubsub_monitor_listener_spawn,
+ **zeromq_kwargs or {},
+ )
ZeroMQMiddlewarePubSub().shared_monitor_data.add_topic(self.out_topic)
- def await_connection(self, out_topic: Optional[str] = None, repeats: Optional[int] = None):
+ def await_connection(
+ self, out_topic: Optional[str] = None, repeats: Optional[int] = None
+ ):
"""
Wait for the connection to be established.
@@ -90,7 +116,9 @@ def await_connection(self, out_topic: Optional[str] = None, repeats: Optional[in
return True
while repeats > 0 or repeats <= -1:
repeats -= 1
- connected = ZeroMQMiddlewarePubSub().shared_monitor_data.is_connected(out_topic)
+ connected = ZeroMQMiddlewarePubSub().shared_monitor_data.is_connected(
+ out_topic
+ )
if connected:
break
time.sleep(0.02)
@@ -115,8 +143,15 @@ def __del__(self):
@Publishers.register("NativeObject", "zeromq")
class ZeroMQNativeObjectPublisher(ZeroMQPublisher):
- def __init__(self, name: str, out_topic: str, carrier: str = "tcp", should_wait: bool = True,
- serializer_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ out_topic: str,
+ carrier: str = "tcp",
+ should_wait: bool = True,
+ serializer_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
The NativeObjectPublisher using the ZeroMQ message construct assuming a combination of python native objects
and numpy arrays as input. Serializes the data (including plugins) using the encoder and sends it as a string.
@@ -128,7 +163,9 @@ def __init__(self, name: str, out_topic: str, carrier: str = "tcp", should_wait:
:param serializer_kwargs: dict: Additional kwargs for the serializer
:param kwargs: dict: Additional kwargs for the publisher
"""
- super().__init__(name, out_topic, carrier=carrier, should_wait=should_wait, **kwargs)
+ super().__init__(
+ name, out_topic, carrier=carrier, should_wait=should_wait, **kwargs
+ )
self._socket = self._netconnect = None
self._plugin_encoder = JsonEncoder
@@ -148,9 +185,13 @@ def establish(self, repeats: Optional[int] = None, **kwargs):
self._socket = zmq.Context.instance().socket(zmq.PUB)
for socket_property in ZeroMQMiddlewarePubSub().zeromq_kwargs.items():
if isinstance(socket_property[1], str):
- self._socket.setsockopt_string(getattr(zmq, socket_property[0]), socket_property[1])
+ self._socket.setsockopt_string(
+ getattr(zmq, socket_property[0]), socket_property[1]
+ )
else:
- self._socket.setsockopt(getattr(zmq, socket_property[0]), socket_property[1])
+ self._socket.setsockopt(
+ getattr(zmq, socket_property[0]), socket_property[1]
+ )
self._socket.connect(self.socket_sub_address)
self._topic = self.out_topic.encode()
established = self.await_connection(repeats=repeats)
@@ -168,16 +209,31 @@ def publish(self, obj):
return
else:
time.sleep(0.2)
- obj_str = json.dumps(obj, cls=self._plugin_encoder, **self._plugin_kwargs,
- serializer_kwrags=self._serializer_kwargs)
+ obj_str = json.dumps(
+ obj,
+ cls=self._plugin_encoder,
+ **self._plugin_kwargs,
+ serializer_kwrags=self._serializer_kwargs,
+ )
self._socket.send_multipart([self._topic, obj_str.encode()])
@Publishers.register("Image", "zeromq")
class ZeroMQImagePublisher(ZeroMQNativeObjectPublisher):
- def __init__(self, name: str, out_topic: str, carrier: str = "tcp", should_wait: bool = True,
- width: int = -1, height: int = -1, rgb: bool = True, fp: bool = False, jpg: bool = False, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ out_topic: str,
+ carrier: str = "tcp",
+ should_wait: bool = True,
+ width: int = -1,
+ height: int = -1,
+ rgb: bool = True,
+ fp: bool = False,
+ jpg: bool = False,
+ **kwargs,
+ ):
"""
The ImagePublisher using the ZeroMQ message construct assuming a numpy array as input.
@@ -191,7 +247,9 @@ def __init__(self, name: str, out_topic: str, carrier: str = "tcp", should_wait:
:param fp: bool: True if the image is floating point, False if it is integer. Default is False
:param jpg: bool: True if the image should be compressed as JPG. Default is False
"""
- super().__init__(name, out_topic, carrier=carrier, should_wait=should_wait, **kwargs)
+ super().__init__(
+ name, out_topic, carrier=carrier, should_wait=should_wait, **kwargs
+ )
self.width = width
self.height = height
self.rgb = rgb
@@ -215,25 +273,44 @@ def publish(self, img: np.ndarray):
return
else:
time.sleep(0.2)
- if 0 < self.width != img.shape[1] or 0 < self.height != img.shape[0] or \
- not ((img.ndim == 2 and not self.rgb) or (img.ndim == 3 and self.rgb and img.shape[2] == 3)):
+ if (
+ 0 < self.width != img.shape[1]
+ or 0 < self.height != img.shape[0]
+ or not (
+ (img.ndim == 2 and not self.rgb)
+ or (img.ndim == 3 and self.rgb and img.shape[2] == 3)
+ )
+ ):
raise ValueError("Incorrect image shape for publisher")
- if not img.flags['C_CONTIGUOUS']:
+ if not img.flags["C_CONTIGUOUS"]:
img = np.ascontiguousarray(img)
if self.jpg:
- img_str = np.array(cv2.imencode('.jpg', img)[1]).tostring()
+ img_str = np.array(cv2.imencode(".jpg", img)[1]).tostring()
else:
- img_str = json.dumps(img, cls=self._plugin_encoder, **self._plugin_kwargs,
- serializer_kwrags=self._serializer_kwargs).encode()
- img_header = '{timestamp:' + str(time.time()) + '}'
+ img_str = json.dumps(
+ img,
+ cls=self._plugin_encoder,
+ **self._plugin_kwargs,
+ serializer_kwrags=self._serializer_kwargs,
+ ).encode()
+ img_header = "{timestamp:" + str(time.time()) + "}"
self._socket.send_multipart([self._topic, img_header.encode(), img_str])
@Publishers.register("AudioChunk", "zeromq")
class ZeroMQAudioChunkPublisher(ZeroMQNativeObjectPublisher):
- def __init__(self, name: str, out_topic: str, carrier: str = "tcp", should_wait: bool = True,
- channels: int = 1, rate: int = 44100, chunk: int = -1, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ out_topic: str,
+ carrier: str = "tcp",
+ should_wait: bool = True,
+ channels: int = 1,
+ rate: int = 44100,
+ chunk: int = -1,
+ **kwargs,
+ ):
"""
The AudioChunkPublisher using the ZeroMQ message construct assuming a numpy array as input.
@@ -245,8 +322,18 @@ def __init__(self, name: str, out_topic: str, carrier: str = "tcp", should_wait:
:param rate: int: Sampling rate. Default is 44100
:param chunk: int: Chunk size. Default is -1 meaning that the chunk size is not fixed
"""
- super().__init__(name, out_topic, carrier=carrier, should_wait=should_wait,
- width=chunk, height=channels, rgb=False, fp=True, jpg=False, **kwargs)
+ super().__init__(
+ name,
+ out_topic,
+ carrier=carrier,
+ should_wait=should_wait,
+ width=chunk,
+ height=channels,
+ rgb=False,
+ fp=True,
+ jpg=False,
+ **kwargs,
+ )
self.channels = channels
self.rate = rate
self.chunk = chunk
@@ -274,11 +361,15 @@ def publish(self, aud: Tuple[np.ndarray, int]):
self.channels = channels if self.channels == -1 else self.channels
if 0 < self.chunk != chunk or 0 < self.channels != channels:
raise ValueError("Incorrect audio shape for publisher")
- aud = np.require(aud, dtype=np.float32, requirements='C')
-
- aud_str = json.dumps((chunk, channels, rate, aud), cls=self._plugin_encoder, **self._plugin_kwargs,
- serializer_kwrags=self._serializer_kwargs).encode()
- aud_header = '{timestamp:' + str(time.time()) + '}'
+ aud = np.require(aud, dtype=np.float32, requirements="C")
+
+ aud_str = json.dumps(
+ (chunk, channels, rate, aud),
+ cls=self._plugin_encoder,
+ **self._plugin_kwargs,
+ serializer_kwrags=self._serializer_kwargs,
+ ).encode()
+ aud_header = "{timestamp:" + str(time.time()) + "}"
self._socket.send_multipart([self._topic, aud_header.encode(), aud_str])
diff --git a/wrapyfi/servers/__init__.py b/wrapyfi/servers/__init__.py
index ef2329e..9ef2390 100755
--- a/wrapyfi/servers/__init__.py
+++ b/wrapyfi/servers/__init__.py
@@ -6,9 +6,18 @@
@Servers.register("MMO", "fallback")
class FallbackServer(Server):
- def __init__(self, name: str, out_topic: str, carrier: str = "", missing_middleware_object: str = "", **kwargs):
- logging.warning(f"Fallback server employed due to missing middleware or object type: "
- f"{missing_middleware_object}")
+ def __init__(
+ self,
+ name: str,
+ out_topic: str,
+ carrier: str = "",
+ missing_middleware_object: str = "",
+ **kwargs,
+ ):
+ logging.warning(
+ f"Fallback server employed due to missing middleware or object type: "
+ f"{missing_middleware_object}"
+ )
Server.__init__(self, name, out_topic, carrier=carrier, **kwargs)
self.missing_middleware_object = missing_middleware_object
diff --git a/wrapyfi/servers/ros.py b/wrapyfi/servers/ros.py
index c5c1fda..b22753b 100755
--- a/wrapyfi/servers/ros.py
+++ b/wrapyfi/servers/ros.py
@@ -13,13 +13,24 @@
import sensor_msgs.msg
from wrapyfi.connect.servers import Server, Servers
-from wrapyfi.middlewares.ros import ROSMiddleware, ROSNativeObjectService, ROSImageService
+from wrapyfi.middlewares.ros import (
+ ROSMiddleware,
+ ROSNativeObjectService,
+ ROSImageService,
+)
from wrapyfi.encoders import JsonEncoder, JsonDecodeHook
class ROSServer(Server):
- def __init__(self, name: str, out_topic: str, carrier: str = "tcp", ros_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ out_topic: str,
+ carrier: str = "tcp",
+ ros_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
Initialize the server.
@@ -30,7 +41,9 @@ def __init__(self, name: str, out_topic: str, carrier: str = "tcp", ros_kwargs:
:param kwargs: dict: Additional kwargs for the server
"""
if carrier or carrier != "tcp":
- logging.warning("[ROS] ROS does not support other carriers than TCP for REQ/REP pattern. Using TCP.")
+ logging.warning(
+ "[ROS] ROS does not support other carriers than TCP for REQ/REP pattern. Using TCP."
+ )
carrier = "tcp"
super().__init__(name, out_topic, carrier=carrier, **kwargs)
ROSMiddleware.activate(**ros_kwargs or {})
@@ -52,8 +65,15 @@ class ROSNativeObjectServer(ROSServer):
SEND_QUEUE = queue.Queue(maxsize=1)
RECEIVE_QUEUE = queue.Queue(maxsize=1)
- def __init__(self, name: str, out_topic: str, carrier: str = "tcp",
- serializer_kwargs: Optional[dict] = None, deserializer_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ out_topic: str,
+ carrier: str = "tcp",
+ serializer_kwargs: Optional[dict] = None,
+ deserializer_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
Specific server handling native Python objects, serializing them to JSON strings for transmission.
@@ -77,7 +97,9 @@ def establish(self):
"""
Establish the connection to the server.
"""
- self._server = rospy.Service(self.out_topic, ROSNativeObjectService, self._service_callback)
+ self._server = rospy.Service(
+ self.out_topic, ROSNativeObjectService, self._service_callback
+ )
self.established = True
def await_request(self, *args, **kwargs):
@@ -94,7 +116,11 @@ def await_request(self, *args, **kwargs):
self.establish()
try:
request = ROSNativeObjectServer.RECEIVE_QUEUE.get(block=True)
- [args, kwargs] = json.loads(request.data, object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs)
+ [args, kwargs] = json.loads(
+ request.data,
+ object_hook=self._plugin_decoder_hook,
+ **self._deserializer_kwargs,
+ )
return args, kwargs
except rospy.ServiceException as e:
logging.error("[ROS] Service call failed: %s" % e)
@@ -118,14 +144,20 @@ def reply(self, obj):
:param obj: Any: The Python object to be serialized and sent
"""
try:
- obj_str = json.dumps(obj, cls=self._plugin_encoder, **self._plugin_kwargs,
- serializer_kwrags=self._serializer_kwargs)
+ obj_str = json.dumps(
+ obj,
+ cls=self._plugin_encoder,
+ **self._plugin_kwargs,
+ serializer_kwrags=self._serializer_kwargs,
+ )
obj_msg = std_msgs.msg.String()
obj_msg.data = obj_str
ROSNativeObjectServer.SEND_QUEUE.put(obj_msg, block=False)
except queue.Full:
- logging.warning(f"[ROS] Discarding data because queue is full. "
- f"This happened due to bad synchronization in {self.__name__}")
+ logging.warning(
+ f"[ROS] Discarding data because queue is full. "
+ f"This happened due to bad synchronization in {self.__name__}"
+ )
@Servers.register("Image", "ros")
@@ -133,9 +165,18 @@ class ROSImageServer(ROSServer):
SEND_QUEUE = queue.Queue(maxsize=1)
RECEIVE_QUEUE = queue.Queue(maxsize=1)
- def __init__(self, name: str, out_topic: str, carrier: str = "tcp",
- width: int = -1, height: int = -1, rgb: bool = True, fp: bool = False,
- deserializer_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ out_topic: str,
+ carrier: str = "tcp",
+ width: int = -1,
+ height: int = -1,
+ rgb: bool = True,
+ fp: bool = False,
+ deserializer_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
Specific server handling native Python objects, serializing them to JSON strings for transmission.
@@ -150,7 +191,9 @@ def __init__(self, name: str, out_topic: str, carrier: str = "tcp",
"""
super().__init__(name, out_topic, carrier=carrier, **kwargs)
if "jpg" in kwargs:
- logging.warning("[ROS] ROS currently does not support JPG encoding in REQ/REP. Using raw image.")
+ logging.warning(
+ "[ROS] ROS currently does not support JPG encoding in REQ/REP. Using raw image."
+ )
kwargs.pop("jpg")
self.width = width
self.height = height
@@ -158,10 +201,10 @@ def __init__(self, name: str, out_topic: str, carrier: str = "tcp",
self.fp = fp
if self.fp:
- self._encoding = '32FC3' if self.rgb else '32FC1'
+ self._encoding = "32FC3" if self.rgb else "32FC1"
self._type = np.float32
else:
- self._encoding = 'bgr8' if self.rgb else 'mono8'
+ self._encoding = "bgr8" if self.rgb else "mono8"
self._type = np.uint8
self._plugin_kwargs = kwargs
@@ -174,7 +217,9 @@ def establish(self):
"""
Establish the connection to the server.
"""
- self._server = rospy.Service(self.out_topic, ROSImageService, self._service_callback)
+ self._server = rospy.Service(
+ self.out_topic, ROSImageService, self._service_callback
+ )
self.established = True
def await_request(self, *args, **kwargs):
@@ -191,7 +236,11 @@ def await_request(self, *args, **kwargs):
self.establish()
try:
request = ROSImageServer.RECEIVE_QUEUE.get(block=True)
- [args, kwargs] = json.loads(request.data, object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs)
+ [args, kwargs] = json.loads(
+ request.data,
+ object_hook=self._plugin_decoder_hook,
+ **self._deserializer_kwargs,
+ )
return args, kwargs
except rospy.ServiceException as e:
logging.error("[ROS] Service call failed: %s" % e)
@@ -215,22 +264,32 @@ def reply(self, img: np.ndarray):
:param img: np.ndarray: Image to publish
"""
try:
- if 0 < self.width != img.shape[1] or 0 < self.height != img.shape[0] or \
- not ((img.ndim == 2 and not self.rgb) or (img.ndim == 3 and self.rgb and img.shape[2] == 3)):
+ if (
+ 0 < self.width != img.shape[1]
+ or 0 < self.height != img.shape[0]
+ or not (
+ (img.ndim == 2 and not self.rgb)
+ or (img.ndim == 3 and self.rgb and img.shape[2] == 3)
+ )
+ ):
raise ValueError("Incorrect image shape for publisher")
- img = np.require(img, dtype=self._type, requirements='C')
+ img = np.require(img, dtype=self._type, requirements="C")
img_msg = sensor_msgs.msg.Image()
img_msg.header.stamp = rospy.Time.now()
img_msg.height = img.shape[0]
img_msg.width = img.shape[1]
img_msg.encoding = self._encoding
- img_msg.is_bigendian = img.dtype.byteorder == '>' or (img.dtype.byteorder == '=' and sys.byteorder == 'big')
+ img_msg.is_bigendian = img.dtype.byteorder == ">" or (
+ img.dtype.byteorder == "=" and sys.byteorder == "big"
+ )
img_msg.step = img.strides[0]
img_msg.data = img.tobytes()
ROSImageServer.SEND_QUEUE.put(img_msg, block=False)
except queue.Full:
- logging.warning(f"[ROS] Discarding data because queue is full. "
- f"This happened due to bad synchronization in {self.__name__}")
+ logging.warning(
+ f"[ROS] Discarding data because queue is full. "
+ f"This happened due to bad synchronization in {self.__name__}"
+ )
@Servers.register("AudioChunk", "ros")
@@ -238,9 +297,17 @@ class ROSAudioChunkServer(ROSServer):
SEND_QUEUE = queue.Queue(maxsize=1)
RECEIVE_QUEUE = queue.Queue(maxsize=1)
- def __init__(self, name: str, out_topic: str, carrier: str = "tcp",
- channels: int = 1, rate: int = 44100, chunk: int = -1,
- deserializer_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ out_topic: str,
+ carrier: str = "tcp",
+ channels: int = 1,
+ rate: int = 44100,
+ chunk: int = -1,
+ deserializer_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
Specific server handling audio data as numpy arrays.
@@ -271,12 +338,18 @@ def establish(self):
from wrapyfi_ros_interfaces.srv import ROSAudioService
except ImportError:
import wrapyfi
- logging.error("[ROS] Could not import ROSAudioService. "
- "Make sure the ROS services in wrapyfi_extensions/wrapyfi_ros_interfaces are compiled. "
- "Refer to the documentation for more information: \n" +
- wrapyfi.__doc__ + "ros_interfaces_lnk.html")
+
+ logging.error(
+ "[ROS] Could not import ROSAudioService. "
+ "Make sure the ROS services in wrapyfi_extensions/wrapyfi_ros_interfaces are compiled. "
+ "Refer to the documentation for more information: \n"
+ + wrapyfi.__doc__
+ + "ros_interfaces_lnk.html"
+ )
sys.exit(1)
- self._server = rospy.Service(self.out_topic, ROSAudioService, self._service_callback)
+ self._server = rospy.Service(
+ self.out_topic, ROSAudioService, self._service_callback
+ )
self._rep_msg = ROSAudioService._response_class().response
self.established = True
@@ -294,8 +367,11 @@ def await_request(self, *args, **kwargs):
self.establish()
try:
request = ROSAudioChunkServer.RECEIVE_QUEUE.get(block=True)
- [args, kwargs] = json.loads(request.request, object_hook=self._plugin_decoder_hook,
- **self._deserializer_kwargs)
+ [args, kwargs] = json.loads(
+ request.request,
+ object_hook=self._plugin_decoder_hook,
+ **self._deserializer_kwargs,
+ )
return args, kwargs
except rospy.ServiceException as e:
logging.error("[ROS] Service call failed: %s" % e)
@@ -329,18 +405,22 @@ def reply(self, aud: Tuple[np.ndarray, int]):
self.channels = channels if self.channels == -1 else self.channels
if 0 < self.chunk != chunk or 0 < self.channels != channels:
raise ValueError("Incorrect audio shape for publisher")
- aud = np.require(aud, dtype=np.float32, requirements='C')
+ aud = np.require(aud, dtype=np.float32, requirements="C")
aud_msg = self._rep_msg
aud_msg.header.stamp = rospy.Time.now()
aud_msg.chunk_size = chunk
aud_msg.channels = channels
aud_msg.sample_rate = rate
- aud_msg.is_bigendian = aud.dtype.byteorder == '>' or (aud.dtype.byteorder == '=' and sys.byteorder == 'big')
- aud_msg.encoding = 'S16BE' if aud_msg.is_bigendian else 'S16LE'
+ aud_msg.is_bigendian = aud.dtype.byteorder == ">" or (
+ aud.dtype.byteorder == "=" and sys.byteorder == "big"
+ )
+ aud_msg.encoding = "S16BE" if aud_msg.is_bigendian else "S16LE"
aud_msg.step = aud.strides[0]
aud_msg.data = aud.tobytes() # (aud * 32767.0).tobytes()
ROSAudioChunkServer.SEND_QUEUE.put(aud_msg, block=False)
except queue.Full:
- logging.warning(f"[ROS] Discarding data because queue is full. "
- f"This happened due to bad synchronization in {self.__name__}")
+ logging.warning(
+ f"[ROS] Discarding data because queue is full. "
+ f"This happened due to bad synchronization in {self.__name__}"
+ )
diff --git a/wrapyfi/servers/ros2.py b/wrapyfi/servers/ros2.py
index 89f16ae..7ab1eda 100755
--- a/wrapyfi/servers/ros2.py
+++ b/wrapyfi/servers/ros2.py
@@ -22,7 +22,9 @@
class ROS2Server(Server, Node):
- def __init__(self, name: str, out_topic: str, ros2_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self, name: str, out_topic: str, ros2_kwargs: Optional[dict] = None, **kwargs
+ ):
"""
Initialize the server.
@@ -34,7 +36,8 @@ def __init__(self, name: str, out_topic: str, ros2_kwargs: Optional[dict] = None
carrier = "tcp"
if "carrier" in kwargs and kwargs["carrier"] not in ["", None]:
logging.warning(
- "[ROS 2] ROS 2 currently does not support explicit carrier setting for REQ/REP pattern. Using TCP.")
+ "[ROS 2] ROS 2 currently does not support explicit carrier setting for REQ/REP pattern. Using TCP."
+ )
if "carrier" in kwargs:
del kwargs["carrier"]
@@ -62,8 +65,14 @@ class ROS2NativeObjectServer(ROS2Server):
SEND_QUEUE = queue.Queue(maxsize=1)
RECEIVE_QUEUE = queue.Queue(maxsize=1)
- def __init__(self, name: str, out_topic: str,
- serializer_kwargs: Optional[dict] = None, deserializer_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ out_topic: str,
+ serializer_kwargs: Optional[dict] = None,
+ deserializer_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
Specific server handling native Python objects, serializing them to JSON strings for transmission.
@@ -89,13 +98,19 @@ def establish(self):
from wrapyfi_ros2_interfaces.srv import ROS2NativeObjectService
except ImportError:
import wrapyfi
- logging.error("[ROS 2] Could not import ROS2NativeObjectService. "
- "Make sure the ROS 2 services in wrapyfi_extensions/wrapyfi_ros2_interfaces are compiled. "
- "Refer to the documentation for more information: \n" +
- wrapyfi.__doc__ + "ros2_interfaces_lnk.html")
+
+ logging.error(
+ "[ROS 2] Could not import ROS2NativeObjectService. "
+ "Make sure the ROS 2 services in wrapyfi_extensions/wrapyfi_ros2_interfaces are compiled. "
+ "Refer to the documentation for more information: \n"
+ + wrapyfi.__doc__
+ + "ros2_interfaces_lnk.html"
+ )
sys.exit(1)
- self._server = self.create_service(ROS2NativeObjectService, self.out_topic, self._service_callback)
+ self._server = self.create_service(
+ ROS2NativeObjectService, self.out_topic, self._service_callback
+ )
self._req_msg = ROS2NativeObjectService.Request()
self._rep_msg = ROS2NativeObjectService.Response()
@@ -114,13 +129,18 @@ def await_request(self, *args, **kwargs):
if not self.established:
self.establish()
try:
- self._background_callback = threading.Thread(name='ros2_server', target=rclpy.spin_once,
- args=(self,), kwargs={})
+ self._background_callback = threading.Thread(
+ name="ros2_server", target=rclpy.spin_once, args=(self,), kwargs={}
+ )
self._background_callback.setDaemon(True)
self._background_callback.start()
request = ROS2NativeObjectServer.RECEIVE_QUEUE.get(block=True)
- [args, kwargs] = json.loads(request.request, object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs)
+ [args, kwargs] = json.loads(
+ request.request,
+ object_hook=self._plugin_decoder_hook,
+ **self._deserializer_kwargs,
+ )
return args, kwargs
except Exception as e:
logging.error("[ROS 2] Service call failed %s" % e)
@@ -145,13 +165,19 @@ def reply(self, obj):
:param obj: Any: The Python object to be serialized and sent
"""
try:
- obj_str = json.dumps(obj, cls=self._plugin_encoder, **self._plugin_kwargs,
- serializer_kwrags=self._serializer_kwargs)
+ obj_str = json.dumps(
+ obj,
+ cls=self._plugin_encoder,
+ **self._plugin_kwargs,
+ serializer_kwrags=self._serializer_kwargs,
+ )
self._rep_msg.response = obj_str
ROS2NativeObjectServer.SEND_QUEUE.put(self._rep_msg, block=False)
except queue.Full:
- logging.warning(f"[ROS 2] Discarding data because queue is full. "
- f"This happened due to bad synchronization in {self.__name__}")
+ logging.warning(
+ f"[ROS 2] Discarding data because queue is full. "
+ f"This happened due to bad synchronization in {self.__name__}"
+ )
@Servers.register("Image", "ros2")
@@ -159,9 +185,18 @@ class ROS2ImageServer(ROS2Server):
SEND_QUEUE = queue.Queue(maxsize=1)
RECEIVE_QUEUE = queue.Queue(maxsize=1)
- def __init__(self, name: str, out_topic: str,
- width: int = -1, height: int = -1, rgb: bool = True, fp: bool = False, jpg: bool = False,
- deserializer_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ out_topic: str,
+ width: int = -1,
+ height: int = -1,
+ rgb: bool = True,
+ fp: bool = False,
+ jpg: bool = False,
+ deserializer_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
Specific server handling native Python objects, serializing them to JSON strings for transmission.
@@ -187,13 +222,13 @@ def __init__(self, name: str, out_topic: str,
self.jpg = jpg
if self.fp:
- self._encoding = '32FC3' if self.rgb else '32FC1'
+ self._encoding = "32FC3" if self.rgb else "32FC1"
self._type = np.float32
else:
- self._encoding = 'bgr8' if self.rgb else 'mono8'
+ self._encoding = "bgr8" if self.rgb else "mono8"
self._type = np.uint8
if self.jpg:
- self._encoding = 'jpeg'
+ self._encoding = "jpeg"
self._type = np.uint8
self._server = None
@@ -203,20 +238,31 @@ def establish(self):
Establish the connection to the server.
"""
try:
- from wrapyfi_ros2_interfaces.srv import ROS2ImageService, ROS2CompressedImageService
+ from wrapyfi_ros2_interfaces.srv import (
+ ROS2ImageService,
+ ROS2CompressedImageService,
+ )
except ImportError:
import wrapyfi
- logging.error("[ROS 2] Could not import ROS2NativeObjectService. "
- "Make sure the ROS 2 services in wrapyfi_extensions/wrapyfi_ros2_interfaces are compiled. "
- "Refer to the documentation for more information: \n" +
- wrapyfi.__doc__ + "ros2_interfaces_lnk.html")
+
+ logging.error(
+ "[ROS 2] Could not import ROS2NativeObjectService. "
+ "Make sure the ROS 2 services in wrapyfi_extensions/wrapyfi_ros2_interfaces are compiled. "
+ "Refer to the documentation for more information: \n"
+ + wrapyfi.__doc__
+ + "ros2_interfaces_lnk.html"
+ )
sys.exit(1)
if self.jpg:
- self._server = self.create_service(ROS2CompressedImageService, self.out_topic, self._service_callback)
+ self._server = self.create_service(
+ ROS2CompressedImageService, self.out_topic, self._service_callback
+ )
self._req_msg = ROS2CompressedImageService.Request()
self._rep_msg = ROS2CompressedImageService.Response()
else:
- self._server = self.create_service(ROS2ImageService, self.out_topic, self._service_callback)
+ self._server = self.create_service(
+ ROS2ImageService, self.out_topic, self._service_callback
+ )
self._req_msg = ROS2ImageService.Request()
self._rep_msg = ROS2ImageService.Response()
self.established = True
@@ -234,13 +280,18 @@ def await_request(self, *args, **kwargs):
if not self.established:
self.establish()
try:
- self._background_callback = threading.Thread(name='ros2_server', target=rclpy.spin_once,
- args=(self,), kwargs={})
+ self._background_callback = threading.Thread(
+ name="ros2_server", target=rclpy.spin_once, args=(self,), kwargs={}
+ )
self._background_callback.setDaemon(True)
self._background_callback.start()
request = ROS2ImageServer.RECEIVE_QUEUE.get(block=True)
- [args, kwargs] = json.loads(request.request, object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs)
+ [args, kwargs] = json.loads(
+ request.request,
+ object_hook=self._plugin_decoder_hook,
+ **self._deserializer_kwargs,
+ )
return args, kwargs
except Exception as e:
logging.error("[ROS 2] Service call failed %s" % e)
@@ -265,28 +316,38 @@ def reply(self, img: np.ndarray):
:param img: np.ndarray: Image to send formatted as a cv2 image - np.ndarray[img_height, img_width, channels]
"""
try:
- if 0 < self.width != img.shape[1] or 0 < self.height != img.shape[0] or \
- not ((img.ndim == 2 and not self.rgb) or (img.ndim == 3 and self.rgb and img.shape[2] == 3)):
+ if (
+ 0 < self.width != img.shape[1]
+ or 0 < self.height != img.shape[0]
+ or not (
+ (img.ndim == 2 and not self.rgb)
+ or (img.ndim == 3 and self.rgb and img.shape[2] == 3)
+ )
+ ):
raise ValueError("Incorrect image shape for publisher")
- img = np.require(img, dtype=self._type, requirements='C')
+ img = np.require(img, dtype=self._type, requirements="C")
img_msg = self._rep_msg.response
if self.jpg:
img_msg.header.stamp = rclpy.clock.Clock().now().to_msg()
img_msg.format = "jpeg"
- img_msg.data = np.array(cv2.imencode('.jpg', img)[1]).tobytes()
+ img_msg.data = np.array(cv2.imencode(".jpg", img)[1]).tobytes()
else:
img_msg.header.stamp = rclpy.clock.Clock().now().to_msg()
img_msg.height = img.shape[0]
img_msg.width = img.shape[1]
img_msg.encoding = self._encoding
- img_msg.is_bigendian = img.dtype.byteorder == '>' or (img.dtype.byteorder == '=' and sys.byteorder == 'big')
+ img_msg.is_bigendian = img.dtype.byteorder == ">" or (
+ img.dtype.byteorder == "=" and sys.byteorder == "big"
+ )
img_msg.step = img.strides[0]
img_msg.data = img.tobytes()
self._rep_msg.response = img_msg
ROS2ImageServer.SEND_QUEUE.put(self._rep_msg, block=False)
except queue.Full:
- logging.warning(f"[ROS 2] Discarding data because queue is full. "
- f"This happened due to bad synchronization in {self.__name__}")
+ logging.warning(
+ f"[ROS 2] Discarding data because queue is full. "
+ f"This happened due to bad synchronization in {self.__name__}"
+ )
@Servers.register("AudioChunk", "ros2")
@@ -294,9 +355,16 @@ class ROS2AudioChunkServer(ROS2Server):
SEND_QUEUE = queue.Queue(maxsize=1)
RECEIVE_QUEUE = queue.Queue(maxsize=1)
- def __init__(self, name: str, out_topic: str,
- channels: int = 1, rate: int = 44100, chunk: int = -1,
- deserializer_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ out_topic: str,
+ channels: int = 1,
+ rate: int = 44100,
+ chunk: int = -1,
+ deserializer_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
Specific server handling audio data as numpy arrays.
@@ -325,12 +393,18 @@ def establish(self):
from wrapyfi_ros2_interfaces.srv import ROS2AudioService
except ImportError:
import wrapyfi
- logging.error("[ROS 2] Could not import ROS2AudioService. "
- "Make sure the ROS 2 services in wrapyfi_extensions/wrapyfi_ros2_interfaces are compiled. "
- "Refer to the documentation for more information: \n" +
- wrapyfi.__doc__ + "ros2_interfaces_lnk.html")
+
+ logging.error(
+ "[ROS 2] Could not import ROS2AudioService. "
+ "Make sure the ROS 2 services in wrapyfi_extensions/wrapyfi_ros2_interfaces are compiled. "
+ "Refer to the documentation for more information: \n"
+ + wrapyfi.__doc__
+ + "ros2_interfaces_lnk.html"
+ )
sys.exit(1)
- self._server = self.create_service(ROS2AudioService, self.out_topic, self._service_callback)
+ self._server = self.create_service(
+ ROS2AudioService, self.out_topic, self._service_callback
+ )
self._req_msg = ROS2AudioService.Request()
self._rep_msg = ROS2AudioService.Response()
self.established = True
@@ -348,13 +422,18 @@ def await_request(self, *args, **kwargs):
if not self.established:
self.establish()
try:
- self._background_callback = threading.Thread(name='ros2_server', target=rclpy.spin_once,
- args=(self,), kwargs={})
+ self._background_callback = threading.Thread(
+ name="ros2_server", target=rclpy.spin_once, args=(self,), kwargs={}
+ )
self._background_callback.setDaemon(True)
self._background_callback.start()
request = ROS2AudioChunkServer.RECEIVE_QUEUE.get(block=True)
- [args, kwargs] = json.loads(request.request, object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs)
+ [args, kwargs] = json.loads(
+ request.request,
+ object_hook=self._plugin_decoder_hook,
+ **self._deserializer_kwargs,
+ )
return args, kwargs
except Exception as e:
logging.error("[ROS 2] Service call failed %s" % e)
@@ -389,19 +468,23 @@ def reply(self, aud: Tuple[np.ndarray, int]):
self.channels = channels if self.channels == -1 else self.channels
if 0 < self.chunk != chunk or 0 < self.channels != channels:
raise ValueError("Incorrect audio shape for publisher")
- aud = np.require(aud, dtype=np.float32, requirements='C')
+ aud = np.require(aud, dtype=np.float32, requirements="C")
aud_msg = self._rep_msg.response
aud_msg.header.stamp = self.get_clock().now().to_msg()
aud_msg.chunk_size = chunk
aud_msg.channels = channels
aud_msg.sample_rate = rate
- aud_msg.is_bigendian = aud.dtype.byteorder == '>' or (aud.dtype.byteorder == '=' and sys.byteorder == 'big')
- aud_msg.encoding = 'S16BE' if aud_msg.is_bigendian else 'S16LE'
+ aud_msg.is_bigendian = aud.dtype.byteorder == ">" or (
+ aud.dtype.byteorder == "=" and sys.byteorder == "big"
+ )
+ aud_msg.encoding = "S16BE" if aud_msg.is_bigendian else "S16LE"
aud_msg.step = aud.strides[0]
aud_msg.data = aud.tobytes() # (aud * 32767.0).tobytes()
self._rep_msg.response = aud_msg
ROS2AudioChunkServer.SEND_QUEUE.put(self._rep_msg, block=False)
except queue.Full:
- logging.warning(f"[ROS 2] Discarding data because queue is full. "
- f"This happened due to bad synchronization in {self.__name__}")
+ logging.warning(
+ f"[ROS 2] Discarding data because queue is full. "
+ f"This happened due to bad synchronization in {self.__name__}"
+ )
diff --git a/wrapyfi/servers/yarp.py b/wrapyfi/servers/yarp.py
index ed67075..cdc0c53 100755
--- a/wrapyfi/servers/yarp.py
+++ b/wrapyfi/servers/yarp.py
@@ -12,9 +12,16 @@
class YarpServer(Server):
- def __init__(self, name: str, out_topic: str, carrier: Literal["tcp", "udp", "mcast"] = "tcp",
- out_topic_connect: Optional[str] = None, persistent: bool = True,
- yarp_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ out_topic: str,
+ carrier: Literal["tcp", "udp", "mcast"] = "tcp",
+ out_topic_connect: Optional[str] = None,
+ persistent: bool = True,
+ yarp_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
Initialize the server.
@@ -26,7 +33,13 @@ def __init__(self, name: str, out_topic: str, carrier: Literal["tcp", "udp", "mc
:param yarp_kwargs: dict: Additional kwargs for the Yarp middleware
:param kwargs: dict: Additional kwargs for the server
"""
- super().__init__(name, out_topic, carrier=carrier, out_topic_connect=out_topic_connect, **kwargs)
+ super().__init__(
+ name,
+ out_topic,
+ carrier=carrier,
+ out_topic_connect=out_topic_connect,
+ **kwargs,
+ )
YarpMiddleware.activate(**yarp_kwargs or {})
self.style = yarp.ContactStyle()
self.style.persistent = persistent
@@ -49,9 +62,17 @@ def __del__(self):
@Servers.register("NativeObject", "yarp")
class YarpNativeObjectServer(YarpServer):
- def __init__(self, name: str, out_topic: str, carrier: Literal["tcp", "udp", "mcast"] = "tcp",
- out_topic_connect: Optional[str] = None, persistent: bool = True,
- serializer_kwargs: Optional[dict] = None, deserializer_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ out_topic: str,
+ carrier: Literal["tcp", "udp", "mcast"] = "tcp",
+ out_topic_connect: Optional[str] = None,
+ persistent: bool = True,
+ serializer_kwargs: Optional[dict] = None,
+ deserializer_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
Specific server handling native Python objects, serializing them to JSON strings for transmission.
@@ -64,7 +85,14 @@ def __init__(self, name: str, out_topic: str, carrier: Literal["tcp", "udp", "mc
:param serializer_kwargs: dict: Additional kwargs for the serializer
:param deserializer_kwargs: dict: Additional kwargs for the deserializer
"""
- super().__init__(name, out_topic, carrier=carrier, out_topic_connect=out_topic_connect, persistent=persistent, **kwargs)
+ super().__init__(
+ name,
+ out_topic,
+ carrier=carrier,
+ out_topic_connect=out_topic_connect,
+ persistent=persistent,
+ **kwargs,
+ )
self._plugin_encoder = JsonEncoder
self._plugin_kwargs = kwargs
self._serializer_kwargs = serializer_kwargs or {}
@@ -80,11 +108,17 @@ def establish(self):
self._port = yarp.RpcServer()
self._port.open(self.out_topic)
if self.style.persistent:
- self._netconnect = yarp.Network.connect(self.out_topic, self.out_topic_connect, self.style)
+ self._netconnect = yarp.Network.connect(
+ self.out_topic, self.out_topic_connect, self.style
+ )
else:
- self._netconnect = yarp.Network.connect(self.out_topic, self.out_topic_connect, self.carrier)
+ self._netconnect = yarp.Network.connect(
+ self.out_topic, self.out_topic_connect, self.carrier
+ )
- self._netconnect = yarp.Network.connect(self.out_topic, self.out_topic_connect, self.carrier)
+ self._netconnect = yarp.Network.connect(
+ self.out_topic, self.out_topic_connect, self.carrier
+ )
if self.persistent:
self.established = True
@@ -106,7 +140,11 @@ def await_request(self, *args, **kwargs):
request = False
while not request:
request = self._port.read(obj_msg, True)
- [args, kwargs] = json.loads(obj_msg.get(0).asString(), object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs)
+ [args, kwargs] = json.loads(
+ obj_msg.get(0).asString(),
+ object_hook=self._plugin_decoder_hook,
+ **self._deserializer_kwargs,
+ )
return args, kwargs
except Exception as e:
logging.error("[YARP] Service call failed: %s" % e)
@@ -119,8 +157,12 @@ def reply(self, obj):
:param obj: Any: The Python object to be serialized and sent
"""
- obj_str = json.dumps(obj, cls=self._plugin_encoder, **self._plugin_kwargs,
- serializer_kwrags=self._serializer_kwargs)
+ obj_str = json.dumps(
+ obj,
+ cls=self._plugin_encoder,
+ **self._plugin_kwargs,
+ serializer_kwrags=self._serializer_kwargs,
+ )
obj_msg = yarp.Bottle()
obj_msg.clear()
obj_msg.addString(obj_str)
@@ -132,10 +174,20 @@ def reply(self, obj):
@Servers.register("Image", "yarp")
class YarpImageServer(YarpNativeObjectServer):
- def __init__(self, name: str, out_topic: str, carrier: Literal["tcp", "udp", "mcast"] = "tcp",
- out_topic_connect: Optional[str] = None, persistent: bool = True,
- width: int = -1, height: int = -1, rgb: bool = True, fp: bool = False,
- deserializer_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ out_topic: str,
+ carrier: Literal["tcp", "udp", "mcast"] = "tcp",
+ out_topic_connect: Optional[str] = None,
+ persistent: bool = True,
+ width: int = -1,
+ height: int = -1,
+ rgb: bool = True,
+ fp: bool = False,
+ deserializer_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
Specific server handling image data as numpy arrays, serializing them to JSON strings for transmission.
@@ -152,9 +204,19 @@ def __init__(self, name: str, out_topic: str, carrier: Literal["tcp", "udp", "mc
:param deserializer_kwargs: dict: Additional kwargs for the deserializer
"""
if "jpg" in kwargs:
- logging.warning("[YARP] YARP currently does not support JPG encoding in REQ/REP. Using raw image.")
+ logging.warning(
+ "[YARP] YARP currently does not support JPG encoding in REQ/REP. Using raw image."
+ )
kwargs.pop("jpg")
- super().__init__(name, out_topic, carrier=carrier, out_topic_connect=out_topic_connect, persistent=persistent, deserializer_kwargs=deserializer_kwargs, **kwargs)
+ super().__init__(
+ name,
+ out_topic,
+ carrier=carrier,
+ out_topic_connect=out_topic_connect,
+ persistent=persistent,
+ deserializer_kwargs=deserializer_kwargs,
+ **kwargs,
+ )
self.width = width
self.height = height
self.rgb = rgb
@@ -166,8 +228,14 @@ def reply(self, img: np.ndarray):
:param img: np.ndarray: Image to send formatted as a cv2 image - np.ndarray[img_height, img_width, channels]
"""
- if 0 < self.width != img.shape[1] or 0 < self.height != img.shape[0] or \
- not ((img.ndim == 2 and not self.rgb) or (img.ndim == 3 and self.rgb and img.shape[2] == 3)):
+ if (
+ 0 < self.width != img.shape[1]
+ or 0 < self.height != img.shape[0]
+ or not (
+ (img.ndim == 2 and not self.rgb)
+ or (img.ndim == 3 and self.rgb and img.shape[2] == 3)
+ )
+ ):
raise ValueError("Incorrect image shape for publisher")
# img = np.require(img, dtype=self._type, requirements='C')
super().reply(img)
@@ -175,10 +243,19 @@ def reply(self, img: np.ndarray):
@Servers.register("AudioChunk", "yarp")
class YarpAudioChunkServer(YarpNativeObjectServer):
- def __init__(self, name: str, out_topic: str, carrier: Literal["tcp", "udp", "mcast"] = "tcp",
- out_topic_connect: Optional[str] = None, persistent: bool = True,
- channels: int = 1, rate: int = 44100, chunk: int = -1,
- deserializer_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ out_topic: str,
+ carrier: Literal["tcp", "udp", "mcast"] = "tcp",
+ out_topic_connect: Optional[str] = None,
+ persistent: bool = True,
+ channels: int = 1,
+ rate: int = 44100,
+ chunk: int = -1,
+ deserializer_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
Specific server handling audio data as numpy arrays, serializing them to JSON strings for transmission.
@@ -193,7 +270,15 @@ def __init__(self, name: str, out_topic: str, carrier: Literal["tcp", "udp", "mc
:param chunk: int: Number of samples in the audio chunk. Default is -1 (use the chunk size of the received audio)
:param deserializer_kwargs: dict: Additional kwargs for the deserializer
"""
- super().__init__(name, out_topic, carrier=carrier, out_topic_connect=out_topic_connect, persistent=persistent, deserializer_kwargs=deserializer_kwargs, **kwargs)
+ super().__init__(
+ name,
+ out_topic,
+ carrier=carrier,
+ out_topic_connect=out_topic_connect,
+ persistent=persistent,
+ deserializer_kwargs=deserializer_kwargs,
+ **kwargs,
+ )
self.channels = channels
self.rate = rate
self.chunk = chunk
@@ -214,5 +299,5 @@ def reply(self, aud: Tuple[np.ndarray, int]):
self.channels = channels if self.channels == -1 else self.channels
if 0 < self.chunk != chunk or 0 < self.channels != channels:
raise ValueError("Incorrect audio shape for publisher")
- aud = np.require(aud, dtype=np.float32, requirements='C')
+ aud = np.require(aud, dtype=np.float32, requirements="C")
super().reply((chunk, channels, rate, aud))
diff --git a/wrapyfi/servers/zeromq.py b/wrapyfi/servers/zeromq.py
index 32271a3..4c21eb8 100644
--- a/wrapyfi/servers/zeromq.py
+++ b/wrapyfi/servers/zeromq.py
@@ -16,16 +16,27 @@
SOCKET_IP = os.environ.get("WRAPYFI_ZEROMQ_SOCKET_IP", "127.0.0.1")
SOCKET_PUB_PORT = int(os.environ.get("WRAPYFI_ZEROMQ_SOCKET_REQ_PORT", 5558))
SOCKET_SUB_PORT = int(os.environ.get("WRAPYFI_ZEROMQ_SOCKET_REP_PORT", 5559))
-START_PROXY_BROKER = os.environ.get("WRAPYFI_ZEROMQ_START_PROXY_BROKER", True) != "False"
+START_PROXY_BROKER = (
+ os.environ.get("WRAPYFI_ZEROMQ_START_PROXY_BROKER", True) != "False"
+)
PROXY_BROKER_SPAWN = os.environ.get("WRAPYFI_ZEROMQ_PROXY_BROKER_SPAWN", "process")
WATCHDOG_POLL_REPEAT = None
class ZeroMQServer(Server):
- def __init__(self, name: str, out_topic: str, carrier: str = "tcp",
- socket_ip: str = SOCKET_IP, socket_rep_port: int = SOCKET_PUB_PORT, socket_req_port: int = SOCKET_SUB_PORT,
- start_proxy_broker: bool = START_PROXY_BROKER, proxy_broker_spawn: bool = PROXY_BROKER_SPAWN,
- zeromq_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ out_topic: str,
+ carrier: str = "tcp",
+ socket_ip: str = SOCKET_IP,
+ socket_rep_port: int = SOCKET_PUB_PORT,
+ socket_req_port: int = SOCKET_SUB_PORT,
+ start_proxy_broker: bool = START_PROXY_BROKER,
+ proxy_broker_spawn: bool = PROXY_BROKER_SPAWN,
+ zeromq_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
Initialize the server and start the device broker if necessary.
@@ -41,20 +52,26 @@ def __init__(self, name: str, out_topic: str, carrier: str = "tcp",
:param kwargs: dict: Additional kwargs for the server
"""
if out_topic != "":
- logging.warning(f"[ZeroMQ] ZeroMQ does not support topics for the REQ/REP pattern. Topic {out_topic} removed")
+ logging.warning(
+ f"[ZeroMQ] ZeroMQ does not support topics for the REQ/REP pattern. Topic {out_topic} removed"
+ )
out_topic = ""
if carrier or carrier != "tcp":
- logging.warning("[ZeroMQ] ZeroMQ does not support other carriers than TCP for REQ/REP pattern. Using TCP.")
+ logging.warning(
+ "[ZeroMQ] ZeroMQ does not support other carriers than TCP for REQ/REP pattern. Using TCP."
+ )
carrier = "tcp"
super().__init__(name, out_topic, carrier=carrier, **kwargs)
self.socket_rep_address = f"{carrier}://{socket_ip}:{socket_rep_port}"
self.socket_req_address = f"{carrier}://{socket_ip}:{socket_req_port}"
if start_proxy_broker:
- ZeroMQMiddlewareReqRep.activate(socket_rep_address=self.socket_rep_address,
- socket_req_address=self.socket_req_address,
- proxy_broker_spawn=proxy_broker_spawn,
- **zeromq_kwargs or {})
+ ZeroMQMiddlewareReqRep.activate(
+ socket_rep_address=self.socket_rep_address,
+ socket_req_address=self.socket_req_address,
+ proxy_broker_spawn=proxy_broker_spawn,
+ **zeromq_kwargs or {},
+ )
else:
ZeroMQMiddlewareReqRep.activate(**zeromq_kwargs or {})
@@ -72,8 +89,15 @@ def __del__(self):
@Servers.register("NativeObject", "zeromq")
class ZeroMQNativeObjectServer(ZeroMQServer):
- def __init__(self, name: str, out_topic: str, carrier: str = "tcp",
- serializer_kwargs: Optional[dict] = None, deserializer_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ out_topic: str,
+ carrier: str = "tcp",
+ serializer_kwargs: Optional[dict] = None,
+ deserializer_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
Specific server handling native Python objects, serializing them to JSON strings for transmission.
@@ -98,9 +122,13 @@ def establish(self, **kwargs):
self._socket = zmq.Context().instance().socket(zmq.REP)
for socket_property in ZeroMQMiddlewareReqRep().zeromq_kwargs.items():
if isinstance(socket_property[1], str):
- self._socket.setsockopt_string(getattr(zmq, socket_property[0]), socket_property[1])
+ self._socket.setsockopt_string(
+ getattr(zmq, socket_property[0]), socket_property[1]
+ )
else:
- self._socket.setsockopt(getattr(zmq, socket_property[0]), socket_property[1])
+ self._socket.setsockopt(
+ getattr(zmq, socket_property[0]), socket_property[1]
+ )
self._socket.connect(self.socket_req_address)
self.established = True
@@ -116,7 +144,11 @@ def await_request(self, *args, **kwargs):
"""
message = self._socket.recv_string()
try:
- request = json.loads(message, object_hook=self._plugin_decoder_hook, **self._deserializer_kwargs)
+ request = json.loads(
+ message,
+ object_hook=self._plugin_decoder_hook,
+ **self._deserializer_kwargs,
+ )
args, kwargs = request
return args, kwargs
except json.JSONDecodeError as e:
@@ -137,9 +169,19 @@ def reply(self, obj):
@Servers.register("Image", "zeromq")
class ZeroMQImageServer(ZeroMQNativeObjectServer):
- def __init__(self, name: str, out_topic: str, carrier: str = "tcp",
- width: int = -1, height: int = -1, rgb: bool = True, fp: bool = False, jpg: bool = False,
- deserializer_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ out_topic: str,
+ carrier: str = "tcp",
+ width: int = -1,
+ height: int = -1,
+ rgb: bool = True,
+ fp: bool = False,
+ jpg: bool = False,
+ deserializer_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
Specific server handling image data as numpy arrays, serializing them to JSON strings for transmission.
@@ -153,7 +195,13 @@ def __init__(self, name: str, out_topic: str, carrier: str = "tcp",
:param jpg: bool: True if the image should be decompressed from JPG. Default is False
:param deserializer_kwargs: dict: Additional kwargs for the deserializer
"""
- super().__init__(name, out_topic, carrier=carrier, deserializer_kwargs=deserializer_kwargs, **kwargs)
+ super().__init__(
+ name,
+ out_topic,
+ carrier=carrier,
+ deserializer_kwargs=deserializer_kwargs,
+ **kwargs,
+ )
self.width = width
self.height = height
@@ -173,15 +221,21 @@ def reply(self, img: np.ndarray):
logging.warning("[ZeroMQ] Image is None. Skipping reply.")
return
- if 0 < self.width != img.shape[1] or 0 < self.height != img.shape[0] or \
- not ((img.ndim == 2 and not self.rgb) or (img.ndim == 3 and self.rgb and img.shape[2] == 3)):
+ if (
+ 0 < self.width != img.shape[1]
+ or 0 < self.height != img.shape[0]
+ or not (
+ (img.ndim == 2 and not self.rgb)
+ or (img.ndim == 3 and self.rgb and img.shape[2] == 3)
+ )
+ ):
raise ValueError("Incorrect image shape for server reply")
- if not img.flags['C_CONTIGUOUS']:
+ if not img.flags["C_CONTIGUOUS"]:
img = np.ascontiguousarray(img)
if self.jpg:
- _, img_encoded = cv2.imencode('.jpg', img)
+ _, img_encoded = cv2.imencode(".jpg", img)
img_bytes = img_encoded.tobytes()
self._socket.send(img_bytes)
else:
@@ -193,9 +247,17 @@ def reply(self, img: np.ndarray):
@Servers.register("AudioChunk", "zeromq")
class ZeroMQAudioChunkServer(ZeroMQNativeObjectServer):
- def __init__(self, name: str, out_topic: str, carrier: str = "tcp",
- channels: int = 1, rate: int = 44100, chunk: int = -1,
- deserializer_kwargs: Optional[dict] = None, **kwargs):
+ def __init__(
+ self,
+ name: str,
+ out_topic: str,
+ carrier: str = "tcp",
+ channels: int = 1,
+ rate: int = 44100,
+ chunk: int = -1,
+ deserializer_kwargs: Optional[dict] = None,
+ **kwargs,
+ ):
"""
Specific server handling audio data as numpy arrays, serializing them to JSON strings for transmission.
@@ -207,7 +269,13 @@ def __init__(self, name: str, out_topic: str, carrier: str = "tcp",
:param chunk: int: Number of samples in the audio chunk. Default is -1 (use the chunk size of the received audio)
:param deserializer_kwargs: dict: Additional kwargs for the deserializer
"""
- super().__init__(name, out_topic, carrier=carrier, deserializer_kwargs=deserializer_kwargs, **kwargs)
+ super().__init__(
+ name,
+ out_topic,
+ carrier=carrier,
+ deserializer_kwargs=deserializer_kwargs,
+ **kwargs,
+ )
self.channels = channels
self.rate = rate
self.chunk = chunk
@@ -226,8 +294,8 @@ def reply(self, aud: Tuple[np.ndarray, int]):
self.channels = channels if self.channels == -1 else self.channels
if 0 < self.chunk != chunk or 0 < self.channels != channels:
raise ValueError("Incorrect audio shape for publisher")
- aud = np.require(aud, dtype=np.float32, requirements='C')
+ aud = np.require(aud, dtype=np.float32, requirements="C")
aud_list = aud.tolist()
aud_json = json.dumps({"aud": (int(chunk), int(channels), int(rate), aud_list)})
- self._socket.send_string(aud_json)
\ No newline at end of file
+ self._socket.send_string(aud_json)
diff --git a/wrapyfi/standalone/zeromq_param_server.py b/wrapyfi/standalone/zeromq_param_server.py
index c0238bc..d50e168 100644
--- a/wrapyfi/standalone/zeromq_param_server.py
+++ b/wrapyfi/standalone/zeromq_param_server.py
@@ -11,12 +11,12 @@
def access_nested_dict(d, keys):
- return reduce(lambda x,y: x[y], keys[:-1], d)[keys[-1]]
+ return reduce(lambda x, y: x[y], keys[:-1], d)[keys[-1]]
def parse_prefix(param_full: str, topics: dict):
# split the string by '/' to get the individual components
- components = param_full.split('/')
+ components = param_full.split("/")
# If there is only one component in the list, add it to the topics dictionary
# as a value and return
if len(components) == 1:
@@ -45,29 +45,44 @@ def reverse_parse_prefix(topics: dict, prefix: str = ""):
# If the value of the key is a string, add the key and value to the list
# of strings as a '/' separated string
if isinstance(topics[key], str):
- strings.append(key + '/' + topics[key])
+ strings.append(key + "/" + topics[key])
# If the value of the key is a dictionary, recursively call generate_strings
# with the value and add the returned list of strings to the list of strings
# with the key as a prefix
else:
nested_strings = reverse_parse_prefix(topics[key])
for string in nested_strings:
- strings.append(key + '/' + string)
+ strings.append(key + "/" + string)
# return the list of strings filtered for repeated /'s
- return [prefix + string.replace('//', '/') for string in strings]
-
-
-def main(role, param_ip, param_pub_port, param_sub_port, param_reqrep_port, param_poll_interval, param_poll_repeat, **kwargs):
+ return [prefix + string.replace("//", "/") for string in strings]
+
+
+def main(
+ role,
+ param_ip,
+ param_pub_port,
+ param_sub_port,
+ param_reqrep_port,
+ param_poll_interval,
+ param_poll_repeat,
+ **kwargs,
+):
if role == "server":
param_pub_address = f"tcp://{param_ip}:{param_pub_port}"
param_sub_address = f"tcp://{param_ip}:{param_sub_port}"
param_reqrep_address = f"tcp://{param_ip}:{param_reqrep_port}"
- ZeroMQMiddlewareParamServer.activate(**{"param_pub_address": param_pub_address,
- "param_sub_address": param_sub_address,
- "param_reqrep_address": param_reqrep_address,
- "param_poll_interval": param_poll_interval,
- "proxy_broker_spawn": "process", "verbose": True}, **kwargs)
+ ZeroMQMiddlewareParamServer.activate(
+ **{
+ "param_pub_address": param_pub_address,
+ "param_sub_address": param_sub_address,
+ "param_reqrep_address": param_reqrep_address,
+ "param_poll_interval": param_poll_interval,
+ "proxy_broker_spawn": "process",
+ "verbose": True,
+ },
+ **kwargs,
+ )
while True:
pass
@@ -87,17 +102,20 @@ def main(role, param_ip, param_pub_port, param_sub_port, param_reqrep_port, para
while True:
# send a request to the request server starting with write, read, delete, set or get
- new_commands = str(input(f"input command: (default - {default_command})")) or default_command
+ new_commands = (
+ str(input(f"input command: (default - {default_command})"))
+ or default_command
+ )
default_command = new_commands
# pass a list of commands to the request server e.g. ['set /foo/bar/21', 'set /foo/bar/baz/42', 'get /foo/bar', 'delete /foo/bar/baz', 'read /foo']
- if '[' in new_commands:
- new_commands = new_commands.replace('[\'', '').replace('\']', '')
- new_commands = new_commands.split('\', \'')
+ if "[" in new_commands:
+ new_commands = new_commands.replace("['", "").replace("']", "")
+ new_commands = new_commands.split("', '")
# write supports writing full tree dicts e.g. write {'foo': {'bar': {'': '42', 'car': '43'}}}
- elif 'write' in new_commands:
- new_commands = new_commands.replace('write ', '')
- if new_commands.startswith('{'):
- new_commands = new_commands.replace('\n', '').replace('\t', '')
+ elif "write" in new_commands:
+ new_commands = new_commands.replace("write ", "")
+ if new_commands.startswith("{"):
+ new_commands = new_commands.replace("\n", "").replace("\t", "")
new_commands = json.loads(new_commands)
new_commands = reverse_parse_prefix(new_commands, prefix="set ")
# pass commands directly to the request server e.g. set /foo/bar/42
@@ -111,7 +129,7 @@ def main(role, param_ip, param_pub_port, param_sub_port, param_reqrep_port, para
print("Received reply from server: %s" % reply)
if "success:::" in reply:
topics = {}
- current_prefix = reply.split(":::")[1].encode('utf-8')
+ current_prefix = reply.split(":::")[1].encode("utf-8")
param_server.subscribe(current_prefix)
# time.sleep(0.2)
prev_params = Counter()
@@ -121,40 +139,91 @@ def main(role, param_ip, param_pub_port, param_sub_port, param_reqrep_port, para
try:
prefix, param, value = param_server.recv_multipart()
except zmq.error.Again:
- print("No new parameters received. Need atleast one topic to subscribe to.")
+ print(
+ "No new parameters received. Need atleast one topic to subscribe to."
+ )
break
# construct the full parameter name with the namespace prefix
- prefix, param, value = prefix.decode('utf-8'), param.decode('utf-8'), value.decode('utf-8')
- if (param in prev_params or param is None) and prev_params[param] == param_poll_repeat:
+ prefix, param, value = (
+ prefix.decode("utf-8"),
+ param.decode("utf-8"),
+ value.decode("utf-8"),
+ )
+ if (param in prev_params or param is None) and prev_params[
+ param
+ ] == param_poll_repeat:
break
prev_params[param] += 1
full_param = "/".join([prefix, param, value])
parse_prefix(full_param, topics)
# print("Received update: %s" % (full_param))
try:
- topic_results = access_nested_dict(topics, current_prefix.decode('utf-8').split('/'))
+ topic_results = access_nested_dict(
+ topics, current_prefix.decode("utf-8").split("/")
+ )
except KeyError:
- print(current_prefix.decode('utf-8') + " has no children")
+ print(current_prefix.decode("utf-8") + " has no children")
continue
- print(json.dumps({current_prefix.decode('utf-8'): topic_results}, indent=None, default=str))
- print("Reverse parse (always in set mode regardless of the transmitted command):")
- print(reverse_parse_prefix(topic_results, prefix=f"set {current_prefix.decode('utf-8')}/"))
+ print(
+ json.dumps(
+ {current_prefix.decode("utf-8"): topic_results},
+ indent=None,
+ default=str,
+ )
+ )
+ print(
+ "Reverse parse (always in set mode regardless of the transmitted command):"
+ )
+ print(
+ reverse_parse_prefix(
+ topic_results,
+ prefix=f"set {current_prefix.decode('utf-8')}/",
+ )
+ )
# close the connection
param_server.unsubscribe(current_prefix)
def parse_args():
parser = argparse.ArgumentParser()
- parser.add_argument("--param_ip", type=str, default="127.0.0.1", help="Parameter sever IP address")
- parser.add_argument("--param_pub_port", type=int, default=5655, help="Socket publishing (PUB) port")
- parser.add_argument("--param_sub_port", type=int, default=5656, help="Socket subscription (SUB) port")
- parser.add_argument("--param_reqrep_port", type=int, default=5659, help="Socket request (REQ)/reply (REP) port")
- parser.add_argument("--param_poll_interval", type=int, default=1, help="PUB/SUB poll interval in milliseconds "
- "(only used when role=server)")
- parser.add_argument("--param_poll_repeat", type=int, default=5, help="SUB poll repetition in terms of no. of "
- "attempts (only used when role=client)")
- parser.add_argument("--role", type=str, default="server", choices=["server", "client"],
- help="Parameter server in serving (server) or requesting (client) mode")
+ parser.add_argument(
+ "--param_ip", type=str, default="127.0.0.1", help="Parameter sever IP address"
+ )
+ parser.add_argument(
+ "--param_pub_port", type=int, default=5655, help="Socket publishing (PUB) port"
+ )
+ parser.add_argument(
+ "--param_sub_port",
+ type=int,
+ default=5656,
+ help="Socket subscription (SUB) port",
+ )
+ parser.add_argument(
+ "--param_reqrep_port",
+ type=int,
+ default=5659,
+ help="Socket request (REQ)/reply (REP) port",
+ )
+ parser.add_argument(
+ "--param_poll_interval",
+ type=int,
+ default=1,
+ help="PUB/SUB poll interval in milliseconds " "(only used when role=server)",
+ )
+ parser.add_argument(
+ "--param_poll_repeat",
+ type=int,
+ default=5,
+ help="SUB poll repetition in terms of no. of "
+ "attempts (only used when role=client)",
+ )
+ parser.add_argument(
+ "--role",
+ type=str,
+ default="server",
+ choices=["server", "client"],
+ help="Parameter server in serving (server) or requesting (client) mode",
+ )
return parser.parse_args()
diff --git a/wrapyfi/standalone/zeromq_proxy_broker.py b/wrapyfi/standalone/zeromq_proxy_broker.py
index ccde90e..88c0afb 100644
--- a/wrapyfi/standalone/zeromq_proxy_broker.py
+++ b/wrapyfi/standalone/zeromq_proxy_broker.py
@@ -1,5 +1,6 @@
import argparse
import logging
+
logging.getLogger().setLevel(logging.INFO)
import zmq
@@ -7,22 +8,42 @@
from wrapyfi.middlewares.zeromq import ZeroMQMiddlewareReqRep, ZeroMQMiddlewarePubSub
-def main(comm_type, socket_ip, socket_pub_port, socket_sub_port, socket_rep_port, socket_req_port, **kwargs):
+def main(
+ comm_type,
+ socket_ip,
+ socket_pub_port,
+ socket_sub_port,
+ socket_rep_port,
+ socket_req_port,
+ **kwargs,
+):
if comm_type == "pubsub":
socket_pub_address = f"tcp://{socket_ip}:{socket_pub_port}"
socket_sub_address = f"tcp://{socket_ip}:{socket_sub_port}"
- ZeroMQMiddlewarePubSub.activate(**{"socket_pub_address": socket_pub_address,
- "socket_sub_address": socket_sub_address,
- "proxy_broker_spawn": "process", "verbose": True}, **kwargs)
+ ZeroMQMiddlewarePubSub.activate(
+ **{
+ "socket_pub_address": socket_pub_address,
+ "socket_sub_address": socket_sub_address,
+ "proxy_broker_spawn": "process",
+ "verbose": True,
+ },
+ **kwargs,
+ )
while True:
pass
elif comm_type == "reqrep":
socket_rep_address = f"tcp://{socket_ip}:{socket_rep_port}"
socket_req_address = f"tcp://{socket_ip}:{socket_req_port}"
- ZeroMQMiddlewareReqRep.activate(**{"socket_rep_address": socket_rep_address,
- "socket_req_address": socket_req_address,
- "proxy_broker_spawn": "process", "verbose": True}, **kwargs)
+ ZeroMQMiddlewareReqRep.activate(
+ **{
+ "socket_rep_address": socket_rep_address,
+ "socket_req_address": socket_req_address,
+ "proxy_broker_spawn": "process",
+ "verbose": True,
+ },
+ **kwargs,
+ )
while True:
pass
@@ -44,27 +65,41 @@ def main(comm_type, socket_ip, socket_pub_port, socket_sub_port, socket_rep_port
event = dict(poller.poll(1000))
if xpub_socket in event:
message = xpub_socket.recv_multipart()
- #print("[ZeroMQ BROKER] xpub_socket recv message: %r" % message)
+ # print("[ZeroMQ BROKER] xpub_socket recv message: %r" % message)
xsub_socket.send_multipart(message)
if xsub_socket in event:
message = xsub_socket.recv_multipart()
- #print("[ZeroMQ BROKER] xsub_socket recv message: %r" % message)
+ # print("[ZeroMQ BROKER] xsub_socket recv message: %r" % message)
xpub_socket.send_multipart(message)
+
def parse_args():
parser = argparse.ArgumentParser()
- parser.add_argument("--socket_ip", type=str, default="127.0.0.1", help="Socket IP address")
- parser.add_argument("--socket_pub_port", type=int, default=5555, help="Socket publishing port")
- parser.add_argument("--socket_sub_port", type=int, default=5556, help="Socket subscription port")
- parser.add_argument("--socket_rep_port", type=int, default=5559, help="Socket reply port")
- parser.add_argument("--socket_req_port", type=int, default=5560, help="Socket request port")
- parser.add_argument("--comm_type", type=str, default="pubsub", choices=["pubsub", "pubsubpoll", "reqrep"],
- help="The zeromq communication pattern")
+ parser.add_argument(
+ "--socket_ip", type=str, default="127.0.0.1", help="Socket IP address"
+ )
+ parser.add_argument(
+ "--socket_pub_port", type=int, default=5555, help="Socket publishing port"
+ )
+ parser.add_argument(
+ "--socket_sub_port", type=int, default=5556, help="Socket subscription port"
+ )
+ parser.add_argument(
+ "--socket_rep_port", type=int, default=5559, help="Socket reply port"
+ )
+ parser.add_argument(
+ "--socket_req_port", type=int, default=5560, help="Socket request port"
+ )
+ parser.add_argument(
+ "--comm_type",
+ type=str,
+ default="pubsub",
+ choices=["pubsub", "pubsubpoll", "reqrep"],
+ help="The zeromq communication pattern",
+ )
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
main(**vars(args))
-
-
diff --git a/wrapyfi/standalone/zeromq_pubsub_topic_monitor.py b/wrapyfi/standalone/zeromq_pubsub_topic_monitor.py
index a20d966..b119f57 100644
--- a/wrapyfi/standalone/zeromq_pubsub_topic_monitor.py
+++ b/wrapyfi/standalone/zeromq_pubsub_topic_monitor.py
@@ -16,8 +16,8 @@ def monitor_active_connections(socket_pub_address, topic):
while True:
topic, message = subscriber.recv_multipart()
- topic = topic.decode('utf-8')
- data = json.loads(message.decode('utf-8'))
+ topic = topic.decode("utf-8")
+ data = json.loads(message.decode("utf-8"))
logging.info(f"[ZeroMQ] Topic: {topic}, Data: {data}")
except Exception as e:
@@ -26,15 +26,23 @@ def monitor_active_connections(socket_pub_address, topic):
def parse_args():
parser = argparse.ArgumentParser()
- parser.add_argument("--socket_pub_address", type=str, default="tcp://127.0.0.1:5555",
- help="Socket subscription address")
- parser.add_argument("--topic", type=str, default="ZEROMQ/CONNECTIONS",
- help="Topic to subscribe. The ZEROMQ/CONNECTIONS topic monitors subscribers to a any topic on "
- "the sub_address. Note that zeromq_proxy_broker.py in pubsub mode (either as a standalone"
- "or spawned by default by any publisher) must be running on socket_pub_address")
+ parser.add_argument(
+ "--socket_pub_address",
+ type=str,
+ default="tcp://127.0.0.1:5555",
+ help="Socket subscription address",
+ )
+ parser.add_argument(
+ "--topic",
+ type=str,
+ default="ZEROMQ/CONNECTIONS",
+ help="Topic to subscribe. The ZEROMQ/CONNECTIONS topic monitors subscribers to a any topic on "
+ "the sub_address. Note that zeromq_proxy_broker.py in pubsub mode (either as a standalone"
+ "or spawned by default by any publisher) must be running on socket_pub_address",
+ )
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
- monitor_active_connections(args.socket_pub_address, args.topic)
\ No newline at end of file
+ monitor_active_connections(args.socket_pub_address, args.topic)
diff --git a/wrapyfi/tests/test_middleware.py b/wrapyfi/tests/test_middleware.py
index 5398f31..7190de4 100644
--- a/wrapyfi/tests/test_middleware.py
+++ b/wrapyfi/tests/test_middleware.py
@@ -13,6 +13,7 @@ def test_publish_listen(self):
receive messages using the PUB/SUB pattern.
"""
import wrapyfi.tests.tools.class_test as class_test
+
# importlib.reload(wrapyfi)
# importlib.reload(class_test)
test_func = class_test.test_func
@@ -23,15 +24,29 @@ def test_publish_listen(self):
self.skipTest(f"{self.MWARE} not installed")
listen_queue = Queue(maxsize=10)
- test_lsn = multiprocessing.Process(target=test_func, args=(listen_queue,),
- kwargs={"mode": "listen", "mware": self.MWARE, "iterations": 10,
- "should_wait": True})
+ test_lsn = multiprocessing.Process(
+ target=test_func,
+ args=(listen_queue,),
+ kwargs={
+ "mode": "listen",
+ "mware": self.MWARE,
+ "iterations": 10,
+ "should_wait": True,
+ },
+ )
# test_lsn.daemon = True
publish_queue = Queue(maxsize=10)
- test_pub = multiprocessing.Process(target=test_func, args=(publish_queue,),
- kwargs={"mode": "publish", "mware": self.MWARE, "iterations": 10,
- "should_wait": True})
+ test_pub = multiprocessing.Process(
+ target=test_func,
+ args=(publish_queue,),
+ kwargs={
+ "mode": "publish",
+ "mware": self.MWARE,
+ "iterations": 10,
+ "should_wait": True,
+ },
+ )
# test_pub.daemon = True
test_lsn.start()
@@ -40,7 +55,9 @@ def test_publish_listen(self):
test_pub.join()
for i in range(10):
try:
- self.assertDictEqual(listen_queue.get(timeout=3), publish_queue.get(timeout=3))
+ self.assertDictEqual(
+ listen_queue.get(timeout=3), publish_queue.get(timeout=3)
+ )
except queue.Empty:
self.assertEqual(i, 9)
@@ -50,6 +67,7 @@ class ROS2TestMiddleware(ZeroMQTestMiddleware):
Test the ROS 2 wrapper. This test class inherits from the ZeroMQ test class, so all tests from the ZeroMQ test class
are also run for the ROS 2 wrapper.
"""
+
MWARE = "ros2"
@@ -58,6 +76,7 @@ class YarpTestMiddleware(ZeroMQTestMiddleware):
Test the YARP wrapper. This test class inherits from the ZeroMQ test class, so all tests from the ZeroMQ test class
are also run for the YARP wrapper.
"""
+
MWARE = "yarp"
@@ -66,8 +85,9 @@ class ROSTestMiddleware(ZeroMQTestMiddleware):
Test the ROS wrapper. This test class inherits from the ZeroMQ test class, so all tests from the ZeroMQ test class
are also run for the ROS wrapper.
"""
+
MWARE = "ros"
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/wrapyfi/tests/test_wrapper.py b/wrapyfi/tests/test_wrapper.py
index dfd04cc..116f918 100644
--- a/wrapyfi/tests/test_wrapper.py
+++ b/wrapyfi/tests/test_wrapper.py
@@ -12,6 +12,7 @@ def test_activate_communication(self):
the connection should no longer be active. The ``activate_communication`` method should also be callable multiple times on the same function.
"""
import wrapyfi.tests.tools.class_test as class_test
+
Test = class_test.Test
Test.close_all_instances()
if self.MWARE not in Test.get_communicators():
@@ -21,7 +22,9 @@ def test_activate_communication(self):
test, test2, test3 = Test(), Test(), Test()
test.activate_communication(test.exchange_object, mode="publish")
test2.activate_communication(test2.exchange_object, mode=None)
- test2.activate_communication(test2.exchange_object, mode="publish") # override None
+ test2.activate_communication(
+ test2.exchange_object, mode="publish"
+ ) # override None
test3.activate_communication(test3.exchange_object, mode="disable")
# ensure both nodes are registered
@@ -29,9 +32,11 @@ def test_activate_communication(self):
self.assertIn("Test.exchange_object.2", Test._MiddlewareCommunicator__registry)
self.assertIn("Test.exchange_object.3", Test._MiddlewareCommunicator__registry)
for i in range(2):
- msg_object, = test.exchange_object(mware=self.MWARE)
- msg_object2, = test2.exchange_object(mware=self.MWARE, topic="/test/test_native_exchange2")
- msg_object3, = test3.exchange_object(mware=self.MWARE)
+ (msg_object,) = test.exchange_object(mware=self.MWARE)
+ (msg_object2,) = test2.exchange_object(
+ mware=self.MWARE, topic="/test/test_native_exchange2"
+ )
+ (msg_object3,) = test3.exchange_object(mware=self.MWARE)
self.assertDictEqual(msg_object, msg_object2)
self.assertIsNone(msg_object3)
test.close()
@@ -40,10 +45,14 @@ def test_activate_communication(self):
# ensure only one node is registered
self.assertIn("Test.exchange_object", Test._MiddlewareCommunicator__registry)
self.assertIn("Test.exchange_object.2", Test._MiddlewareCommunicator__registry)
- self.assertNotIn("Test.exchange_object.3", Test._MiddlewareCommunicator__registry)
+ self.assertNotIn(
+ "Test.exchange_object.3", Test._MiddlewareCommunicator__registry
+ )
for i in range(2):
- msg_object2, = test2.exchange_object(mware=self.MWARE, topic="/test/test_native_exchange2")
+ (msg_object2,) = test2.exchange_object(
+ mware=self.MWARE, topic="/test/test_native_exchange2"
+ )
test2.close()
test3.close()
del test2, test3
@@ -60,6 +69,7 @@ def test_close(self):
multiple times on different instances of a class.
"""
import wrapyfi.tests.tools.class_test as class_test
+
Test = class_test.Test
Test.close_all_instances()
@@ -82,11 +92,17 @@ def test_close(self):
del test2
self.assertIn("Test.exchange_object", Test._MiddlewareCommunicator__registry)
self.assertIn("Test.exchange_object.2", Test._MiddlewareCommunicator__registry)
- self.assertNotIn("Test.exchange_object.3", Test._MiddlewareCommunicator__registry)
+ self.assertNotIn(
+ "Test.exchange_object.3", Test._MiddlewareCommunicator__registry
+ )
# check that the first & third node are still registered
- test_id = Test._MiddlewareCommunicator__registry["Test.exchange_object"]["__WRAPYFI_INSTANCES"].index(hex(id(test)))
- test3_id = Test._MiddlewareCommunicator__registry["Test.exchange_object"]["__WRAPYFI_INSTANCES"].index(hex(id(test3)))
+ test_id = Test._MiddlewareCommunicator__registry["Test.exchange_object"][
+ "__WRAPYFI_INSTANCES"
+ ].index(hex(id(test)))
+ test3_id = Test._MiddlewareCommunicator__registry["Test.exchange_object"][
+ "__WRAPYFI_INSTANCES"
+ ].index(hex(id(test3)))
self.assertEqual(test_id, 0)
self.assertEqual(test3_id, 1)
@@ -100,8 +116,12 @@ def test_close(self):
self.assertIn("Test.exchange_object", Test._MiddlewareCommunicator__registry)
# check that the first & third node are still registered
- test_id = Test._MiddlewareCommunicator__registry["Test.exchange_object"]["__WRAPYFI_INSTANCES"].index(hex(id(test)))
- test3_id = Test._MiddlewareCommunicator__registry["Test.exchange_object"]["__WRAPYFI_INSTANCES"].index(hex(id(test3)))
+ test_id = Test._MiddlewareCommunicator__registry["Test.exchange_object"][
+ "__WRAPYFI_INSTANCES"
+ ].index(hex(id(test)))
+ test3_id = Test._MiddlewareCommunicator__registry["Test.exchange_object"][
+ "__WRAPYFI_INSTANCES"
+ ].index(hex(id(test3)))
self.assertEqual(test_id, 1)
self.assertEqual(test3_id, 0)
@@ -112,14 +132,22 @@ def test_close(self):
del test
self.assertIn("Test.exchange_object", Test._MiddlewareCommunicator__registry)
- self.assertNotIn("Test.exchange_object.2", Test._MiddlewareCommunicator__registry)
- self.assertNotIn("Test.exchange_object.3", Test._MiddlewareCommunicator__registry)
+ self.assertNotIn(
+ "Test.exchange_object.2", Test._MiddlewareCommunicator__registry
+ )
+ self.assertNotIn(
+ "Test.exchange_object.3", Test._MiddlewareCommunicator__registry
+ )
# close all instances after all have been closed
Test.close_all_instances()
self.assertIn("Test.exchange_object", Test._MiddlewareCommunicator__registry)
- self.assertNotIn("Test.exchange_object.2", Test._MiddlewareCommunicator__registry)
- self.assertNotIn("Test.exchange_object.3", Test._MiddlewareCommunicator__registry)
+ self.assertNotIn(
+ "Test.exchange_object.2", Test._MiddlewareCommunicator__registry
+ )
+ self.assertNotIn(
+ "Test.exchange_object.3", Test._MiddlewareCommunicator__registry
+ )
def test_get_communicators(self):
"""
@@ -127,6 +155,7 @@ def test_get_communicators(self):
method should return a list of all available middleware wrappers.
"""
import wrapyfi.tests.tools.class_test as class_test
+
# importlib.reload(class_test)
Test = class_test.Test
@@ -144,6 +173,7 @@ class ROS2TestWrapper(ZeroMQTestWrapper):
Test the ROS 2 wrapper. This test class inherits from the ZeroMQ test class, so all tests from the ZeroMQ test class
are also run for the ROS 2 wrapper.
"""
+
MWARE = "ros2"
@@ -152,6 +182,7 @@ class YarpTestWrapper(ZeroMQTestWrapper):
Test the YARP wrapper. This test class inherits from the ZeroMQ test class, so all tests from the ZeroMQ test class
are also run for the YARP wrapper.
"""
+
MWARE = "yarp"
@@ -160,8 +191,9 @@ class ROSTestWrapper(ZeroMQTestWrapper):
Test the ROS wrapper. This test class inherits from the ZeroMQ test class, so all tests from the ZeroMQ test class
are also run for the ROS wrapper.
"""
+
MWARE = "ros"
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/wrapyfi/tests/tools/benchmarking_native_object.py b/wrapyfi/tests/tools/benchmarking_native_object.py
index 0bac873..138dd68 100644
--- a/wrapyfi/tests/tools/benchmarking_native_object.py
+++ b/wrapyfi/tests/tools/benchmarking_native_object.py
@@ -1,5 +1,6 @@
import argparse
import time
+
try:
import numpy as np
import pandas as pd
@@ -7,8 +8,9 @@
print("Install pandas and NumPy before running this script.")
try:
import tensorflow as tf
+
# avoid allocating all GPU memory assuming tf>=2.2
- gpus = tf.config.experimental.list_physical_devices('GPU')
+ gpus = tf.config.experimental.list_physical_devices("GPU")
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
except ImportError:
@@ -28,11 +30,16 @@ def get_numpy_object(dims):
@staticmethod
def get_pandas_object(dims):
- return {"pandas": pd.DataFrame(np.ones(dims), index=None, columns=list(range(dims[-1])))}
+ return {
+ "pandas": pd.DataFrame(
+ np.ones(dims), index=None, columns=list(range(dims[-1]))
+ )
+ }
@staticmethod
def get_pillow_object(dims):
from PIL import Image
+
return {"pillow": Image.fromarray(np.ones(dims, dtype=np.uint8))}
@staticmethod
@@ -42,96 +49,178 @@ def get_tensorflow_object(dims):
@staticmethod
def get_jax_object(dims):
import jax as jx
+
return {"jax": jx.numpy.ones(dims)}
@staticmethod
def get_mxnet_object(dims):
import mxnet as mx
+
return {"mxnet": mx.nd.ones(dims)}
@staticmethod
def get_mxnet_gpu_object(dims, gpu=0):
import mxnet as mx
+
return {"mxnet_gpu": mx.nd.ones(dims, ctx=mx.gpu(gpu))}
@staticmethod
def get_pytorch_object(dims):
import torch as th
+
return {"pytorch": th.ones(dims)}
@staticmethod
def get_pytorch_gpu_object(dims, gpu=0):
import torch as th
+
return {"pytorch_gpu": th.ones(dims, device=f"cuda:{gpu}")}
@staticmethod
def get_paddle_object(dims):
import paddle as pa
+
return {"paddle": pa.Tensor(pa.ones(dims), place=pa.CPUPlace())}
@staticmethod
def get_paddle_gpu_object(dims, gpu=0):
import paddle as pa
+
return {"paddle_gpu": pa.Tensor(pa.zeros(dims), place=pa.CUDAPlace(gpu))}
def get_all_objects(self, count, plugin_name):
- obj = {"count": count,
- "time": time.time()}
- obj.update(**getattr(self, f"get_{plugin_name}_object")((args.height, args.width,)))
+ obj = {"count": count, "time": time.time()}
+ obj.update(
+ **getattr(self, f"get_{plugin_name}_object")(
+ (
+ args.height,
+ args.width,
+ )
+ )
+ )
return obj
- @MiddlewareCommunicator.register("NativeObject", "yarp",
- "ExampleClass", "/example/get_native_objects",
- carrier="tcp", should_wait=SHOULD_WAIT)
+ @MiddlewareCommunicator.register(
+ "NativeObject",
+ "yarp",
+ "ExampleClass",
+ "/example/get_native_objects",
+ carrier="tcp",
+ should_wait=SHOULD_WAIT,
+ )
def get_yarp_native_objects(self, count, plugin_name):
- return self.get_all_objects(count, plugin_name),
+ return (self.get_all_objects(count, plugin_name),)
- @MiddlewareCommunicator.register("NativeObject", "ros",
- "ExampleClass", "/example/get_native_objects",
- carrier="tcp", should_wait=SHOULD_WAIT)
+ @MiddlewareCommunicator.register(
+ "NativeObject",
+ "ros",
+ "ExampleClass",
+ "/example/get_native_objects",
+ carrier="tcp",
+ should_wait=SHOULD_WAIT,
+ )
def get_ros_native_objects(self, count, plugin_name):
- return self.get_all_objects(count, plugin_name),
+ return (self.get_all_objects(count, plugin_name),)
- @MiddlewareCommunicator.register("NativeObject", "ros2",
- "ExampleClass", "/example/get_native_objects",
- carrier="tcp", should_wait=SHOULD_WAIT)
+ @MiddlewareCommunicator.register(
+ "NativeObject",
+ "ros2",
+ "ExampleClass",
+ "/example/get_native_objects",
+ carrier="tcp",
+ should_wait=SHOULD_WAIT,
+ )
def get_ros2_native_objects(self, count, plugin_name):
- return self.get_all_objects(count, plugin_name),
+ return (self.get_all_objects(count, plugin_name),)
- @MiddlewareCommunicator.register("NativeObject", "zeromq",
- "ExampleClass", "/example/get_native_objects",
- carrier="tcp", should_wait=SHOULD_WAIT)
+ @MiddlewareCommunicator.register(
+ "NativeObject",
+ "zeromq",
+ "ExampleClass",
+ "/example/get_native_objects",
+ carrier="tcp",
+ should_wait=SHOULD_WAIT,
+ )
def get_zeromq_native_objects(self, count, plugin_name):
- return self.get_all_objects(count, plugin_name),
+ return (self.get_all_objects(count, plugin_name),)
def parse_args():
parser = argparse.ArgumentParser()
- parser.add_argument("--publish", dest="mode", action="store_const", const="publish", default="listen", help="Publish mode")
- parser.add_argument("--listen", dest="mode", action="store_const", const="listen", default="listen", help="Listen mode (default)")
- parser.add_argument("--mwares", type=str, default=list(MiddlewareCommunicator.get_communicators()),
- choices=MiddlewareCommunicator.get_communicators(), nargs="+",
- help="The middlewares to use for transmission")
- parser.add_argument("--plugins", type=str,
- default=["numpy", "pandas", "tensorflow", "jax", "mxnet", "mxnet_gpu", "pytorch", "pytorch_gpu",
- "paddle", "paddle_gpu"], nargs="+",
- help="The middlewares to use for transmission")
- parser.add_argument("--height", type=int, default=200, help="The tensor image height")
+ parser.add_argument(
+ "--publish",
+ dest="mode",
+ action="store_const",
+ const="publish",
+ default="listen",
+ help="Publish mode",
+ )
+ parser.add_argument(
+ "--listen",
+ dest="mode",
+ action="store_const",
+ const="listen",
+ default="listen",
+ help="Listen mode (default)",
+ )
+ parser.add_argument(
+ "--mwares",
+ type=str,
+ default=list(MiddlewareCommunicator.get_communicators()),
+ choices=MiddlewareCommunicator.get_communicators(),
+ nargs="+",
+ help="The middlewares to use for transmission",
+ )
+ parser.add_argument(
+ "--plugins",
+ type=str,
+ default=[
+ "numpy",
+ "pandas",
+ "tensorflow",
+ "jax",
+ "mxnet",
+ "mxnet_gpu",
+ "pytorch",
+ "pytorch_gpu",
+ "paddle",
+ "paddle_gpu",
+ ],
+ nargs="+",
+ help="The middlewares to use for transmission",
+ )
+ parser.add_argument(
+ "--height", type=int, default=200, help="The tensor image height"
+ )
parser.add_argument("--width", type=int, default=200, help="The tensor image width")
- parser.add_argument("--trials", type=int, default=2000, help="Number of trials to run per middleware")
- parser.add_argument("--skip-trials", type=int, default=0, help="Number of trials to skip before logging "
- "to csv to avoid warmup time logging")
+ parser.add_argument(
+ "--trials",
+ type=int,
+ default=2000,
+ help="Number of trials to run per middleware",
+ )
+ parser.add_argument(
+ "--skip-trials",
+ type=int,
+ default=0,
+ help="Number of trials to skip before logging "
+ "to csv to avoid warmup time logging",
+ )
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
benchmarker = Benchmarker()
- benchmark_logger = pd.DataFrame(columns=["middleware", "plugin", "timestamp", "count", "delay"])
+ benchmark_logger = pd.DataFrame(
+ columns=["middleware", "plugin", "timestamp", "count", "delay"]
+ )
benchmark_iterator = {}
for middleware_name in args.mwares:
- benchmark_iterator[middleware_name] = getattr(benchmarker, f"get_{middleware_name}_native_objects")
+ benchmark_iterator[middleware_name] = getattr(
+ benchmarker, f"get_{middleware_name}_native_objects"
+ )
for middleware_name, method in benchmark_iterator.items():
benchmarker.activate_communication(method, mode=args.mode)
@@ -140,20 +229,32 @@ def parse_args():
counter = -1
while True:
counter += 1
- native_objects, = method(counter, plugin_name)
+ (native_objects,) = method(counter, plugin_name)
if native_objects is not None:
time_acc_native_objects.append(time.time() - native_objects["time"])
- print(f"{middleware_name} :: {plugin_name} :: delay:", time_acc_native_objects[-1],
- " Length:", len(time_acc_native_objects), " Count:", native_objects["count"])
+ print(
+ f"{middleware_name} :: {plugin_name} :: delay:",
+ time_acc_native_objects[-1],
+ " Length:",
+ len(time_acc_native_objects),
+ " Count:",
+ native_objects["count"],
+ )
if args.trials - 1 == native_objects["count"]:
break
if counter > args.skip_trials:
- benchmark_logger = benchmark_logger.append(pd.DataFrame({"middleware": [middleware_name],
- "plugin": [plugin_name],
- "timestamp": [native_objects["timestamp"]],
- "count": [native_objects["count"]],
- "delay": [time_acc_native_objects[-1]]}),
- ignore_index=True)
+ benchmark_logger = benchmark_logger.append(
+ pd.DataFrame(
+ {
+ "middleware": [middleware_name],
+ "plugin": [plugin_name],
+ "timestamp": [native_objects["timestamp"]],
+ "count": [native_objects["count"]],
+ "delay": [time_acc_native_objects[-1]],
+ }
+ ),
+ ignore_index=True,
+ )
if counter == 0:
if args.mode == "publish":
time.sleep(5)
@@ -162,6 +263,12 @@ def parse_args():
time.sleep(0.1)
time_acc_native_objects = pd.DataFrame(np.array(time_acc_native_objects))
- print(f"{middleware_name} :: {plugin_name} :: time statistics:", time_acc_native_objects.describe())
+ print(
+ f"{middleware_name} :: {plugin_name} :: time statistics:",
+ time_acc_native_objects.describe(),
+ )
time.sleep(5)
- benchmark_logger.to_csv(f"results/benchmarking_native_object_{args.mode}__{','.join(args.mwares)}__{','.join(args.plugins)}.csv", index=False)
\ No newline at end of file
+ benchmark_logger.to_csv(
+ f"results/benchmarking_native_object_{args.mode}__{','.join(args.mwares)}__{','.join(args.plugins)}.csv",
+ index=False,
+ )
diff --git a/wrapyfi/tests/tools/class_test.py b/wrapyfi/tests/tools/class_test.py
index 5703e37..44de6e7 100644
--- a/wrapyfi/tests/tools/class_test.py
+++ b/wrapyfi/tests/tools/class_test.py
@@ -6,24 +6,48 @@
class Test(MiddlewareCommunicator):
- @MiddlewareCommunicator.register("NativeObject", "$mware", "Test", "$topic",
- should_wait="$should_wait")
- def exchange_object(self, msg=None, mware=DEFAULT_COMMUNICATOR, topic="/test/test_native_exchange", should_wait=False):
- ret = {"message": msg,
- "set": {'a', 1, None},
- "list": [[[3, [4], 5.677890, 1.2]]],
- "string": "string of characters",
- "2": 2.73211,
- "dict": {"other": [None, False, 16, 4.32,]}}
- return ret,
+ @MiddlewareCommunicator.register(
+ "NativeObject", "$mware", "Test", "$topic", should_wait="$should_wait"
+ )
+ def exchange_object(
+ self,
+ msg=None,
+ mware=DEFAULT_COMMUNICATOR,
+ topic="/test/test_native_exchange",
+ should_wait=False,
+ ):
+ ret = {
+ "message": msg,
+ "set": {"a", 1, None},
+ "list": [[[3, [4], 5.677890, 1.2]]],
+ "string": "string of characters",
+ "2": 2.73211,
+ "dict": {
+ "other": [
+ None,
+ False,
+ 16,
+ 4.32,
+ ]
+ },
+ }
+ return (ret,)
-def test_func(queue_buffer, mode="listen", mware=DEFAULT_COMMUNICATOR, topic="/test/test_native_exchange", iterations=2,
- should_wait=False):
+def test_func(
+ queue_buffer,
+ mode="listen",
+ mware=DEFAULT_COMMUNICATOR,
+ topic="/test/test_native_exchange",
+ iterations=2,
+ should_wait=False,
+):
test = Test()
test.activate_communication(test.exchange_object, mode=mode)
for i in range(iterations):
- my_message = test.exchange_object(msg=f"signal_idx:{i}", mware=mware, topic=topic, should_wait=should_wait)
+ my_message = test.exchange_object(
+ msg=f"signal_idx:{i}", mware=mware, topic=topic, should_wait=should_wait
+ )
if my_message is not None:
print(f"result {mode}:", my_message[0]["message"])
queue_buffer.put(my_message[0])
diff --git a/wrapyfi/utils.py b/wrapyfi/utils.py
index 95f27d0..6901413 100755
--- a/wrapyfi/utils.py
+++ b/wrapyfi/utils.py
@@ -12,7 +12,11 @@
lock = threading.Lock()
-def deepcopy(obj: Any, exclude_keys: Optional[Union[list, tuple]] = None, shallow_keys: Optional[Union[list, tuple]] = None):
+def deepcopy(
+ obj: Any,
+ exclude_keys: Optional[Union[list, tuple]] = None,
+ shallow_keys: Optional[Union[list, tuple]] = None,
+):
"""
Deep copy an object, excluding specified keys.
@@ -21,6 +25,7 @@ def deepcopy(obj: Any, exclude_keys: Optional[Union[list, tuple]] = None, shallo
:param shallow_keys: Union[list, tuple]: A list of keys to shallow copy
"""
import copy
+
if exclude_keys is None:
return copy.deepcopy(obj)
else:
@@ -32,7 +37,11 @@ def deepcopy(obj: Any, exclude_keys: Optional[Union[list, tuple]] = None, shallo
return {deepcopy(item, exclude_keys) for item in obj}
elif isinstance(obj, dict):
_shallows = shallow_keys or []
- ret = {key: deepcopy(val, exclude_keys) for key, val in obj.items() if key not in exclude_keys + _shallows}
+ ret = {
+ key: deepcopy(val, exclude_keys)
+ for key, val in obj.items()
+ if key not in exclude_keys + _shallows
+ }
ret.update({key: val for key, val in obj.items() if key in _shallows})
return ret
else:
@@ -46,6 +55,7 @@ def get_default_args(fnc: Callable[..., Any]):
:param fnc: Callable[..., Any]: The function to get the default arguments for
"""
import inspect
+
signature = inspect.signature(fnc)
return {
k: v.default
@@ -54,7 +64,12 @@ def get_default_args(fnc: Callable[..., Any]):
}
-def match_args(args: Union[list, tuple], kwargs: dict, src_args: Union[list, tuple], src_kwargs: dict):
+def match_args(
+ args: Union[list, tuple],
+ kwargs: dict,
+ src_args: Union[list, tuple],
+ src_kwargs: dict,
+):
"""
Match and Substitute Arguments and Keyword Arguments using Specified Source Values.
@@ -91,7 +106,11 @@ def match_args(args: Union[list, tuple], kwargs: dict, src_args: Union[list, tup
for kwarg_key, kwarg_val in kwargs.items():
if isinstance(kwarg_val, str) and "$" in kwarg_val and kwarg_val[1:].isdigit():
new_kwargs[kwarg_key] = src_args[int(kwarg_val[1:])]
- elif isinstance(kwarg_val, str) and "$" in kwarg_val and not kwarg_val[1:].isdigit():
+ elif (
+ isinstance(kwarg_val, str)
+ and "$" in kwarg_val
+ and not kwarg_val[1:].isdigit()
+ ):
new_kwargs[kwarg_key] = src_kwargs[kwarg_val[1:]]
else:
new_kwargs[kwarg_key] = kwarg_val
@@ -111,14 +130,14 @@ def dynamic_module_import(modules: List[str], globals: dict):
module_name = module_name[:-3]
module_name = module_name.replace("/", ".")
try:
- module = __import__(module_name, fromlist=['*'])
+ module = __import__(module_name, fromlist=["*"])
except ImportError as e:
# print(module_name + " could not be imported.", e)
continue
- if hasattr(module, '__all__'):
+ if hasattr(module, "__all__"):
all_names = module.__all__
else:
- all_names = [name for name in dir(module) if not name.startswith('_')]
+ all_names = [name for name in dir(module) if not name.startswith("_")]
globals.update({name: getattr(module, name) for name in all_names})
@@ -128,13 +147,16 @@ class SingletonOptimized(type):
Source: https://stackoverflow.com/a/6798042
"""
+
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
with lock:
if cls not in cls._instances:
- cls._instances[cls] = super(SingletonOptimized, cls).__call__(*args, **kwargs)
+ cls._instances[cls] = super(SingletonOptimized, cls).__call__(
+ *args, **kwargs
+ )
return cls._instances[cls]
@@ -142,6 +164,7 @@ class Plugin(object):
"""
Base class for encoding and decoding plugins.
"""
+
def encode(self, *args, **kwargs):
"""
Encode data into a base64 string.
@@ -172,6 +195,7 @@ class PluginRegistrar(object):
"""
Class for registering encoding and decoding plugins.
"""
+
encoder_registry = {}
decoder_registry = {}
@@ -182,12 +206,14 @@ def register(types=None):
:param types: tuple: The type(s) to register the plugin for
"""
+
def wrapper(cls):
if types is not None:
for cls_type in types:
PluginRegistrar.encoder_registry[cls_type] = cls
PluginRegistrar.decoder_registry[str(cls.__name__)] = cls
return cls
+
return wrapper
@staticmethod
@@ -196,20 +222,30 @@ def scan():
Scan the plugins directory (Wrapyfi builtin and external) for plugins to register.
This method is called automatically when the module is imported.
"""
- modules = glob(os.path.join(os.path.dirname(__file__), "plugins", "*.py"), recursive=True)
- modules = ["wrapyfi.plugins." + module.replace(os.path.dirname(__file__) + "/plugins/", "") for module in modules]
+ modules = glob(
+ os.path.join(os.path.dirname(__file__), "plugins", "*.py"), recursive=True
+ )
+ modules = [
+ "wrapyfi.plugins."
+ + module.replace(os.path.dirname(__file__) + "/plugins/", "")
+ for module in modules
+ ]
dynamic_module_import(modules, globals())
extern_modules_paths = os.environ.get(WRAPYFI_PLUGIN_PATHS, "").split(":")
for mod_group_idx, extern_module_path in enumerate(extern_modules_paths):
- extern_modules = glob(os.path.join(extern_module_path, "plugins", "*.py"), recursive=True)
+ extern_modules = glob(
+ os.path.join(extern_module_path, "plugins", "*.py"), recursive=True
+ )
for mod_idx, extern_module in enumerate(extern_modules):
- spec = importlib.util.spec_from_file_location(f"wrapyfi.extern{mod_group_idx}.plugins{mod_idx}", extern_module)
+ spec = importlib.util.spec_from_file_location(
+ f"wrapyfi.extern{mod_group_idx}.plugins{mod_idx}", extern_module
+ )
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module)
- extern_modules = [f"wrapyfi.extern{mod_group_idx}.plugins{mod_idx}." + extern_module.replace(
- extern_module_path + "/plugins/", "")]
+ extern_modules = [
+ f"wrapyfi.extern{mod_group_idx}.plugins{mod_idx}."
+ + extern_module.replace(extern_module_path + "/plugins/", "")
+ ]
dynamic_module_import(extern_modules, globals())
-
-