Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CustomCall failed: jaxlib/gpu/prng_kernels #339

Open
ERIK54600 opened this issue Sep 4, 2024 · 1 comment
Open

CustomCall failed: jaxlib/gpu/prng_kernels #339

ERIK54600 opened this issue Sep 4, 2024 · 1 comment

Comments

@ERIK54600
Copy link

While running this in Google Colab I get the following error: I am using the pro version of Google Collab.

XlaRuntimeError Traceback (most recent call last)
in <cell line: 9>()
9 for i in trange(max(n_predictions // jax.device_count(), 1)):
10 # get a new key
---> 11 key, subkey = jax.random.split(key)
12 # generate images
13 encoded_images = p_generate(

10 frames
[... skipping hidden 2 frame]

/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py in _execute_compiled(name, compiled, input_handler, output_buffer_counts, result_handler, has_unordered_effects, ordered_effects, kept_var_idx, has_host_callbacks, *args)
893 runtime_token = None
894 else:
--> 895 out_flat = compiled.execute(in_flat)
896 check_special(name, out_flat)
897 out_bufs = unflatten(out_flat, output_buffer_counts)

XlaRuntimeError: INTERNAL: CustomCall failed: jaxlib/gpu/prng_kernels.cc:33: operation gpuGetLastError() failed: cudaGetErrorString symbol not found

@ERIK54600
Copy link
Author

This happens when running this cell

from flax.training.common_utils import shard_prng_key
import numpy as np
from PIL import Image
from tqdm.notebook import trange

print(f"Prompts: {prompts}\n")

generate images

images = []
for i in trange(max(n_predictions // jax.device_count(), 1)):
# get a new key
key, subkey = jax.random.split(key)
# generate images
encoded_images = p_generate(
tokenized_prompt,
shard_prng_key(subkey),
params,
gen_top_k,
gen_top_p,
temperature,
cond_scale,
)
# remove BOS
encoded_images = encoded_images.sequences[..., 1:]
# decode images
decoded_images = p_decode(encoded_images, vqgan_params)
decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
for decoded_img in decoded_images:
img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))
images.append(img)
display(img)
print()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant