Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving performance of video inference and increase GPU utilization #178

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
255 changes: 154 additions & 101 deletions inference_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision.transforms.functional import to_pil_image
from multiprocessing import Process, Pipe
from queue import Queue
from threading import Thread
from tqdm import tqdm
from PIL import Image
Expand Down Expand Up @@ -79,15 +81,36 @@

class VideoWriter:
def __init__(self, path, frame_rate, width, height):
self.out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'mp4v'), frame_rate, (width, height))
output_p, input_p = Pipe()
self.worker = Process(target=self.VideoWriterWorker, args=(path, frame_rate, width, height, (output_p, input_p)))
self.worker.start()
output_p.close()
self.input_p = input_p

def add_batch(self, frames):
frames = frames.mul(255).byte()
frames = frames.cpu().permute(0, 2, 3, 1).numpy()
for i in range(frames.shape[0]):
frame = frames[i]
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
self.out.write(frame)
frames = frames.mul(255).byte().permute(0, 2, 3, 1)
self.input_p.send(frames.cpu())

def close(self):
self.input_p.send(0)
self.worker.join()

@staticmethod
def VideoWriterWorker(path, frame_rate, width, height, pipe):
output_p, input_p = pipe
input_p.close()
out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'mp4v'), frame_rate, (width, height))
while True:
read_buffer = output_p.recv()
# gracefully exit with provided exit code if it is an integer
if type(read_buffer) == int:
break
frames = read_buffer.numpy()
for i in range(frames.shape[0]):
frame = frames[i]
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
out.write(frame)
out.release()


class ImageSequenceWriter:
Expand All @@ -110,106 +133,136 @@ def _add_batch(self, frames, index):


# --------------- Main ---------------
if __name__ == '__main__':

device = torch.device(args.device)

device = torch.device(args.device)

# Load model
if args.model_type == 'mattingbase':
model = MattingBase(args.model_backbone)
if args.model_type == 'mattingrefine':
model = MattingRefine(
args.model_backbone,
args.model_backbone_scale,
args.model_refine_mode,
args.model_refine_sample_pixels,
args.model_refine_threshold,
args.model_refine_kernel_size)

model = model.to(device).eval()
model.load_state_dict(torch.load(args.model_checkpoint, map_location=device), strict=False)


# Load video and background
vid = VideoDataset(args.video_src)
bgr = [Image.open(args.video_bgr).convert('RGB')]
dataset = ZipDataset([vid, bgr], transforms=A.PairCompose([
A.PairApply(T.Resize(args.video_resize[::-1]) if args.video_resize else nn.Identity()),
HomographicAlignment() if args.preprocess_alignment else A.PairApply(nn.Identity()),
A.PairApply(T.ToTensor())
]))
if args.video_target_bgr:
dataset = ZipDataset([dataset, VideoDataset(args.video_target_bgr, transforms=T.ToTensor())])

# Create output directory
if os.path.exists(args.output_dir):
if input(f'Directory {args.output_dir} already exists. Override? [Y/N]: ').lower() == 'y':
shutil.rmtree(args.output_dir)
else:
exit()
os.makedirs(args.output_dir)


# Prepare writers
if args.output_format == 'video':
h = args.video_resize[1] if args.video_resize is not None else vid.height
w = args.video_resize[0] if args.video_resize is not None else vid.width
if 'com' in args.output_types:
com_writer = VideoWriter(os.path.join(args.output_dir, 'com.mp4'), vid.frame_rate, w, h)
if 'pha' in args.output_types:
pha_writer = VideoWriter(os.path.join(args.output_dir, 'pha.mp4'), vid.frame_rate, w, h)
if 'fgr' in args.output_types:
fgr_writer = VideoWriter(os.path.join(args.output_dir, 'fgr.mp4'), vid.frame_rate, w, h)
if 'err' in args.output_types:
err_writer = VideoWriter(os.path.join(args.output_dir, 'err.mp4'), vid.frame_rate, w, h)
if 'ref' in args.output_types:
ref_writer = VideoWriter(os.path.join(args.output_dir, 'ref.mp4'), vid.frame_rate, w, h)
else:
if 'com' in args.output_types:
com_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'com'), 'png')
if 'pha' in args.output_types:
pha_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'pha'), 'jpg')
if 'fgr' in args.output_types:
fgr_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'fgr'), 'jpg')
if 'err' in args.output_types:
err_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'err'), 'jpg')
if 'ref' in args.output_types:
ref_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'ref'), 'jpg')


# Conversion loop
with torch.no_grad():
for input_batch in tqdm(DataLoader(dataset, batch_size=1, pin_memory=True)):
if args.video_target_bgr:
(src, bgr), tgt_bgr = input_batch
tgt_bgr = tgt_bgr.to(device, non_blocking=True)
# Load model
if args.model_type == 'mattingbase':
model = MattingBase(args.model_backbone)
if args.model_type == 'mattingrefine':
model = MattingRefine(
args.model_backbone,
args.model_backbone_scale,
args.model_refine_mode,
args.model_refine_sample_pixels,
args.model_refine_threshold,
args.model_refine_kernel_size)

model = model.to(device).eval()
model.load_state_dict(torch.load(args.model_checkpoint, map_location=device), strict=False)


# Load video and background
vid = VideoDataset(args.video_src)
bgr = Image.open(args.video_bgr).convert('RGB')

transforms = T.Compose([
T.Resize(args.video_resize[::-1]) if args.video_resize else nn.Identity(),
T.ToTensor()
])

bgr = transforms(bgr)
dataset = VideoDataset(args.video_src, transforms=transforms)

if args.video_target_bgr:
dataset = ZipDataset([dataset, VideoDataset(args.video_target_bgr, transforms=T.ToTensor())])

# Create output directory
if os.path.exists(args.output_dir):
if input(f'Directory {args.output_dir} already exists. Override? [Y/N]: ').lower() == 'y':
shutil.rmtree(args.output_dir)
else:
src, bgr = input_batch
tgt_bgr = torch.tensor([120/255, 255/255, 155/255], device=device).view(1, 3, 1, 1)
src = src.to(device, non_blocking=True)
bgr = bgr.to(device, non_blocking=True)

if args.model_type == 'mattingbase':
pha, fgr, err, _ = model(src, bgr)
elif args.model_type == 'mattingrefine':
pha, fgr, _, _, err, ref = model(src, bgr)
elif args.model_type == 'mattingbm':
pha, fgr = model(src, bgr)
exit()
os.makedirs(args.output_dir)


# Prepare writers
if args.output_format == 'video':
h = args.video_resize[1] if args.video_resize is not None else vid.height
w = args.video_resize[0] if args.video_resize is not None else vid.width
if 'com' in args.output_types:
if args.output_format == 'video':
# Output composite with green background
com = fgr * pha + tgt_bgr * (1 - pha)
com_writer.add_batch(com)
else:
# Output composite as rgba png images
com = torch.cat([fgr * pha.ne(0), pha], dim=1)
com_writer.add_batch(com)
com_writer = VideoWriter(os.path.join(args.output_dir, 'com.mp4'), vid.frame_rate, w, h)
if 'pha' in args.output_types:
pha_writer.add_batch(pha)
pha_writer = VideoWriter(os.path.join(args.output_dir, 'pha.mp4'), vid.frame_rate, w, h)
if 'fgr' in args.output_types:
fgr_writer.add_batch(fgr)
fgr_writer = VideoWriter(os.path.join(args.output_dir, 'fgr.mp4'), vid.frame_rate, w, h)
if 'err' in args.output_types:
err_writer.add_batch(F.interpolate(err, src.shape[2:], mode='bilinear', align_corners=False))
err_writer = VideoWriter(os.path.join(args.output_dir, 'err.mp4'), vid.frame_rate, w, h)
if 'ref' in args.output_types:
ref_writer.add_batch(F.interpolate(ref, src.shape[2:], mode='nearest'))
ref_writer = VideoWriter(os.path.join(args.output_dir, 'ref.mp4'), vid.frame_rate, w, h)
else:
if 'com' in args.output_types:
com_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'com'), 'png')
if 'pha' in args.output_types:
pha_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'pha'), 'jpg')
if 'fgr' in args.output_types:
fgr_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'fgr'), 'jpg')
if 'err' in args.output_types:
err_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'err'), 'jpg')
if 'ref' in args.output_types:
ref_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'ref'), 'jpg')


# Conversion loop
with torch.no_grad():
queue = Queue(1)
def load_worker():
tgt_bgr = torch.tensor([120/255, 255/255, 155/255]).view(1, 3, 1, 1)
for input_batch in tqdm(DataLoader(dataset, batch_size=1, pin_memory=True)):
if args.video_target_bgr:
src, tgt_bgr = input_batch
else:
src = input_batch
queue.put((src, tgt_bgr))
queue.put(None)
loader = Thread(target=load_worker)
loader.start()
# move background to device
bgr = (bgr[None]).to(device, non_blocking=False)
while True:
task = queue.get()
if task == None:
break
src, tgt_bgr = task
# move frame to device
src = src.to(device, non_blocking=True)
tgt_bgr = tgt_bgr.to(device, non_blocking=True)

if args.model_type == 'mattingbase':
pha, fgr, err, _ = model(src, bgr)
elif args.model_type == 'mattingrefine':
pha, fgr, _, _, err, ref = model(src, bgr)
elif args.model_type == 'mattingbm':
pha, fgr = model(src, bgr)

if 'com' in args.output_types:
if args.output_format == 'video':
# Output composite with green background
com = fgr * pha + tgt_bgr * (1 - pha)
com_writer.add_batch(com)
else:
# Output composite as rgba png images
com = torch.cat([fgr * pha.ne(0), pha], dim=1)
com_writer.add_batch(com)
if 'pha' in args.output_types:
pha_writer.add_batch(pha)
if 'fgr' in args.output_types:
fgr_writer.add_batch(fgr)
if 'err' in args.output_types:
err_writer.add_batch(F.interpolate(err, src.shape[2:], mode='bilinear', align_corners=False))
if 'ref' in args.output_types:
ref_writer.add_batch(F.interpolate(ref, src.shape[2:], mode='nearest'))
# terminate children processes
loader.join()
if args.output_format == 'video':
if 'com' in args.output_types:
com_writer.close()
if 'pha' in args.output_types:
pha_writer.close()
if 'fgr' in args.output_types:
fgr_writer.close()
if 'err' in args.output_types:
err_writer.close()
if 'ref' in args.output_types:
ref_writer.close()