Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MJX step compiles twice due to mjx_data.time being a weak type #2306

Open
2 tasks done
amacati opened this issue Dec 22, 2024 · 1 comment
Open
2 tasks done

MJX step compiles twice due to mjx_data.time being a weak type #2306

amacati opened this issue Dec 22, 2024 · 1 comment
Assignees
Labels
bug Something isn't working

Comments

@amacati
Copy link

amacati commented Dec 22, 2024

Intro

Hi!

I am a PhD Student at TUM, and we are using MuJoCo in various projects from robot manipulation to drones.

My setup

Setup:

  • mujoco 3.2.6
  • mujoco-mjx 3.2.6
  • x64 Ubuntu 22.04

What's happening? What did you expect?

Creating an MJX model and jit-compiling mjx.step with jax results in two compilations instead of one as expected. This noticeably slows down simulation start-up times for bigger projects. jitting the function should only incur compilation costs once. We can reproduce the issue using a stripped-down version of the tutorial notebook.

The reason for the recompilation is that mjx_data.time is weakly typed and changes to a strong type after the first call to mjx.step. The changed type causes jax to recompile the function on the second iteration. Since mjx_data.time stays strongly typed from step 2 onwards, successive calls do not compile again.

We can confirm that jax is compiling twice by timing the first step and inspecting the cache size of the jitted function.

Steps for reproduction

  1. Load the model below.
  2. Run the code below without enabling the line promoting time to a strong type.
  3. The printouts confirm that jax jits twice at step 1 and 2.
  4. Run the code below again with the line promoting time to a strong type.
  5. The printouts show that jax only jits once at step 1.

Minimal model for reproduction

minimal XML
<mujoco>
  <worldbody>
    <light name="top" pos="0 0 1"/>
    <body name="box_and_sphere" euler="0 0 -30">
      <joint name="swing" type="hinge" axis="1 -1 0" pos="-.2 -.2 -.2"/>
      <geom name="red_box" type="box" size=".2 .2 .2" rgba="1 0 0 1"/>
      <geom name="green_sphere" pos=".2 .2 .2" size=".1" rgba="0 1 0 1"/>
    </body>
  </worldbody>
</mujoco>

Code required for reproduction

import time

import jax
import mujoco
from mujoco import mjx

# Make model and data
mj_model = mujoco.MjModel.from_xml_string(xml)
mj_data = mujoco.MjData(mj_model)
mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)
# mjx_data.time is currently a weak_type, will cause double compilation because jax promotes it to
# a strong type on the first use, and then has to recompile the jit function with the changed type

batch = jax.vmap(lambda _: mjx_data)(jax.numpy.arange(4096))
jit_step = jax.jit(jax.vmap(mjx.step, in_axes=(None, 0)))
# Enabling this line avoids double compilation due to mjx_data.time being a weak type.
# batch = batch.replace(time=jax.numpy.float32(batch.time))

timings = []
for _ in range(3):
    print(f"Number of compiled functions for jit_step: {jit_step._cache_size()}")
    tstart = time.perf_counter()
    batch = jit_step(mjx_model, batch)
    jax.block_until_ready(batch)
    tend = time.perf_counter()
    timings.append(tend - tstart)

print(f"Step timings: 1: {timings[0]:.0e}s | 2: {timings[1]:.0e}s | 3: {timings[2]:.0e}s")
# Output without promoting time to a strong type:
# Number of compiled functions for jit_step: 0
# Number of compiled functions for jit_step: 1
# Number of compiled functions for jit_step: 2
# Step timings: 1: 8e-01s | 2: 6e-01s | 3: 2e-03s

# Output with promoting time to a strong type:
# Number of compiled functions for jit_step: 0
# Number of compiled functions for jit_step: 1
# Number of compiled functions for jit_step: 1
# Step timings: 1: 8e-01s | 2: 2e-03s | 3: 2e-03s

Confirmations

@amacati amacati added the bug Something isn't working label Dec 22, 2024
@junqingqiao
Copy link

I used jax aot to solve this problem for now. I did " step_fn = jax.jit(step_fn).lower(mx, dx).compile()"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants