diff --git a/download.py b/download.py deleted file mode 100644 index 3ed1239..0000000 --- a/download.py +++ /dev/null @@ -1,22 +0,0 @@ -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import os - -def find_model(model_name): - """ - Finds a pre-trained Latte model, downloading it if necessary. Alternatively, loads a model from a local path. - """ - assert os.path.isfile(model_name), f'Could not find Latte checkpoint at {model_name}' - checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage) - - if "ema" in checkpoint: # supports checkpoints from train.py - print('Using Ema!') - checkpoint = checkpoint["ema"] - else: - print('Using model!') - checkpoint = checkpoint['model'] - return checkpoint \ No newline at end of file diff --git a/sample/sample.py b/sample/sample.py index 9fd5991..e99e776 100644 --- a/sample/sample.py +++ b/sample/sample.py @@ -12,14 +12,14 @@ import utils from diffusion import create_diffusion - from download import find_model + from utils import find_model except: sys.path.append(os.path.split(sys.path[0])[0]) import utils from diffusion import create_diffusion - from download import find_model + from utils import find_model import torch import argparse diff --git a/sample/sample_ddp.py b/sample/sample_ddp.py index fcc073a..c9ca3ce 100644 --- a/sample/sample_ddp.py +++ b/sample/sample_ddp.py @@ -16,7 +16,7 @@ import torch sys.path.append(os.path.split(sys.path[0])[0]) import torch.distributed as dist -from download import find_model +from utils import find_model from diffusion import create_diffusion from diffusers.models import AutoencoderKL from tqdm import tqdm diff --git a/sample/sample_t2x.py b/sample/sample_t2x.py index 2b86593..87ffa65 100644 --- a/sample/sample_t2x.py +++ b/sample/sample_t2x.py @@ -15,7 +15,6 @@ import os, sys sys.path.append(os.path.split(sys.path[0])[0]) -from download import find_model from pipeline_videogen import VideoGenPipeline from models import get_models from utils import save_video_grid diff --git a/utils.py b/utils.py index ad8b831..42c3356 100644 --- a/utils.py +++ b/utils.py @@ -271,6 +271,20 @@ def save_video_grid(video, nrow=None): return video_grid +def find_model(model_name): + """ + Finds a pre-trained Latte model, downloading it if necessary. Alternatively, loads a model from a local path. + """ + assert os.path.isfile(model_name), f'Could not find Latte checkpoint at {model_name}' + checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage) + + if "ema" in checkpoint: # supports checkpoints from train.py + print('Using Ema!') + checkpoint = checkpoint["ema"] + else: + print('Using model!') + checkpoint = checkpoint['model'] + return checkpoint ################################################################################# # MMCV Utils #