Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Jax support #4

Merged
merged 10 commits into from
Sep 24, 2024
5 changes: 5 additions & 0 deletions .github/workflows/deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ jobs:
run: |
pip install --upgrade pip
pip install build
- name: Patch the README links to point to the correct files at the current tag
run: |
perl -i.bak -pe 's/\[(.*)\]\((?!http)\.?\/?(.*)\.([a-z]+)\)/[\1](https:\/\/raw.github.com\/idiap\/RawSpeechClassification\/${{ github.event.release.tag_name }}\/\2.\3)/g' README.md
perl -i.bak -pe 's/\[(.*)\]\((?!http)\.?\/?(.*)\)/[\1](https:\/\/github.com\/idiap\/RawSpeechClassification\/tree\/${{ github.event.release.tag_name }}\/\2)/g' README.md
rm README.md.bak
- name: Package the project
run: python -m build
- name: Produce a GitHub actions artifact (the package)
Expand Down
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ SPDX-License-Identifier: GPL-3.0-only

# Changelog

## September 2024

- Add Jax backend
- Make pip installable package on PyPi

## August 2024

- Update the code for Keras 3 with PyTorch or Tensorflow backend
Expand Down
31 changes: 26 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ SPDX-License-Identifier: GPL-3.0-only

# Raw Speech Classification

[![PyPI package](https://shields.io/pypi/v/raw-speech-classification.svg?logo=pypi)](https:/pypi.org/project/raw-speech-classification)

Trains CNN (or any neural network based) classifiers from raw speech using Keras and
tests them. The inputs are lists of wav files, where each file is labelled. It then
creates fixed length signals and processes them. During testing, it computes scores at
Expand All @@ -14,7 +16,7 @@ the fixed length signals.

## Installation

### From source, In a Conda environment
### From source in a conda environment

To install Keras 3 with PyTorch backend, run:

Expand All @@ -28,6 +30,12 @@ To install Keras 3 with TensorFlow backend, run:
conda env create -f conda/rsclf-tensorflow.yaml
```

To install Keras 3 with Jax backend, run:

```bash
conda env create -f conda/rsclf-jax.yaml
```

Then install the package in that environment (the default name is `rsclf`) with:

```bash
Expand All @@ -49,15 +57,28 @@ or
pip install raw-speech-classification[tensorflow]
```

You'll also need to set the `KERAS_BACKEND` environment variable to the correct backend
or

```bash
pip install raw-speech-classification[jax]
```

If you already have an environment with PyTorch, TensorFlow, or Jax
installed, you can simply run:

```bash
pip install raw-speech-classification
```

You will also need to set the `KERAS_BACKEND` environment variable to the correct backend
before running `rsclf-train` or `rsclf-test` (see below), or globally for the current
bash session with:

```bash
export KERAS_BACKEND=torch
```

Replace `torch` by `tensorflow` accordingly.
Replace `torch` by `tensorflow` or `jax` accordingly.

## Using the code

Expand All @@ -68,7 +89,7 @@ Replace `torch` by `tensorflow` accordingly.
`root` option could be `/home/bob/data/my_dataset` and the content of the files would
then be like:

```txt
```text
part1/file1.wav 1
part1/file2.wav 0
```
Expand Down Expand Up @@ -127,7 +148,7 @@ obtain the following curve in `results/seg-f1/plot.png`:
probabilities. If you need the results per speaker, configure it accordingly (see the
script for details). The default output format is:

```txt
```text
<speakerID> <label> [<posterior_probability_vector>]
```

Expand Down
16 changes: 16 additions & 0 deletions conda/rsclf-jax.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# SPDX-FileCopyrightText: Copyright © Idiap Research Institute <[email protected]>
#
# SPDX-License-Identifier: GPL-3.0-only

name: rsclf
dependencies:
- python=3.11
- pip
- pip:
- keras
- h5py
- scipy
- jax[cuda12]
- matplotlib
- numpy
- polars
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ torch = [
tensorflow = [
"tensorflow[and-cuda]",
]
jax = [
"jax[cuda12]",
]
dev = [
"pre-commit",
]
Expand Down