Skip to content

Commit

Permalink
remove download
Browse files Browse the repository at this point in the history
  • Loading branch information
maxin-cn committed Jun 4, 2024
1 parent 0bcb798 commit e8409e3
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 26 deletions.
22 changes: 0 additions & 22 deletions download.py

This file was deleted.

4 changes: 2 additions & 2 deletions sample/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
import utils

from diffusion import create_diffusion
from download import find_model
from utils import find_model
except:
sys.path.append(os.path.split(sys.path[0])[0])

import utils

from diffusion import create_diffusion
from download import find_model
from utils import find_model

import torch
import argparse
Expand Down
2 changes: 1 addition & 1 deletion sample/sample_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch
sys.path.append(os.path.split(sys.path[0])[0])
import torch.distributed as dist
from download import find_model
from utils import find_model
from diffusion import create_diffusion
from diffusers.models import AutoencoderKL
from tqdm import tqdm
Expand Down
1 change: 0 additions & 1 deletion sample/sample_t2x.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import os, sys
sys.path.append(os.path.split(sys.path[0])[0])
from download import find_model
from pipeline_videogen import VideoGenPipeline
from models import get_models
from utils import save_video_grid
Expand Down
14 changes: 14 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,20 @@ def save_video_grid(video, nrow=None):

return video_grid

def find_model(model_name):
"""
Finds a pre-trained Latte model, downloading it if necessary. Alternatively, loads a model from a local path.
"""
assert os.path.isfile(model_name), f'Could not find Latte checkpoint at {model_name}'
checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage)

if "ema" in checkpoint: # supports checkpoints from train.py
print('Using Ema!')
checkpoint = checkpoint["ema"]
else:
print('Using model!')
checkpoint = checkpoint['model']
return checkpoint

#################################################################################
# MMCV Utils #
Expand Down

0 comments on commit e8409e3

Please sign in to comment.