Skip to content

Commit

Permalink
fix (JAX backend)(manipulation.py): adding a temp fix to workaround t…
Browse files Browse the repository at this point in the history
…he issue wherein `jnp.broadcast_to` fails when passing an instance of `nnx.Variable`
  • Loading branch information
YushaArif99 committed Sep 24, 2024
1 parent 240cd14 commit 667e054
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions ivy/functional/backends/jax/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,11 @@ def expand(
for i, dim in enumerate(shape):
if dim < 0:
shape[i] = x.shape[i]

#TODO: remove this once JAX supports passing in nnx.Variables to
# jnp.broadcast_to
if hasattr(x, '__jax_array__'):
x = x.__jax_array__()
return jnp.broadcast_to(x, tuple(shape))


Expand Down

0 comments on commit 667e054

Please sign in to comment.