From 638254f11303f4d596e16b99b4380da41df3c0e1 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Tue, 22 Aug 2023 03:51:12 +0000 Subject: [PATCH] style: lint --- app/gradio/app.py | 1 - src/dalle_mini/model/modeling.py | 5 +- tools/inference/inference_pipeline.ipynb | 1116 +++++++++++----------- tools/train/train.py | 2 - 4 files changed, 560 insertions(+), 564 deletions(-) diff --git a/app/gradio/app.py b/app/gradio/app.py index 649bdd854..39d17402a 100644 --- a/app/gradio/app.py +++ b/app/gradio/app.py @@ -22,7 +22,6 @@ def infer(prompt): with gr.Group(): with gr.Box(): with gr.Row().style(mobile_collapse=False, equal_height=True): - text = gr.Textbox( label="Enter your prompt", show_label=False, max_lines=1 ).style( diff --git a/src/dalle_mini/model/modeling.py b/src/dalle_mini/model/modeling.py index 25af7a1d5..53dd7ce96 100644 --- a/src/dalle_mini/model/modeling.py +++ b/src/dalle_mini/model/modeling.py @@ -77,6 +77,7 @@ def _smelu(x: Any) -> Any: ACT2FN.update({"smelu": smelu()}) + # deepnet initialization def deepnet_init(init_std, gain=1): init = jax.nn.initializers.normal(init_std) @@ -498,7 +499,6 @@ class GLU(nn.Module): @nn.compact def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: - if self.config.use_deepnet_scaling: gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"]( self.config @@ -567,7 +567,6 @@ class FFN(nn.Module): @nn.compact def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: - if self.config.use_deepnet_scaling: gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"]( self.config @@ -634,7 +633,6 @@ def __call__( output_attentions: bool = True, deterministic: bool = True, ) -> Tuple[jnp.ndarray]: - if self.config.use_scan: hidden_states = hidden_states[0] @@ -742,7 +740,6 @@ def __call__( output_attentions: bool = True, deterministic: bool = True, ) -> Tuple[jnp.ndarray]: - if self.config.use_scan: hidden_states = hidden_states[0] diff --git a/tools/inference/inference_pipeline.ipynb b/tools/inference/inference_pipeline.ipynb index 298a550ff..43455fe43 100644 --- a/tools/inference/inference_pipeline.ipynb +++ b/tools/inference/inference_pipeline.ipynb @@ -1,559 +1,561 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github", - "colab_type": "text" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "118UKH5bWCGa" - }, - "source": [ - "# DALL·E mini - Inference pipeline\n", - "\n", - "*Generate images from a text prompt*\n", - "\n", - "\n", - "\n", - "This notebook illustrates [DALL·E mini](https://github.com/borisdayma/dalle-mini) inference pipeline.\n", - "\n", - "Just want to play? Use directly [the app](https://www.craiyon.com/).\n", - "\n", - "For more understanding of the model, refer to [the report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA)." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "dS8LbaonYm3a" - }, - "source": [ - "## 🛠️ Installation and set-up" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "uzjAM2GBYpZX" - }, - "outputs": [], - "source": [ - "# Required only for colab environments + GPU\n", - "!pip install jax==0.3.25 jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n", - "\n", - "# Install required libraries\n", - "!pip install -q dalle-mini orbax==0.0.23\n", - "!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ozHzTkyv8cqU" - }, - "source": [ - "We load required models:\n", - "* DALL·E mini for text to encoded images\n", - "* VQGAN for decoding images\n", - "* CLIP for scoring predictions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "K6CxW2o42f-w" - }, - "outputs": [], - "source": [ - "# Model references\n", - "\n", - "# dalle-mega\n", - "DALLE_MODEL = \"dalle-mini/dalle-mini/mega-1-fp16:latest\" # can be wandb artifact or 🤗 Hub or local folder or google bucket\n", - "DALLE_COMMIT_ID = None\n", - "\n", - "# if the notebook crashes too often you can use dalle-mini instead by uncommenting below line\n", - "# DALLE_MODEL = \"dalle-mini/dalle-mini/mini-1:v0\"\n", - "\n", - "# VQGAN model\n", - "VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n", - "VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Yv-aR3t4Oe5v" - }, - "outputs": [], - "source": [ - "import jax\n", - "import jax.numpy as jnp\n", - "\n", - "# check how many devices are available\n", - "jax.local_device_count()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "92zYmvsQ38vL" - }, - "outputs": [], - "source": [ - "# Load models & tokenizer\n", - "from dalle_mini import DalleBart, DalleBartProcessor\n", - "from vqgan_jax.modeling_flax_vqgan import VQModel\n", - "from transformers import CLIPProcessor, FlaxCLIPModel\n", - "\n", - "# Load dalle-mini\n", - "model, params = DalleBart.from_pretrained(\n", - " DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False\n", - ")\n", - "\n", - "# Load VQGAN\n", - "vqgan, vqgan_params = VQModel.from_pretrained(\n", - " VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "o_vH2X1tDtzA" - }, - "source": [ - "Model parameters are replicated on each device for faster inference." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "wtvLoM48EeVw" - }, - "outputs": [], - "source": [ - "from flax.jax_utils import replicate\n", - "\n", - "params = replicate(params)\n", - "vqgan_params = replicate(vqgan_params)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "0A9AHQIgZ_qw" - }, - "source": [ - "Model functions are compiled and parallelized to take advantage of multiple devices." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "sOtoOmYsSYPz" - }, - "outputs": [], - "source": [ - "from functools import partial\n", - "\n", - "# model inference\n", - "@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4, 5, 6))\n", - "def p_generate(\n", - " tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale\n", - "):\n", - " return model.generate(\n", - " **tokenized_prompt,\n", - " prng_key=key,\n", - " params=params,\n", - " top_k=top_k,\n", - " top_p=top_p,\n", - " temperature=temperature,\n", - " condition_scale=condition_scale,\n", - " )\n", - "\n", - "\n", - "# decode image\n", - "@partial(jax.pmap, axis_name=\"batch\")\n", - "def p_decode(indices, params):\n", - " return vqgan.decode_code(indices, params=params)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HmVN6IBwapBA" - }, - "source": [ - "Keys are passed to the model on each device to generate unique inference per device." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "4CTXmlUkThhX" - }, - "outputs": [], - "source": [ - "import random\n", - "\n", - "# create a random key\n", - "seed = random.randint(0, 2**32 - 1)\n", - "key = jax.random.PRNGKey(seed)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "BrnVyCo81pij" - }, - "source": [ - "## 🖍 Text Prompt" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "rsmj0Aj5OQox" - }, - "source": [ - "Our model requires processing prompts." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "YjjhUychOVxm" - }, - "outputs": [], - "source": [ - "from dalle_mini import DalleBartProcessor\n", - "\n", - "processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "BQ7fymSPyvF_" - }, - "source": [ - "Let's define some text prompts." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "x_0vI9ge1oKr" - }, - "outputs": [], - "source": [ - "prompts = [\n", - " \"sunset over a lake in the mountains\",\n", - " \"the Eiffel tower landing on the moon\",\n", - "]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "XlZUG3SCLnGE" - }, - "source": [ - "Note: we could use the same prompt multiple times for faster inference." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "VKjEZGjtO49k" - }, - "outputs": [], - "source": [ - "tokenized_prompts = processor(prompts)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-CEJBnuJOe5z" - }, - "source": [ - "Finally we replicate the prompts onto each device." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "lQePgju5Oe5z" - }, - "outputs": [], - "source": [ - "tokenized_prompt = replicate(tokenized_prompts)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "phQ9bhjRkgAZ" - }, - "source": [ - "## 🎨 Generate images\n", - "\n", - "We generate images using dalle-mini model and decode them with the VQGAN." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "d0wVkXpKqnHA" - }, - "outputs": [], - "source": [ - "# number of predictions per prompt\n", - "n_predictions = 8\n", - "\n", - "# We can customize generation parameters (see https://huggingface.co/blog/how-to-generate)\n", - "gen_top_k = None\n", - "gen_top_p = None\n", - "temperature = None\n", - "cond_scale = 10.0" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "SDjEx9JxR3v8" - }, - "outputs": [], - "source": [ - "from flax.training.common_utils import shard_prng_key\n", - "import numpy as np\n", - "from PIL import Image\n", - "from tqdm.notebook import trange\n", - "\n", - "print(f\"Prompts: {prompts}\\n\")\n", - "# generate images\n", - "images = []\n", - "for i in trange(max(n_predictions // jax.device_count(), 1)):\n", - " # get a new key\n", - " key, subkey = jax.random.split(key)\n", - " # generate images\n", - " encoded_images = p_generate(\n", - " tokenized_prompt,\n", - " shard_prng_key(subkey),\n", - " params,\n", - " gen_top_k,\n", - " gen_top_p,\n", - " temperature,\n", - " cond_scale,\n", - " )\n", - " # remove BOS\n", - " encoded_images = encoded_images.sequences[..., 1:]\n", - " # decode images\n", - " decoded_images = p_decode(encoded_images, vqgan_params)\n", - " decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n", - " for decoded_img in decoded_images:\n", - " img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))\n", - " images.append(img)\n", - " display(img)\n", - " print()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tw02wG9zGmyB" - }, - "source": [ - "## 🏅 Optional: Rank images by CLIP score\n", - "\n", - "We can rank images according to CLIP.\n", - "\n", - "**Note: your session may crash if you don't have a subscription to Colab Pro.**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "RGjlIW_f6GA0" - }, - "outputs": [], - "source": [ - "# CLIP model\n", - "CLIP_REPO = \"openai/clip-vit-base-patch32\"\n", - "CLIP_COMMIT_ID = None\n", - "\n", - "# Load CLIP\n", - "clip, clip_params = FlaxCLIPModel.from_pretrained(\n", - " CLIP_REPO, revision=CLIP_COMMIT_ID, dtype=jnp.float16, _do_init=False\n", - ")\n", - "clip_processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)\n", - "clip_params = replicate(clip_params)\n", - "\n", - "# score images\n", - "@partial(jax.pmap, axis_name=\"batch\")\n", - "def p_clip(inputs, params):\n", - " logits = clip(params=params, **inputs).logits_per_image\n", - " return logits" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "FoLXpjCmGpju" - }, - "outputs": [], - "source": [ - "from flax.training.common_utils import shard\n", - "\n", - "# get clip scores\n", - "clip_inputs = clip_processor(\n", - " text=prompts * jax.device_count(),\n", - " images=images,\n", - " return_tensors=\"np\",\n", - " padding=\"max_length\",\n", - " max_length=77,\n", - " truncation=True,\n", - ").data\n", - "logits = p_clip(shard(clip_inputs), clip_params)\n", - "\n", - "# organize scores per prompt\n", - "p = len(prompts)\n", - "logits = np.asarray([logits[:, i::p, i] for i in range(p)]).squeeze()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "4AAWRm70LgED" - }, - "source": [ - "Let's now display images ranked by CLIP score." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "zsgxxubLLkIu" - }, - "outputs": [], - "source": [ - "for i, prompt in enumerate(prompts):\n", - " print(f\"Prompt: {prompt}\\n\")\n", - " for idx in logits[i].argsort()[::-1]:\n", - " display(images[idx * p + i])\n", - " print(f\"Score: {jnp.asarray(logits[i][idx], dtype=jnp.float32):.2f}\\n\")\n", - " print()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "oZT9i3jCjir0" - }, - "source": [ - "## 🪄 Optional: Save your Generated Images as W&B Tables\n", - "\n", - "W&B Tables is an interactive 2D grid with support to rich media logging. Use this to save the generated images on W&B dashboard and share with the world." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "-pSiv6Vwjkn0" - }, - "outputs": [], - "source": [ - "import wandb\n", - "\n", - "# Initialize a W&B run.\n", - "project = \"dalle-mini-tables-colab\"\n", - "run = wandb.init(project=project)\n", - "\n", - "# Initialize an empty W&B Tables.\n", - "columns = [\"captions\"] + [f\"image_{i+1}\" for i in range(n_predictions)]\n", - "gen_table = wandb.Table(columns=columns)\n", - "\n", - "# Add data to the table.\n", - "for i, prompt in enumerate(prompts):\n", - " # If CLIP scores exist, sort the Images\n", - " if logits is not None:\n", - " idxs = logits[i].argsort()[::-1]\n", - " tmp_imgs = images[i :: len(prompts)]\n", - " tmp_imgs = [tmp_imgs[idx] for idx in idxs]\n", - " else:\n", - " tmp_imgs = images[i :: len(prompts)]\n", - "\n", - " # Add the data to the table.\n", - " gen_table.add_data(prompt, *[wandb.Image(img) for img in tmp_imgs])\n", - "\n", - "# Log the Table to W&B dashboard.\n", - "wandb.log({\"Generated Images\": gen_table})\n", - "\n", - "# Close the W&B run.\n", - "run.finish()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Ck2ZnHwVjnRd" - }, - "source": [ - "Click on the link above to check out your generated images." - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "machine_shape": "hm", - "name": "DALL·E mini - Inference pipeline.ipynb", - "provenance": [], - "gpuType": "A100", - "include_colab_link": true - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.7" - } - }, - "nbformat": 4, - "nbformat_minor": 0 + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "118UKH5bWCGa" + }, + "source": [ + "# DALL·E mini - Inference pipeline\n", + "\n", + "*Generate images from a text prompt*\n", + "\n", + "\n", + "\n", + "This notebook illustrates [DALL·E mini](https://github.com/borisdayma/dalle-mini) inference pipeline.\n", + "\n", + "Just want to play? Use directly [the app](https://www.craiyon.com/).\n", + "\n", + "For more understanding of the model, refer to [the report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dS8LbaonYm3a" + }, + "source": [ + "## 🛠️ Installation and set-up" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "uzjAM2GBYpZX" + }, + "outputs": [], + "source": [ + "# Required only for colab environments + GPU\n", + "!pip install jax==0.3.25 jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n", + "\n", + "# Install required libraries\n", + "!pip install -q dalle-mini orbax==0.0.23\n", + "!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ozHzTkyv8cqU" + }, + "source": [ + "We load required models:\n", + "* DALL·E mini for text to encoded images\n", + "* VQGAN for decoding images\n", + "* CLIP for scoring predictions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "K6CxW2o42f-w" + }, + "outputs": [], + "source": [ + "# Model references\n", + "\n", + "# dalle-mega\n", + "DALLE_MODEL = \"dalle-mini/dalle-mini/mega-1-fp16:latest\" # can be wandb artifact or 🤗 Hub or local folder or google bucket\n", + "DALLE_COMMIT_ID = None\n", + "\n", + "# if the notebook crashes too often you can use dalle-mini instead by uncommenting below line\n", + "# DALLE_MODEL = \"dalle-mini/dalle-mini/mini-1:v0\"\n", + "\n", + "# VQGAN model\n", + "VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n", + "VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Yv-aR3t4Oe5v" + }, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "# check how many devices are available\n", + "jax.local_device_count()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "92zYmvsQ38vL" + }, + "outputs": [], + "source": [ + "# Load models & tokenizer\n", + "from dalle_mini import DalleBart, DalleBartProcessor\n", + "from vqgan_jax.modeling_flax_vqgan import VQModel\n", + "from transformers import CLIPProcessor, FlaxCLIPModel\n", + "\n", + "# Load dalle-mini\n", + "model, params = DalleBart.from_pretrained(\n", + " DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False\n", + ")\n", + "\n", + "# Load VQGAN\n", + "vqgan, vqgan_params = VQModel.from_pretrained(\n", + " VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "o_vH2X1tDtzA" + }, + "source": [ + "Model parameters are replicated on each device for faster inference." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "wtvLoM48EeVw" + }, + "outputs": [], + "source": [ + "from flax.jax_utils import replicate\n", + "\n", + "params = replicate(params)\n", + "vqgan_params = replicate(vqgan_params)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0A9AHQIgZ_qw" + }, + "source": [ + "Model functions are compiled and parallelized to take advantage of multiple devices." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "sOtoOmYsSYPz" + }, + "outputs": [], + "source": [ + "from functools import partial\n", + "\n", + "\n", + "# model inference\n", + "@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4, 5, 6))\n", + "def p_generate(\n", + " tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale\n", + "):\n", + " return model.generate(\n", + " **tokenized_prompt,\n", + " prng_key=key,\n", + " params=params,\n", + " top_k=top_k,\n", + " top_p=top_p,\n", + " temperature=temperature,\n", + " condition_scale=condition_scale,\n", + " )\n", + "\n", + "\n", + "# decode image\n", + "@partial(jax.pmap, axis_name=\"batch\")\n", + "def p_decode(indices, params):\n", + " return vqgan.decode_code(indices, params=params)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HmVN6IBwapBA" + }, + "source": [ + "Keys are passed to the model on each device to generate unique inference per device." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4CTXmlUkThhX" + }, + "outputs": [], + "source": [ + "import random\n", + "\n", + "# create a random key\n", + "seed = random.randint(0, 2**32 - 1)\n", + "key = jax.random.PRNGKey(seed)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BrnVyCo81pij" + }, + "source": [ + "## 🖍 Text Prompt" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rsmj0Aj5OQox" + }, + "source": [ + "Our model requires processing prompts." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "YjjhUychOVxm" + }, + "outputs": [], + "source": [ + "from dalle_mini import DalleBartProcessor\n", + "\n", + "processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BQ7fymSPyvF_" + }, + "source": [ + "Let's define some text prompts." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "x_0vI9ge1oKr" + }, + "outputs": [], + "source": [ + "prompts = [\n", + " \"sunset over a lake in the mountains\",\n", + " \"the Eiffel tower landing on the moon\",\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XlZUG3SCLnGE" + }, + "source": [ + "Note: we could use the same prompt multiple times for faster inference." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "VKjEZGjtO49k" + }, + "outputs": [], + "source": [ + "tokenized_prompts = processor(prompts)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-CEJBnuJOe5z" + }, + "source": [ + "Finally we replicate the prompts onto each device." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "lQePgju5Oe5z" + }, + "outputs": [], + "source": [ + "tokenized_prompt = replicate(tokenized_prompts)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "phQ9bhjRkgAZ" + }, + "source": [ + "## 🎨 Generate images\n", + "\n", + "We generate images using dalle-mini model and decode them with the VQGAN." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "d0wVkXpKqnHA" + }, + "outputs": [], + "source": [ + "# number of predictions per prompt\n", + "n_predictions = 8\n", + "\n", + "# We can customize generation parameters (see https://huggingface.co/blog/how-to-generate)\n", + "gen_top_k = None\n", + "gen_top_p = None\n", + "temperature = None\n", + "cond_scale = 10.0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "SDjEx9JxR3v8" + }, + "outputs": [], + "source": [ + "from flax.training.common_utils import shard_prng_key\n", + "import numpy as np\n", + "from PIL import Image\n", + "from tqdm.notebook import trange\n", + "\n", + "print(f\"Prompts: {prompts}\\n\")\n", + "# generate images\n", + "images = []\n", + "for i in trange(max(n_predictions // jax.device_count(), 1)):\n", + " # get a new key\n", + " key, subkey = jax.random.split(key)\n", + " # generate images\n", + " encoded_images = p_generate(\n", + " tokenized_prompt,\n", + " shard_prng_key(subkey),\n", + " params,\n", + " gen_top_k,\n", + " gen_top_p,\n", + " temperature,\n", + " cond_scale,\n", + " )\n", + " # remove BOS\n", + " encoded_images = encoded_images.sequences[..., 1:]\n", + " # decode images\n", + " decoded_images = p_decode(encoded_images, vqgan_params)\n", + " decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n", + " for decoded_img in decoded_images:\n", + " img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))\n", + " images.append(img)\n", + " display(img)\n", + " print()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tw02wG9zGmyB" + }, + "source": [ + "## 🏅 Optional: Rank images by CLIP score\n", + "\n", + "We can rank images according to CLIP.\n", + "\n", + "**Note: your session may crash if you don't have a subscription to Colab Pro.**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "RGjlIW_f6GA0" + }, + "outputs": [], + "source": [ + "# CLIP model\n", + "CLIP_REPO = \"openai/clip-vit-base-patch32\"\n", + "CLIP_COMMIT_ID = None\n", + "\n", + "# Load CLIP\n", + "clip, clip_params = FlaxCLIPModel.from_pretrained(\n", + " CLIP_REPO, revision=CLIP_COMMIT_ID, dtype=jnp.float16, _do_init=False\n", + ")\n", + "clip_processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)\n", + "clip_params = replicate(clip_params)\n", + "\n", + "\n", + "# score images\n", + "@partial(jax.pmap, axis_name=\"batch\")\n", + "def p_clip(inputs, params):\n", + " logits = clip(params=params, **inputs).logits_per_image\n", + " return logits" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "FoLXpjCmGpju" + }, + "outputs": [], + "source": [ + "from flax.training.common_utils import shard\n", + "\n", + "# get clip scores\n", + "clip_inputs = clip_processor(\n", + " text=prompts * jax.device_count(),\n", + " images=images,\n", + " return_tensors=\"np\",\n", + " padding=\"max_length\",\n", + " max_length=77,\n", + " truncation=True,\n", + ").data\n", + "logits = p_clip(shard(clip_inputs), clip_params)\n", + "\n", + "# organize scores per prompt\n", + "p = len(prompts)\n", + "logits = np.asarray([logits[:, i::p, i] for i in range(p)]).squeeze()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4AAWRm70LgED" + }, + "source": [ + "Let's now display images ranked by CLIP score." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "zsgxxubLLkIu" + }, + "outputs": [], + "source": [ + "for i, prompt in enumerate(prompts):\n", + " print(f\"Prompt: {prompt}\\n\")\n", + " for idx in logits[i].argsort()[::-1]:\n", + " display(images[idx * p + i])\n", + " print(f\"Score: {jnp.asarray(logits[i][idx], dtype=jnp.float32):.2f}\\n\")\n", + " print()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "oZT9i3jCjir0" + }, + "source": [ + "## 🪄 Optional: Save your Generated Images as W&B Tables\n", + "\n", + "W&B Tables is an interactive 2D grid with support to rich media logging. Use this to save the generated images on W&B dashboard and share with the world." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-pSiv6Vwjkn0" + }, + "outputs": [], + "source": [ + "import wandb\n", + "\n", + "# Initialize a W&B run.\n", + "project = \"dalle-mini-tables-colab\"\n", + "run = wandb.init(project=project)\n", + "\n", + "# Initialize an empty W&B Tables.\n", + "columns = [\"captions\"] + [f\"image_{i+1}\" for i in range(n_predictions)]\n", + "gen_table = wandb.Table(columns=columns)\n", + "\n", + "# Add data to the table.\n", + "for i, prompt in enumerate(prompts):\n", + " # If CLIP scores exist, sort the Images\n", + " if logits is not None:\n", + " idxs = logits[i].argsort()[::-1]\n", + " tmp_imgs = images[i :: len(prompts)]\n", + " tmp_imgs = [tmp_imgs[idx] for idx in idxs]\n", + " else:\n", + " tmp_imgs = images[i :: len(prompts)]\n", + "\n", + " # Add the data to the table.\n", + " gen_table.add_data(prompt, *[wandb.Image(img) for img in tmp_imgs])\n", + "\n", + "# Log the Table to W&B dashboard.\n", + "wandb.log({\"Generated Images\": gen_table})\n", + "\n", + "# Close the W&B run.\n", + "run.finish()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ck2ZnHwVjnRd" + }, + "source": [ + "Click on the link above to check out your generated images." + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "machine_shape": "hm", + "name": "DALL·E mini - Inference pipeline.ipynb", + "provenance": [], + "gpuType": "A100", + "include_colab_link": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } \ No newline at end of file diff --git a/tools/train/train.py b/tools/train/train.py index e8d97db7a..7a036546b 100644 --- a/tools/train/train.py +++ b/tools/train/train.py @@ -1179,7 +1179,6 @@ def loss_fn(logits, labels): # Define gradient update step fn def train_step(state, batch, train_time): - # get a minibatch (one gradient accumulation slice) def get_minibatch(batch, grad_idx): return jax.tree_util.tree_map( @@ -1539,7 +1538,6 @@ def run_evaluation(): def run_save_model(state, eval_metrics=None): if jax.process_index() == 0: - start_save_time = time.perf_counter() output_dir = training_args.output_dir use_bucket = output_dir.startswith("gs://")