-
Notifications
You must be signed in to change notification settings - Fork 849
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mujoco MJX stepping is not reproducible for some models #2252
Labels
MJX
Using JAX to run on GPU
Comments
KamatMayur
changed the title
mujoco mjx stepping is not reproducible for some models
Mujoco MJX stepping is not reproducible for some models
Nov 26, 2024
Additional followup regarding some more observations about this. I used jax.lax.scan() to simulate and surprisingly now the results were consistent. However the scan version is way slower than python for loop which is again something I am unable to understand. import time
import mujoco
import mujoco.mjx as mjx
import jax
import os
from jax import lax
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags
path = "temp.xml"
jit_step = jax.jit(mjx.step)
model = mujoco.MjModel.from_xml_path(path)
data = mujoco.MjData(model)
model_mjx = mjx.put_model(model)
data_mjx = mjx.put_data(model, data)
_ = jit_step(model_mjx, data_mjx)
def timeit_decorator(func):
def wrapper(*args, **kwargs):
start_time = time.perf_counter()
result = func(*args, **kwargs)
end_time = time.perf_counter()
elapsed_time = end_time - start_time
print(f"Function {func.__name__} executed in {elapsed_time:.6f} seconds.")
return result
return wrapper
@timeit_decorator
def simulate_1(path):
model = mujoco.MjModel.from_xml_path(path)
data = mujoco.MjData(model)
for _ in range(2000):
mujoco.mj_step(model, data)
return data
@timeit_decorator
def simulate_2(path):
model = mujoco.MjModel.from_xml_path(path)
data = mujoco.MjData(model)
model_mjx = mjx.put_model(model)
data_mjx = mjx.put_data(model, data)
for _ in range(2000):
data_mjx = jit_step(model_mjx, data_mjx)
return data_mjx
@timeit_decorator
def simulate_3(path):
# Load the Mujoco model and data
model = mujoco.MjModel.from_xml_path(path)
data = mujoco.MjData(model)
model_mjx = mjx.put_model(model)
data_mjx = mjx.put_data(model, data)
# Define a single step function for lax.scan
def step_fn(data_mjx, _):
data_mjx = jit_step(model_mjx, data_mjx)
return data_mjx, None
# Use lax.scan to perform 2000 steps
final_data_mjx, _ = lax.scan(step_fn, data_mjx, None, length=2000)
return final_data_mjx
print(simulate_1(path).qpos[:6])
print(simulate_1(path).qpos[:6])
print(simulate_2(path).qpos[:6])
print(simulate_2(path).qpos[:6])
print(simulate_3(path).qpos[:6])
print(simulate_3(path).qpos[:6]) the ouputs i get are as follows:
|
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Intro
Hi!
I am a graduate, I use MuJoCo for my research on RL.
My setup
mujoco version: 3.2.5
python api
64 bit
Ubuntu 24.04.1 LTS
RTX 2060 super, 8GB @ 2010 MHz
What's happening? What did you expect?
I'm simulating a model for a number of steps multiple times after resetting it. The results are always consistent with the mujoco.mj_step(). But never consistent with mjx.step(). The following code can be used to reproduce the results. I have tested my own xml model but the same is applicable to the humanoid.xml model under the mjx/test_data/humanoid/ directory. The ouput of the mjx step is always different despite of restarting my ipynb notebook kernel.
NOTE: This is only happening to models simillar to what have been defined in mjx or brax. Not with the models under the mujoco/model/. However the mujoco step is always consistent no matter what the model is.
Steps for reproduction
Run the code given below with the humanoid.xml model
Minimal model for reproduction
Use the humanoid.xml model under the mjx/test_data/humanoid/ directory.
Code required for reproduction
Confirmations
The text was updated successfully, but these errors were encountered: