Skip to content

Latest commit

 

History

History
61 lines (44 loc) · 2.43 KB

File metadata and controls

61 lines (44 loc) · 2.43 KB

Token Merging for Stable Diffusion running with OpenVINO

This is an OpenVINO adopted version of Token Merging method. The method is applied to PyTorch model before exporting to OpenVINO representation. It can be also stacked with 8-bit quantization to achieve a higher inference speed. The repository contains implementation for:

  • Stable Diffusion (HF Diffusers based models), see example.
  • OpenCLIP, see example.
  • Timm

Here are the results for 100 iteration of 512x512 image generation on CPU. ToMe for SD applied on a 512x512 image.

This is the official implementation of ToMe for SD from the paper:
Token Merging for Fast Stable Diffusion

ToMe for SD is an extension of the original ToMe:
Token Merging: Your ViT but Faster

Note: This also supports most downstream UIs that use these repositories.

Installation

ToMe for SD requires pytorch >= 1.12.1 (for scatter_reduce), which you can get from here. Then after installing your choice of stable diffusion environment (supported environments), use the corresponding python environment to install ToMe for SD:

pip install git+https://github.com/openvinotoolkit/openvino_contrib.git#egg=tomeov&subdirectory=modules/token_merging

Usage

  • Diffusers:
import torch, tomeov
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")

save_dir = "stable_diffusion_optimized"
# Apply ToMe with a 30% merging ratio
tomeov.patch_stable_diffusion(pipe, ratio=0.3) # Can also use pipe.unet in place of pipe here
  • OpenCLIP:
import torch, tomeov
import open_clip
from open_clip import tokenizer

model, _, preprocess = open_clip.create_model_and_transforms("ViT-B-16-plus-240", pretrained="laion400m_e32")

tomeov.patch_openclip(model, 8) # 8 - number of tokens merged in each MHSA from top down
  • Timm:
import torch, tomeov
import timm

model_name = 'vit_tiny_patch16_224'
model = timm.create_model(model_name, pretrained=True)

tomeov.patch_timm(model, 4) # 8 - number of tokens merged in each MHSA from top down