-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_video_prepare.py
105 lines (81 loc) · 3.78 KB
/
main_video_prepare.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import cv2
import numpy as np
import qoi
import torch
import torchvision
import torchvision.models.optical_flow
from torchvision import transforms
from tqdm import tqdm
class DivisibleBy8:
def __call__(self, x: torch.Tensor | np.ndarray) -> torch.Tensor | np.ndarray:
""" Preprocesses a tensor to be divisible by 8. This is required by the RAFT model. """
h, w = x.shape[:2]
h = h - h % 8
w = w - w % 8
x = x[:h, :w, ...]
return x
def frame_to_content(frame: np.ndarray) -> np.ndarray:
"""
Input: (H, W, C) ndarray, BGR or BGRA
Output: (H, W, 3) ndarray, RGB
"""
transform = transforms.Compose([
DivisibleBy8(),
torchvision.transforms.Lambda(lambda x: x[..., [2, 1, 0]]), # BGR to RGB
])
return transform(frame)
def frame_to_tensor(frame: np.ndarray, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
"""
Input: (H, W, C) ndarray, BGR or BGRA
Output: (C, H, W) Tensor, normalized
"""
transform = transforms.Compose([
DivisibleBy8(),
torchvision.transforms.Lambda(lambda x: x[..., [2, 1, 0]]), # BGR to RGB
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
return transform(frame)
def main():
video_filepath = 'examples/video/sintel.mp4'
root_path = "root/sintel.mp4"
# logger = logging.getLogger(__name__)
# logger.setLevel(logging.DEBUG)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torchvision.models.optical_flow.raft_large(weights=torchvision.models.optical_flow.Raft_Large_Weights.C_T_V2).to(device)
# model = torchvision.models.optical_flow.raft_small(weights=torchvision.models.optical_flow.Raft_Small_Weights.C_T_V2).to(device)
model.eval()
raft_mean = torch.tensor((0.5, 0.5, 0.5)).to(device)
raft_std = torch.tensor((0.5, 0.5, 0.5)).to(device)
vidcap = cv2.VideoCapture(video_filepath)
frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
content_prev = None
for frame_idx in tqdm(range(frame_count), desc="Splitting video and computing optical flow"):
frame_path = f"{root_path}/frame/{frame_idx}"
flow_path = f"{root_path}/flow"
os.makedirs(frame_path, exist_ok=True)
os.makedirs(flow_path, exist_ok=True)
content_filepath = f"{frame_path}/content.qoi"
# forward_flow_filepath = f"{flow_path}/flow_{frame_idx-1}_to_{frame_idx}.npy"
# backward_flow_filepath = f"{flow_path}/flow_{frame_idx}_to_{frame_idx-1}.npy"
forward_flow_filepath = f"{flow_path}/flow_{frame_idx-1}_to_{frame_idx}.npz"
backward_flow_filepath = f"{flow_path}/flow_{frame_idx}_to_{frame_idx-1}.npz"
success, frame_cur = vidcap.read()
if not success:
raise Exception(f"Failed to read frame {frame_idx} from video of {frame_count} frames")
content = frame_to_content(frame_cur).copy()
_ = qoi.write(content_filepath, content)
content_cur = frame_to_tensor(frame_cur, raft_mean, raft_std).to(device)
if content_prev is not None:
if not os.path.exists(forward_flow_filepath):
flow: list[torch.Tensor] = model(content_prev[None], content_cur[None])
# np.save(forward_flow_filepath, flow[0][0].cpu().detach().numpy())
np.savez_compressed(forward_flow_filepath, flow[0][0].cpu().detach().numpy())
# if not os.path.exists(backward_flow_filepath):
# flow: list[torch.Tensor] = model(content_cur[None], content_prev[None])
# np.save(backward_flow_filepath, flow[0][0].cpu().detach().numpy())
# np.savez_compressed(backward_flow_filepath, flow[0][0].cpu().detach().numpy())
content_prev = content_cur
if __name__ == '__main__':
main()