-
Notifications
You must be signed in to change notification settings - Fork 909
/
example.py
39 lines (26 loc) · 1.03 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# Copyright © 2024 Apple Inc.
import mlx.core as mx
from utils import load_audio, save_audio
from encodec import EncodecModel
# Load the 48 KHz model and preprocessor.
model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
# Load an audio file
audio = load_audio("/path/to/audio", model.sampling_rate, model.channels)
# Preprocess the audio (this can also be a list of arrays for batched
# processing).
feats, mask = processor(audio)
# Encode at the given bandwidth. A lower bandwidth results in more
# compression but lower reconstruction quality.
@mx.compile
def encode(feats, mask):
return model.encode(feats, mask, bandwidth=3)
# Decode to reconstruct the audio
@mx.compile
def decode(codes, scales, mask):
return model.decode(codes, scales, mask)
codes, scales = encode(feats, mask)
reconstructed = decode(codes, scales, mask)
# Trim any padding:
reconstructed = reconstructed[0, : len(audio)]
# Save the audio as a wave file
save_audio("reconstructed.wav", reconstructed, model.sampling_rate)