You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am a PhD Student at TUM, and we are using MuJoCo in various projects from robot manipulation to drones.
My setup
Setup:
mujoco 3.2.6
mujoco-mjx 3.2.6
x64 Ubuntu 22.04
What's happening? What did you expect?
Creating an MJX model and jit-compiling mjx.step with jax results in two compilations instead of one as expected. This noticeably slows down simulation start-up times for bigger projects. jitting the function should only incur compilation costs once. We can reproduce the issue using a stripped-down version of the tutorial notebook.
The reason for the recompilation is that mjx_data.time is weakly typed and changes to a strong type after the first call to mjx.step. The changed type causes jax to recompile the function on the second iteration. Since mjx_data.time stays strongly typed from step 2 onwards, successive calls do not compile again.
We can confirm that jax is compiling twice by timing the first step and inspecting the cache size of the jitted function.
Steps for reproduction
Load the model below.
Run the code below without enabling the line promoting time to a strong type.
The printouts confirm that jax jits twice at step 1 and 2.
Run the code below again with the line promoting time to a strong type.
The printouts show that jax only jits once at step 1.
importtimeimportjaximportmujocofrommujocoimportmjx# Make model and datamj_model=mujoco.MjModel.from_xml_string(xml)
mj_data=mujoco.MjData(mj_model)
mjx_model=mjx.put_model(mj_model)
mjx_data=mjx.put_data(mj_model, mj_data)
# mjx_data.time is currently a weak_type, will cause double compilation because jax promotes it to# a strong type on the first use, and then has to recompile the jit function with the changed typebatch=jax.vmap(lambda_: mjx_data)(jax.numpy.arange(4096))
jit_step=jax.jit(jax.vmap(mjx.step, in_axes=(None, 0)))
# Enabling this line avoids double compilation due to mjx_data.time being a weak type.# batch = batch.replace(time=jax.numpy.float32(batch.time))timings= []
for_inrange(3):
print(f"Number of compiled functions for jit_step: {jit_step._cache_size()}")
tstart=time.perf_counter()
batch=jit_step(mjx_model, batch)
jax.block_until_ready(batch)
tend=time.perf_counter()
timings.append(tend-tstart)
print(f"Step timings: 1: {timings[0]:.0e}s | 2: {timings[1]:.0e}s | 3: {timings[2]:.0e}s")
# Output without promoting time to a strong type:# Number of compiled functions for jit_step: 0# Number of compiled functions for jit_step: 1# Number of compiled functions for jit_step: 2# Step timings: 1: 8e-01s | 2: 6e-01s | 3: 2e-03s# Output with promoting time to a strong type:# Number of compiled functions for jit_step: 0# Number of compiled functions for jit_step: 1# Number of compiled functions for jit_step: 1# Step timings: 1: 8e-01s | 2: 2e-03s | 3: 2e-03s
Intro
Hi!
I am a PhD Student at TUM, and we are using MuJoCo in various projects from robot manipulation to drones.
My setup
Setup:
What's happening? What did you expect?
Creating an MJX model and jit-compiling
mjx.step
with jax results in two compilations instead of one as expected. This noticeably slows down simulation start-up times for bigger projects. jitting the function should only incur compilation costs once. We can reproduce the issue using a stripped-down version of the tutorial notebook.The reason for the recompilation is that
mjx_data.time
is weakly typed and changes to a strong type after the first call tomjx.step
. The changed type causes jax to recompile the function on the second iteration. Sincemjx_data.time
stays strongly typed from step 2 onwards, successive calls do not compile again.We can confirm that jax is compiling twice by timing the first step and inspecting the cache size of the jitted function.
Steps for reproduction
Minimal model for reproduction
minimal XML
Code required for reproduction
Confirmations
The text was updated successfully, but these errors were encountered: