-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathimage_processing.py
113 lines (92 loc) · 4.09 KB
/
image_processing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import base64
import tempfile
from io import BytesIO
from rembg.cli import remove
from PIL import Image
from flask import Blueprint, request, abort, Response
from flask_login import login_required, current_user
from app import db, app
from helpers import return_as_json, return_as_json_list, list_to_dict
from image_processing_schemas import extract_faces_schema, remove_background_schema
from orm import UserImage
from flask_expects_json import expects_json
from test_enhance_single_unalign import *
image_processing = Blueprint('image_processing', __name__)
print("Face Req Loading... ")
face_detector = dlib.cnn_face_detection_model_v1('./pretrain_models/mmod_human_face_detector.dat')
lmk_predictor = dlib.shape_predictor('./pretrain_models/shape_predictor_5_face_landmarks.dat')
template_path = './pretrain_models/FFHQ_template.npy'
face_detector = dlib.get_frontal_face_detector()
opt = TestOptions().parse()
opt.gpus = 0
enhance_model = def_models(opt)
print("Face Req Loaded... ")
@image_processing.route('/remove_background/<image_id>', methods=['GET'])
@login_required
def remove_background(image_id):
user_image = UserImage.query.filter_by(id=image_id, user_id=current_user.id).first()
if user_image is None:
abort(400, Response("No such image"))
root_id = user_image.root_id
if user_image.root_id == 0:
root_id = user_image.id # make sure we have the right root id
img_bytes = base64.b64decode(user_image.data)
img_bytes_out = remove(img_bytes)
image_data = base64.b64encode(img_bytes_out)
new_image = UserImage(parent_id=user_image.id, user_id=current_user.id, data=image_data,
root_id=root_id)
db.session.add(new_image)
db.session.commit()
db.session.flush()
return return_as_json(new_image.to_dict())
@image_processing.route('/extract_faces/<image_id>', methods=['GET'])
@login_required
def extract_faces(image_id):
current_user_id = current_user.id
user_image = UserImage.query.filter_by(id=image_id, user_id=current_user_id).first()
if user_image is None:
abort(400, Response("No such image"))
root_id = user_image.root_id
user_image_id = user_image.id
if user_image.root_id == 0:
root_id = user_image_id # make sure we have the right root id
img = get_image_as_dlib(user_image)
aligned_faces, tform_params = detect_and_align_faces(img, face_detector, lmk_predictor, template_path)
if len(aligned_faces) == 0:
return return_as_json({'faces': list_to_dict([]), 'enhanced': None})
hq_faces, lq_parse_maps = enhance_faces(aligned_faces, enhance_model)
hq_images = []
try:
for hq_img in hq_faces:
new_image = create_image_from_array(current_user_id, hq_img, root_id, user_image_id)
db.session.add(new_image)
hq_images.append(new_image)
except BaseException as e:
app.logger.error(e)
return Response("Check the logs", status=500)
try:
hq_img = past_faces_back(img, hq_faces, tform_params, upscale=opt.test_upscale)
new_image = create_image_from_array(current_user_id, hq_img, root_id, user_image_id)
db.session.add(new_image)
except BaseException as e:
app.logger.error(e)
return Response("Check the logs", status=500)
db.session.commit()
db.session.flush()
return return_as_json({'faces': list_to_dict(hq_images), 'enhanced': new_image.to_dict()})
def get_image_as_dlib(user_image):
with tempfile.NamedTemporaryFile(delete=False) as tmp:
encoded_image = user_image.data.encode('utf-8')
img_bytes = base64.b64decode(encoded_image)
tmp.write(img_bytes)
img = dlib.load_rgb_image(tmp.name)
os.remove(tmp.name)
return img
def create_image_from_array(current_user_id, hq_img, root_id, user_image_id):
image_data_np = Image.fromarray(hq_img)
buffered = BytesIO()
image_data_np.save(buffered, format="PNG")
image_data = base64.b64encode(buffered.getvalue())
new_image = UserImage(parent_id=user_image_id, user_id=current_user_id, data=image_data,
root_id=root_id)
return new_image