Skip to content

Commit

Permalink
Merge pull request #100 from modular-ml/dev
Browse files Browse the repository at this point in the history
Added Trax support and minor fixes to JAX
  • Loading branch information
fabawi committed Feb 21, 2024
2 parents ed755e8 + 410a65f commit e77475e
Show file tree
Hide file tree
Showing 9 changed files with 219 additions and 3 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ For more examples of usage, refer to the [user guide](docs/usage.md). Run script
- [ ] **msgpack**
- [ ] **protobuf**


## Data Structures

Supported Objects by the `NativeObject` type include:
Expand All @@ -349,6 +350,7 @@ Supported Objects by the `NativeObject` type include:
- [x] [**PyTorch Tensor**](https://pytorch.org/docs/stable/index.html)
- [x] [**TensorFlow 2 Tensor**](https://www.tensorflow.org/api_docs/python/tf)
- [x] [**JAX Tensor**](https://jax.readthedocs.io/en/latest/)
- [x] [**Trax Array**](https://trax-ml.readthedocs.io/en/latest/)
- [x] [**MXNet Tensor**](https://mxnet.apache.org/versions/1.9.1/api/python.html)
- [x] [**PaddlePaddle Tensor**](https://www.paddlepaddle.org.cn/documentation/docs/en/guides/index_en.html)
- [x] [**pandas DataFrame | Series**](https://pandas.pydata.org/docs/)
Expand All @@ -363,12 +365,14 @@ Supported Objects by the `NativeObject` type include:
- [ ] [**Gmpy 2 MPZ**](https://gmpy2.readthedocs.io/en/latest/) [![planned](https://custom-icon-badges.demolab.com/badge/planned%20for%20Wrapyfi%20v0.5-%23C2E0C6.svg?logo=hourglass&logoColor=white)](https://github.com/modular-ml/wrapyfi/issues/99 "planned link")
- [ ] [**MLX Tensor**](https://ml-explore.github.io/mlx/build/html/index.html) [![planned](https://custom-icon-badges.demolab.com/badge/planned%20for%20Wrapyfi%20v0.5-%23C2E0C6.svg?logo=hourglass&logoColor=white)](https://github.com/modular-ml/wrapyfi/issues/99 "planned link")


## Image

Supported Objects by the `Image` type include:

- [x] **NumPy Array** [*supports many libraries including [scikit-image](https://scikit-image.org/), [imageio](https://imageio.readthedocs.io/en/stable/), [Open CV](https://opencv.org/), [imutils](https://github.com/PyImageSearch/imutils), [matplotlib.image](https://matplotlib.org/stable/api/image_api.html), and [Mahotas](https://mahotas.readthedocs.io/en/latest/)*]


## Sound

Supported Objects by the `AudioChunk` type include:
Expand Down
5 changes: 4 additions & 1 deletion docs/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ A message publisher and listener for native Python objects and CuPy arrays.
A message publisher and listener for native Python objects and Zarr arrays/groups.

## [JAX](https://wrapyfi.readthedocs.io/en/latest/examples/examples.encoders.html#module-examples.encoders.jax_example)
A message publisher and listener for native Python objects and JAX arrays.
A message publisher and listener for native Python objects and JAX tensors.

## [Trax](https://wrapyfi.readthedocs.io/en/latest/examples/examples.encoders.html#module-examples.encoders.trax_example)
A message publisher and listener for native Python objects and Trax arrays.

## [MXNet](https://wrapyfi.readthedocs.io/en/latest/examples/examples.encoders.html#module-examples.encoders.mxnet_example)
A message publisher and listener for native Python objects and MXNet tensors.
Expand Down
2 changes: 1 addition & 1 deletion docs/exclude_packages.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
["torch", "tensorflow", "collections", "std_msgs.msg", "std_msgs.msg.String", "std_msgs.msg.Image", "dask.array", "dask.dataframe", "numpy", "cupy.ndarray", "wrapyfi.encoders", "wrapyfi.connect.wrapper", "wrapyfi.connect.listeners", "PIL", "importlib", "scipy.stats", "rostopic", "Cython.Compiler.UtilNodes", "jax.numpy", "jax.numpy.DeviceArray", "jax.Array", "rospy", "wrapyfi.connect.clients", "xarray.DataArray", "xarray.Dataset", "wrapyfi_ros2_interfaces.srv", "zipfile", "wrapyfi.config.manager", "Utils", "zipimport", "rclpy.node", "cv2", "wrapyfi.tests.tools.class_test", "wrapyfi.connect.servers", "pandas.DataFrame", "pandas.Series", "wrapyfi.utils", "yarp", "rclpy", "astropy.table", "pint", "mxnet", "wrapyfi.connect.publishers", "pexpect", "zarr.Array", "zarr.Group", "sensor_msgs.msg", "gzip", "zmq", "pyarrow.StructArray", "paddle", "sounddevice", "traceback", "geometry_msgs.msg", "astropy", "tempfile"]
["torch", "tensorflow", "collections", "std_msgs.msg", "std_msgs.msg.String", "std_msgs.msg.Image", "dask.array", "dask.dataframe", "numpy", "cupy.ndarray", "wrapyfi.encoders", "wrapyfi.connect.wrapper", "wrapyfi.connect.listeners", "PIL", "importlib", "scipy.stats", "rostopic", "Cython.Compiler.UtilNodes", "jax.numpy", "jax.numpy.DeviceArray", "jax.Array", "jaxlib.xla_extension.ArrayImpl", "trax", "trax.fastmath", "trax.fastmath.numpy", "trax.fastmath", "rospy", "wrapyfi.connect.clients", "xarray.DataArray", "xarray.Dataset", "wrapyfi_ros2_interfaces.srv", "zipfile", "wrapyfi.config.manager", "Utils", "zipimport", "rclpy.node", "cv2", "wrapyfi.tests.tools.class_test", "wrapyfi.connect.servers", "pandas.DataFrame", "pandas.Series", "wrapyfi.utils", "yarp", "rclpy", "astropy.table", "pint", "mxnet", "wrapyfi.connect.publishers", "pexpect", "zarr.Array", "zarr.Group", "sensor_msgs.msg", "gzip", "zmq", "pyarrow.StructArray", "paddle", "sounddevice", "traceback", "geometry_msgs.msg", "astropy", "tempfile"]
16 changes: 16 additions & 0 deletions docs/source/wrapyfi.standalone.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,22 @@ wrapyfi.standalone package
Submodules
----------

wrapyfi.standalone.mqtt\_publisher module
-----------------------------------------

.. automodule:: wrapyfi.standalone.mqtt_publisher
:members:
:undoc-members:
:show-inheritance:

wrapyfi.standalone.mqtt\_subscriber module
------------------------------------------

.. automodule:: wrapyfi.standalone.mqtt_subscriber
:members:
:undoc-members:
:show-inheritance:

wrapyfi.standalone.zeromq\_param\_server module
-----------------------------------------------

Expand Down
1 change: 1 addition & 0 deletions docs/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ Other than native python objects, the following objects are supported:
* `tensorflow.Tensor` and `tensorflow.EagerTensor`
* `mxnet.nd.NDArray`
* `jax.numpy.DeviceArray`
* `trax.ArrayImpl` -> `jaxlib.xla_extension.ArrayImpl`
* `paddle.Tensor`
* `PIL.Image`
* `pyarrow.StructArray`
Expand Down
1 change: 1 addition & 0 deletions docs/usage/User Guide/Plugins.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ Other than native python objects, the following objects are supported:
* `tensorflow.Tensor` and `tensorflow.EagerTensor`
* `mxnet.nd.NDArray`
* `jax.numpy.DeviceArray`
* `trax.ArrayImpl` -> `jaxlib.xla_extension.ArrayImpl`
* `paddle.Tensor`
* `PIL.Image`
* `pyarrow.StructArray`
Expand Down
102 changes: 102 additions & 0 deletions examples/encoders/trax_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""
A message publisher and listener for native Python objects and Trax arrays.
This script demonstrates the capability to transmit native Python objects and Trax arrays using
the MiddlewareCommunicator within the Wrapyfi library. The communication follows the PUB/SUB pattern
allowing message publishing and listening functionalities between processes or machines.
Demonstrations:
- Using the NativeObject message
- Transmitting a nested dummy Python object with native objects and Trax arrays
- Applying the PUB/SUB pattern with mirroring
Requirements:
- Wrapyfi: Middleware communication wrapper (refer to the Wrapyfi documentation for installation instructions)
- YARP, ROS, ROS 2, ZeroMQ (refer to the Wrapyfi documentation for installation instructions)
- Trax: Used for handling Google Trax arrays (refer to https://trax-ml.readthedocs.io/en/latest/notebooks/trax_intro.html for installation instructions)
Install using pip:
``pip install trax``
Run:
# On machine 1 (or process 1): Publisher waits for keyboard input and transmits message
``python3 trax_example.py --mode publish``
# On machine 2 (or process 2): Listener waits for message and prints the entire dummy object
``python3 trax_example.py --mode listen``
"""

import argparse

try:
import trax
from trax.fastmath import numpy as fastnp
trax.fastmath.use_backend('tensorflow-numpy')
except ImportError:
print("Install Trax before running this script.")

from wrapyfi.connect.wrapper import MiddlewareCommunicator, DEFAULT_COMMUNICATOR


class Notifier(MiddlewareCommunicator):
@MiddlewareCommunicator.register(
"NativeObject",
"$mware",
"Notifier",
"/notify/test_trax_exchange",
carrier="tcp",
should_wait=True,
)
def exchange_object(self, mware=None):
"""
Exchange messages with Trax arrays and other native Python objects.
"""
msg = input("Type your message: ")
ret = {
"message": msg,
"trax_array": fastnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
"trax_ones": fastnp.ones(3),
"trax_zeros": fastnp.zeros(3),
}
return (ret,)


def parse_args():
"""
Parse command line arguments.
"""
parser = argparse.ArgumentParser(
description="A message publisher and listener for native Python objects and Trax arrays."
)
parser.add_argument(
"--mode",
type=str,
default="publish",
choices={"publish", "listen"},
help="The transmission mode",
)
parser.add_argument(
"--mware",
type=str,
default=DEFAULT_COMMUNICATOR,
choices=MiddlewareCommunicator.get_communicators(),
help="The middleware to use for transmission",
)
return parser.parse_args()


def main(args):
"""
Main function to initiate Notifier class and communication.
"""
notifier = Notifier()
notifier.activate_communication(Notifier.exchange_object, mode=args.mode)

while True:
(msg_object,) = notifier.exchange_object(mware=args.mware)
print("Method result:", msg_object)


if __name__ == "__main__":
args = parse_args()
main(args)
2 changes: 1 addition & 1 deletion wrapyfi/plugins/jax_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

try:
# if jax 0.3.22 is installed, then jax.numpy is a module
types = None if not HAVE_JAX else jax.numpy.DeviceArray.__mro__[:-1] + (jax.Array,)
types = None if not HAVE_JAX else (jax.Array,)
except AttributeError:
types = None if not HAVE_JAX else jax.numpy.DeviceArray.__mro__[:-1]

Expand Down
89 changes: 89 additions & 0 deletions wrapyfi/plugins/trax_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""
Encoder and Decoder for Trax Array Data via Wrapyfi.
This script provides mechanisms to encode and decode Trax Array Data using Wrapyfi.
It utilizes base64 encoding and pickle serialization.
The script contains a class, `TraxArray`, registered as a plugin to manage the
conversion of Trax Array data (if available) between its original and encoded forms.
Requirements:
- Wrapyfi: Middleware communication wrapper (refer to the Wrapyfi documentation for installation instructions)
- Trax: A deep learning library that focuses on clear code and speed (refer to https://trax-ml.readthedocs.io/en/latest/notebooks/trax_intro.html for installation instructions)
Note: If Trax is not available, HAVE_TRAX will be set to False and
the plugin will be registered with no types. Trax uses JAX or TensorFlow-NumPy as its backend,
so they must be installed as well. Trax installs JAX as a dependency, but TensorFlow must be installed separately.
You can install the necessary packages using pip:
``pip install trax`` # Basic installation of Trax
"""

import pickle
import base64

from wrapyfi.utils import *

try:
import trax
import jax
import jaxlib.xla_extension
HAVE_TRAX = True
except ImportError:
HAVE_TRAX = False


@PluginRegistrar.register(
types=None if not HAVE_TRAX else jaxlib.xla_extension.ArrayImpl.__mro__[:-1]
)
class TraxArray(Plugin):
def __init__(self, **kwargs):
"""
Initialize the TraxArray plugin.
"""
pass

def encode(self, obj, *args, **kwargs):
"""
Encode Trax Array data using pickle and base64.
:param obj: jaxlib.xla_extension.ArrayImpl: The Trax Array data to encode
:param args: tuple: Additional arguments (not used)
:param kwargs: dict: Additional keyword arguments (not used)
:return: Tuple[bool, dict]: A tuple containing:
- bool: Always True, indicating that the encoding was successful
- dict: A dictionary containing:
- '__wrapyfi__': A tuple containing the class name, pickled data string, and any buffer data
"""
buffers = []
obj_data = pickle.dumps(obj, protocol=5, buffer_callback=buffers.append)
obj_buffers = list(
map(lambda x: base64.b64encode(memoryview(x)).decode("ascii"), buffers)
)
return True, dict(
__wrapyfi__=(
str(self.__class__.__name__),
obj_data.decode("latin1"),
*obj_buffers,
)
)

def decode(self, obj_type, obj_full, *args, **kwargs):
"""
Decode a pickled and base64 encoded string back into Trax Array data.
:param obj_type: type: The expected type of the decoded object (not used)
:param obj_full: tuple: A tuple containing the pickled data string and any buffer data
:param args: tuple: Additional arguments (not used)
:param kwargs: dict: Additional keyword arguments (not used)
:return: Tuple[bool, pa.StructArray]: A tuple containing:
- bool: Always True, indicating that the decoding was successful
- jaxlib.xla_extension.ArrayImpl: The decoded Trax Array data
"""
obj_data = obj_full[1].encode("latin1")
obj_buffers = list(
map(lambda x: base64.b64decode(x.encode("ascii")), obj_full[2:])
)
obj_data = bytearray(obj_data)
for buf in obj_buffers:
obj_data += buf
return True, pickle.loads(obj_data, buffers=obj_buffers)

0 comments on commit e77475e

Please sign in to comment.