Deal with static but ragged arrays in JAX #19889
-
Hi, this is a fairly specific question about a use case that I can't manage on my own. The idea is simple: I want to work with a list of arrays of different sizes (corresponding to tabulated fluxes associated with a list of temperatures). The function interpolates these tables in two stages: I select 2 vectors around the desired temperature, I interpolate these two arrays to obtain, reducing them to two values, and I combine these two values. The only issue is that each of this vector has a different length (due to inhomogeneous sampling), but they are static, and no intermediate vector is produced with a shape which is unknown at runtime. I have tried to pad everything to the same size, but this is not achievable in this specific use case, and try to think of way around to use vmap/scan or equivalent to get a fast evaluation. Any thoughts on approaching this ? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
Hi - I think there are two general possibilities here. First, you could store your ragged arrays in a Python list, where each element is a JAX array of different size. You wouldn't be able to An alternative would be to design your own concatenated storage format for the ragged data; say one array Finally, a third approach might be to pad your ragged arrays so they fit into a two-dimensional array, with the lengths stored in a companion array. There are tradeoffs here with respect to a concatenated approach in terms of memory and computation, but the benefit is that it would allow you to use As you can see, unfortunately there's no general one-size-fits-all approach here, but one of those three options should probably unblock you. What do you think? |
Beta Was this translation helpful? Give feedback.
Hi - I think there are two general possibilities here. First, you could store your ragged arrays in a Python list, where each element is a JAX array of different size. You wouldn't be able to
vmap
orscan
over the list, but you could iterate using e.g. Python list comprehensions and XLA will generate efficient code for the resulting array operations. The only downside here is that these Python loops will be unrolled before being passed to the compiler, so if your list has many entries it could lead to very slow compilation.An alternative would be to design your own concatenated storage format for the ragged data; say one array
data
with the concatenated results, and another arraystart_i…