-
Not sure how to use jnp.pad with variable pad in a traced context, running into this error (TypeError: iteration over a 0-d array).
Short reproduction notebook here. Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 3 replies
-
afaik, padding a variable size returns an variable sized array. Which leads to error when jitting. from jax import numpy as jnp
import jax
img = jnp.arange(25).reshape((5, 5))
search_window = 5
result_img = jnp.zeros((img.shape[0] + 2*search_window, img.shape[1] + 2*search_window))
@jax.jit
def nlm(img, result_img):
img_pad = jnp.pad(img, ((result_img.shape[0] - img.shape[0])//2, (result_img.shape[1] - img.shape[1])//2))
nlm(img, result_img) |
Beta Was this translation helpful? Give feedback.
-
This is a misleading error (I'll plan to send a fix to improve it today), but as @BugQualia mentioned the issue is that you're padding with a dynamic value. One way to fix this is to mark the pad width as static: @jax.partial(jax.jit, static_argnums=1)
def nlm(img, search_window):
img_pad = jnp.pad(img, (search_window, search_window))
img_pad = jnp.pad(img, search_window) See JIT mechanics: tracing and static variables for some discussion of this. |
Beta Was this translation helpful? Give feedback.
This is a misleading error (I'll plan to send a fix to improve it today), but as @BugQualia mentioned the issue is that you're padding with a dynamic value. One way to fix this is to mark the pad width as static:
See JIT mechanics: tracing and static variables for some discussion of this.