Skip to content

Deal with static but ragged arrays in JAX #19889

Answered by jakevdp
renecotyfanboy asked this question in Q&A
Discussion options

You must be logged in to vote

Hi - I think there are two general possibilities here. First, you could store your ragged arrays in a Python list, where each element is a JAX array of different size. You wouldn't be able to vmap or scan over the list, but you could iterate using e.g. Python list comprehensions and XLA will generate efficient code for the resulting array operations. The only downside here is that these Python loops will be unrolled before being passed to the compiler, so if your list has many entries it could lead to very slow compilation.

An alternative would be to design your own concatenated storage format for the ragged data; say one array data with the concatenated results, and another array start_i…

Replies: 1 comment 4 replies

Comment options

You must be logged in to vote
4 replies
@renecotyfanboy
Comment options

@cottrell
Comment options

@renecotyfanboy
Comment options

@cottrell
Comment options

Answer selected by renecotyfanboy
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants