Skip to content

Commit

Permalink
re-use already downloaded input images for serverless (thanks Calvin!)
Browse files Browse the repository at this point in the history
  • Loading branch information
robballantyne committed Dec 23, 2023
1 parent 9147eb5 commit 218c035
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 13 deletions.
2 changes: 2 additions & 0 deletions build/COPY_ROOT/opt/ai-dock/bin/build/layer0/common.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ build_common_create_env() {

# RunPod serverless support
$MAMBA_CREATE -n serverless python=3.10
$MAMBA_INSTALL -n serverless \
python-magic
micromamba run -n serverless $PIP_INSTALL \
runpod
}
Expand Down
18 changes: 13 additions & 5 deletions build/COPY_ROOT/opt/serverless/handlers/basehandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import shutil
from utils.s3utils import s3utils
from utils.network import Network
from utils.filesystem import Filesystem

class BaseHandler:
ENDPOINT_PROMPT="http://127.0.0.1:18188/prompt"
Expand Down Expand Up @@ -66,12 +67,19 @@ def replace_urls(self, data):
return data

def get_url_content(self, url):
return os.path.basename(Network.download_file(
url,
self.get_input_dir(),
self.request_id
existing_file = Filesystem.find_input_file(
self.get_input_dir(),
Network.get_url_hash(url)
)
if existing_file:
return os.path.basename(existing_file)
else:
return os.path.basename(Network.download_file(
url,
self.get_input_dir(),
self.request_id
)
)
)

def is_server_ready(self):
try:
Expand Down
24 changes: 24 additions & 0 deletions build/COPY_ROOT/opt/serverless/utils/filesystem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import glob
import magic
import mimetypes

class Filesystem:
def __init__(self):
pass

@staticmethod
def find_input_file(directory, str_hash):
# Hashed url should have only one result
try:
matched = glob.glob(f'{directory}/{str_hash}*')
if len(matched) > 0:
return matched[0]
return None
except:
return None

@staticmethod
def get_file_extension(filepath):
mime_str = magic.from_file(filepath, mime=True)
return mimetypes.guess_extension(mime_str)

24 changes: 16 additions & 8 deletions build/COPY_ROOT/opt/serverless/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import requests
import os
import uuid
import hashlib
from .filesystem import Filesystem

class Network:
def __init__(self):
Expand All @@ -14,23 +16,29 @@ def is_url(value):
except:
return False

@staticmethod
def get_url_hash(url):
return hashlib.md5((f'{url}').encode()).hexdigest()

# todo - threads
@staticmethod
def download_file(url, target_dir, request_id):
try:
file_name_hash = Network.get_url_hash(url)
os.makedirs(target_dir, exist_ok=True)
response = requests.get(url, timeout=5)
if response.status_code > 399:
raise requests.RequestException(f"Unable to download {url}")
if "content-disposition" in response.headers:
content_disposition = response.headers["content-disposition"]
filename = content_disposition.split("filename=")[1]
else:
filename = url.split("/")[-1]

filepath = f"{target_dir}/{request_id}-{uuid.uuid4()}-{filename}"
with open(filepath, mode="wb") as file:

filepath_hash = f"{target_dir}/{file_name_hash}"
# ignore above
with open(filepath_hash, mode="wb") as file:
file.write(response.content)

file_extension = Filesystem.get_file_extension(filepath_hash)
filepath = f"{filepath_hash}{file_extension}"
os.replace(filepath_hash, filepath)

except:
raise

Expand Down

0 comments on commit 218c035

Please sign in to comment.