Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multimer prediction error: "ValueError: Unable to retrieve parameter 'scale' for module......." #38

Open
tclin422 opened this issue Aug 3, 2023 · 5 comments

Comments

@tclin422
Copy link

tclin422 commented Aug 3, 2023

Hi,

I've finished the featuring step and got the output feature.pkl for the multiple prediction. Then I got these error messages:

(parafold) [linx@localhost ParallelFold-main]$ I0803 08:18:26.620421 139735959648064 templates.py:857] Using precomputed obsolete pdbs /data/linx/02_Database/01_AF2/pdb_mmcif/obsolete.dat.
I0803 08:18:26.864619 139735959648064 xla_bridge.py:603] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Interpreter CUDA Host
I0803 08:18:26.865048 139735959648064 xla_bridge.py:603] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
I0803 08:18:30.334805 139735959648064 run_alphafold.py:445] Have 25 models: ['model_1_multimer_pred_0', 'model_1_multimer_pred_1', 'model_1_multimer_pred_2', 'model_1_multimer_pred_3', 'model_1_multimer_pred_4', 'model_2_multimer_pred_0', 'model_2_multimer_pred_1', 'model_2_multimer_pred_2', 'model_2_multimer_pred_3', 'model_2_multimer_pred_4', 'model_3_multimer_pred_0', 'model_3_multimer_pred_1', 'model_3_multimer_pred_2', 'model_3_multimer_pred_3', 'model_3_multimer_pred_4', 'model_4_multimer_pred_0', 'model_4_multimer_pred_1', 'model_4_multimer_pred_2', 'model_4_multimer_pred_3', 'model_4_multimer_pred_4', 'model_5_multimer_pred_0', 'model_5_multimer_pred_1', 'model_5_multimer_pred_2', 'model_5_multimer_pred_3', 'model_5_multimer_pred_4']
I0803 08:18:30.335292 139735959648064 run_alphafold.py:459] Using random seed 338821470640361521 for the data pipeline
I0803 08:18:30.335713 139735959648064 run_alphafold.py:189] Predicting R04373_impa
I0803 08:18:30.491963 139735959648064 run_alphafold.py:231] Running model model_1_multimer_pred_0 on R04373_impa
I0803 08:18:30.492612 139735959648064 model.py:165] Running predict with shape(feat) = {'aatype': (1787,), 'residue_index': (1787,), 'seq_length': (), 'msa': (3072, 1787), 'num_alignments': (), 'template_aatype': (4, 1787), 'template_all_atom_mask': (4, 1787, 37), 'template_all_atom_positions': (4, 1787, 37, 3), 'asym_id': (1787,), 'sym_id': (1787,), 'entity_id': (1787,), 'deletion_matrix': (3072, 1787), 'deletion_mean': (1787,), 'all_atom_mask': (1787, 37), 'all_atom_positions': (1787, 37, 3), 'assembly_num_chains': (), 'entity_mask': (1787,), 'num_templates': (), 'cluster_bias_mask': (3072,), 'bert_mask': (3072, 1787), 'seq_mask': (1787,), 'msa_mask': (3072, 1787)}
Traceback (most recent call last):
File "/home/linx/software/02_parafold/ParallelFold-main/run_alphafold.py", line 491, in
app.run(main)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/home/linx/software/02_parafold/ParallelFold-main/run_alphafold.py", line 464, in main
predict_structure(
File "/home/linx/software/02_parafold/ParallelFold-main/run_alphafold.py", line 239, in predict_structure
prediction_result = model_runner.predict(processed_feature_dict,
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/model.py", line 167, in predict
result = self.apply(self.params, jax.random.PRNGKey(random_seed), feat)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/pjit.py", line 250, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/pjit.py", line 158, in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/api.py", line 306, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/pjit.py", line 505, in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/pjit.py", line 971, in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/linear_util.py", line 345, in memoized_fun
ans = call(fun, *args)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/pjit.py", line 924, in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers
)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/transform.py", line 128, in apply_fn
out, state = f.apply(params, {}, *args, **kwargs)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/transform.py", line 357, in apply_fn
out = f(*args, **kwargs)
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/model.py", line 77, in _forward_fn
return model(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped
out = f(*args, **kwargs)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 508, in call
num_recycles, _, prev, safe_key = hk.while_loop(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/stateful.py", line 898, in while_loop
val, state = jax.lax.while_loop(pure_cond_fun, pure_body_fun, init_val)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/lax/control_flow/loops.py", line 1150, in while_loop
init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(init_val)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/lax/control_flow/loops.py", line 1133, in _create_jaxpr
body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/lax/control_flow/common.py", line 64, in _initial_style_jaxpr
jaxpr, consts, out_tree = _initial_style_open_jaxpr(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/lax/control_flow/common.py", line 58, in _initial_style_open_jaxpr
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers
)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/stateful.py", line 891, in pure_body_fun
val = body_fun(val)
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 486, in recycle_body
ret = apply_network(prev=prev, safe_key=safe_key2)
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 449, in apply_network
return impl(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped
out = f(*args, **kwargs)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 321, in call
repr_shape = hk.eval_shape(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/stateful.py", line 951, in eval_shape
out_shape = jax.eval_shape(stateless_fun, internal_state(), *args, **kwargs)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/api.py", line 2783, in eval_shape
out = pe.abstract_eval_fun(wrapped_fun.call_wrapped,
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 670, in abstract_eval_fun
_, avals_out, _ = trace_to_jaxpr_dynamic(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers
)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/stateful.py", line 947, in stateless_fun
out = fun(*args, **kwargs)
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 322, in
lambda: embedding_module(batch, is_training))
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped
out = f(*args, **kwargs)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 696, in call
template_act = template_module(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped
out = f(*args, **kwargs)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 883, in call
summed_template_embeddings, _ = hk.scan(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/stateful.py", line 640, in scan
(carry, state), ys = jax.lax.scan(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/lax/control_flow/loops.py", line 250, in scan
init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/lax/control_flow/loops.py", line 236, in _create_jaxpr
jaxpr, consts, out_tree = _initial_style_jaxpr(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/lax/control_flow/common.py", line 64, in _initial_style_jaxpr
jaxpr, consts, out_tree = _initial_style_open_jaxpr(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/lax/control_flow/common.py", line 58, in _initial_style_open_jaxpr
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers
)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/stateful.py", line 623, in stateful_fun
carry, out = f(carry, x)
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 879, in scan_fn
return carry + partial_template_embedder(*x), None
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 866, in partial_template_embedder
return template_embedder(query_embedding,
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped
out = f(*args, **kwargs)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 1036, in call
act, safe_key = template_stack((act, safe_subkey))
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/layer_stack.py", line 265, in wrapped
ret = _LayerStackNoState(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped
out = f(*args, **kwargs)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/layer_stack.py", line 156, in call
carry, zs = hk.scan(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/stateful.py", line 640, in scan
(carry, state), ys = jax.lax.scan(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/lax/control_flow/loops.py", line 250, in scan
init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/lax/control_flow/loops.py", line 236, in _create_jaxpr
jaxpr, consts, out_tree = _initial_style_jaxpr(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/lax/control_flow/common.py", line 64, in _initial_style_jaxpr
jaxpr, consts, out_tree = _initial_style_open_jaxpr(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/lax/control_flow/common.py", line 58, in _initial_style_open_jaxpr
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers
)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/stateful.py", line 623, in stateful_fun
carry, out = f(carry, x)
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/layer_stack.py", line 149, in layer
out_x, z = self._call_wrapped(carry.x, *scanned.args_ys)
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/layer_stack.py", line 182, in _call_wrapped
ret = self._f(*args)
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 1022, in template_iteration_fn
act = template_iteration(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped
out = f(*args, **kwargs)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 1084, in call
act = dropout_wrapper_fn(
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules.py", line 76, in dropout_wrapper
residual = module(input_act, mask, is_training=is_training, **kwargs)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped
out = f(*args, **kwargs)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules.py", line 1316, in call
return self._triangle_multiplication(left_act, left_mask)
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules.py", line 1326, in _triangle_multiplication
act = common_modules.LayerNorm(axis=[-1], create_scale=True, create_offset=True,
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped
out = f(*args, **kwargs)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/common_modules.py", line 176, in call
scale = hk.get_parameter(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/base.py", line 448, in wrapped
return wrapped._current(*args, **kwargs)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/base.py", line 524, in get_parameter
raise ValueError(
jax._src.traceback_util.UnfilteredStackTrace: ValueError: Unable to retrieve parameter 'scale' for module 'alphafold/alphafold_iteration/evoformer/template_embedding/single_template_embedding/template_embedding_iteration/triangle_multiplication_outgoing/layer_norm_input' All parameters must be created as part of init.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.


The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/home/linx/software/02_parafold/ParallelFold-main/run_alphafold.py", line 491, in
app.run(main)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/home/linx/software/02_parafold/ParallelFold-main/run_alphafold.py", line 464, in main
predict_structure(
File "/home/linx/software/02_parafold/ParallelFold-main/run_alphafold.py", line 239, in predict_structure
prediction_result = model_runner.predict(processed_feature_dict,
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/model.py", line 167, in predict
result = self.apply(self.params, jax.random.PRNGKey(random_seed), feat)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/transform.py", line 128, in apply_fn
out, state = f.apply(params, {}, *args, **kwargs)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/transform.py", line 357, in apply_fn
out = f(*args, **kwargs)
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/model.py", line 77, in _forward_fn
return model(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped
out = f(*args, **kwargs)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 508, in call
num_recycles, _, prev, safe_key = hk.while_loop(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/stateful.py", line 898, in while_loop
val, state = jax.lax.while_loop(pure_cond_fun, pure_body_fun, init_val)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/stateful.py", line 891, in pure_body_fun
val = body_fun(val)
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 486, in recycle_body
ret = apply_network(prev=prev, safe_key=safe_key2)
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 449, in apply_network
return impl(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped
out = f(*args, **kwargs)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 321, in call
repr_shape = hk.eval_shape(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/stateful.py", line 951, in eval_shape
out_shape = jax.eval_shape(stateless_fun, internal_state(), *args, **kwargs)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/stateful.py", line 947, in stateless_fun
out = fun(*args, **kwargs)
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 322, in
lambda: embedding_module(batch, is_training))
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped
out = f(*args, **kwargs)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 696, in call
template_act = template_module(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped
out = f(*args, **kwargs)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 883, in call
summed_template_embeddings, _ = hk.scan(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/stateful.py", line 640, in scan
(carry, state), ys = jax.lax.scan(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/stateful.py", line 623, in stateful_fun
carry, out = f(carry, x)
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 879, in scan_fn
return carry + partial_template_embedder(*x), None
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 866, in partial_template_embedder
return template_embedder(query_embedding,
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped
out = f(*args, **kwargs)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 1036, in call
act, safe_key = template_stack((act, safe_subkey))
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/layer_stack.py", line 265, in wrapped
ret = _LayerStackNoState(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped
out = f(*args, **kwargs)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/layer_stack.py", line 156, in call
carry, zs = hk.scan(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/stateful.py", line 640, in scan
(carry, state), ys = jax.lax.scan(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/stateful.py", line 623, in stateful_fun
carry, out = f(carry, x)
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/layer_stack.py", line 149, in layer
out_x, z = self._call_wrapped(carry.x, *scanned.args_ys)
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/layer_stack.py", line 182, in _call_wrapped
ret = self._f(*args)
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 1022, in template_iteration_fn
act = template_iteration(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped
out = f(*args, **kwargs)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 1084, in call
act = dropout_wrapper_fn(
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules.py", line 76, in dropout_wrapper
residual = module(input_act, mask, is_training=is_training, **kwargs)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped
out = f(*args, **kwargs)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules.py", line 1316, in call
return self._triangle_multiplication(left_act, left_mask)
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules.py", line 1326, in _triangle_multiplication
act = common_modules.LayerNorm(axis=[-1], create_scale=True, create_offset=True,
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped
out = f(*args, **kwargs)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/common_modules.py", line 176, in call
scale = hk.get_parameter(
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/base.py", line 448, in wrapped
return wrapped._current(*args, **kwargs)
File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/base.py", line 524, in get_parameter
raise ValueError(
ValueError: Unable to retrieve parameter 'scale' for module 'alphafold/alphafold_iteration/evoformer/template_embedding/single_template_embedding/template_embedding_iteration/triangle_multiplication_outgoing/layer_norm_input' All parameters must be created as part of init.

@Zuricho
Copy link
Owner

Zuricho commented Aug 7, 2023

How did you do the feature step. Could you provide the command to run run_alphafold.sh in feature step?

@tclin422
Copy link
Author

image

@gilspeyer
Copy link

I am also getting this error. tclin422, did you find out what caused it?

@von-elfen
Copy link

I also got this error, Have you solve that?? This error so weird, i tried many solutions but none works :(

@luwei0917
Copy link

I got the same error and solved it.
this is because I was using '-m model_1_multimer,model_2_multimer,model_3_multimer,model_4_multimer,model_5_multimer' command,
but actually, I should use '-m model_1_multimer_v3,model_2_multimer_v3,model_3_multimer_v3,model_4_multimer_v3,model_5_multimer_v3'.
I was setting the soft link from the model params_model_1_multimer_v3.npz to params_model_1_multimer.npz.(which I shouldn't do)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants