Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support offloading encode, for generate() with much less VRAM #269

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

drdaxxy
Copy link

@drdaxxy drdaxxy commented Jun 19, 2022

generate() from Transformers can take encoder outputs as kwargs instead of running the encoder. This PR extends this to "super conditioning" sampling. It also enables providing only one "null sequence" per batch, as inputs or encoder state, since that prompt is normally constant.

How is this useful? We only need to run the encoder once per distinct prompt, which even on a household CPU takes 1-2 seconds for a single input (worst case, no batching, no reuse). Offloading this step, generate works without 2 or 4 gigabytes of encoder weights (mega-1 and mega-1-fp16, respectively) hogging VRAM.

That way, mega-1-fp16 can run on a 4GB GPU (1-batches, without VQGAN, which is fast enough on CPU) and full-precision mega-1 can run on an 8GB GPU (1-batches with VQGAN, up to 3-batches without).

Specifically, without VQGAN, 1-batches need 3728 MiB in float16, 6770 MiB in float32 this way. GPU-accelerating VQGAN adds 770 MiB, assuming we also del vqgan_params["encoder"] (we never need these for generating images) before replicate(vqgan_params) or the like.

On systems that have enough memory anyway, up to 10 (fp32) or 20 (fp16) more items fit in a batch. Given the CPU encode cost, that's a few percent slower or faster (especially combined with other tricks in #247) in my experience, depending on how much state is shared.

@Kepler-Br
Copy link

Sounds awesome!
But as for someone who just want to try it out I'm not able to figure it out fast enough how to use offloading
Could you please add a usage example?

@drdaxxy
Copy link
Author

drdaxxy commented Jun 19, 2022

Could you please add a usage example?

I don't have time to write a proper example now, sorry... I'm hoping another developer decides to take care of that.

@TakuSmash
Copy link

Could this even get the full one working on a much smaller GPU VRAM too ? the full mega checkpoint instead of just the fp?

@Kepler-Br
Copy link

Kepler-Br commented Jun 21, 2022

Could this even get the full one working on a much smaller GPU VRAM too ? the full mega checkpoint instead of just the fp?

I guess so:

full-precision mega-1 can run on an 8GB GPU (1-batches with VQGAN, up to 3-batches without).

@TakuSmash
Copy link

TakuSmash commented Jun 21, 2022

Could this even get the full one working on a much smaller GPU VRAM too ? the full mega checkpoint instead of just the fp?

I guess so:

full-precision mega-1 can run on an 8GB GPU (1-batches with VQGAN, up to 3-batches without).

wait so my RTX 3060 should already be good to go for running this in something like Visions of Chaos? The full checkpoint?

@borisdayma
Copy link
Owner

Those are very interesting ideas @drdaxxy !

I'm gonna try to think about how to integrate it in a clean way.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants