-
Notifications
You must be signed in to change notification settings - Fork 60
/
Copy pathdataio.py
172 lines (135 loc) · 6.73 KB
/
dataio.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import os
import torch
import numpy as np
from glob import glob
import data_util
import util
def pick(list, item_idcs):
if not list:
return list
return [list[i] for i in item_idcs]
class SceneInstanceDataset():
"""This creates a dataset class for a single object instance (such as a single car)."""
def __init__(self,
instance_idx,
instance_dir,
specific_observation_idcs=None, # For few-shot case: Can pick specific observations only
img_sidelength=None,
num_images=-1):
self.instance_idx = instance_idx
self.img_sidelength = img_sidelength
self.instance_dir = instance_dir
color_dir = os.path.join(instance_dir, "rgb")
pose_dir = os.path.join(instance_dir, "pose")
param_dir = os.path.join(instance_dir, "params")
if not os.path.isdir(color_dir):
print("Error! root dir %s is wrong" % instance_dir)
return
self.has_params = os.path.isdir(param_dir)
self.color_paths = sorted(data_util.glob_imgs(color_dir))
self.pose_paths = sorted(glob(os.path.join(pose_dir, "*.txt")))
if self.has_params:
self.param_paths = sorted(glob(os.path.join(param_dir, "*.txt")))
else:
self.param_paths = []
if specific_observation_idcs is not None:
self.color_paths = pick(self.color_paths, specific_observation_idcs)
self.pose_paths = pick(self.pose_paths, specific_observation_idcs)
self.param_paths = pick(self.param_paths, specific_observation_idcs)
elif num_images != -1:
idcs = np.linspace(0, stop=len(self.color_paths), num=num_images, endpoint=False, dtype=int)
self.color_paths = pick(self.color_paths, idcs)
self.pose_paths = pick(self.pose_paths, idcs)
self.param_paths = pick(self.param_paths, idcs)
def set_img_sidelength(self, new_img_sidelength):
"""For multi-resolution training: Updates the image sidelength with whichimages are loaded."""
self.img_sidelength = new_img_sidelength
def __len__(self):
return len(self.pose_paths)
def __getitem__(self, idx):
intrinsics, _, _, _ = util.parse_intrinsics(os.path.join(self.instance_dir, "intrinsics.txt"),
trgt_sidelength=self.img_sidelength)
intrinsics = torch.Tensor(intrinsics).float()
rgb = data_util.load_rgb(self.color_paths[idx], sidelength=self.img_sidelength)
rgb = rgb.reshape(3, -1).transpose(1, 0)
pose = data_util.load_pose(self.pose_paths[idx])
if self.has_params:
params = data_util.load_params(self.param_paths[idx])
else:
params = np.array([0])
uv = np.mgrid[0:self.img_sidelength, 0:self.img_sidelength].astype(np.int32)
uv = torch.from_numpy(np.flip(uv, axis=0).copy()).long()
uv = uv.reshape(2, -1).transpose(1, 0)
sample = {
"instance_idx": torch.Tensor([self.instance_idx]).squeeze(),
"rgb": torch.from_numpy(rgb).float(),
"pose": torch.from_numpy(pose).float(),
"uv": uv,
"param": torch.from_numpy(params).float(),
"intrinsics": intrinsics
}
return sample
class SceneClassDataset(torch.utils.data.Dataset):
"""Dataset for a class of objects, where each datapoint is a SceneInstanceDataset."""
def __init__(self,
root_dir,
img_sidelength=None,
max_num_instances=-1,
max_observations_per_instance=-1,
specific_observation_idcs=None, # For few-shot case: Can pick specific observations only
samples_per_instance=2):
self.samples_per_instance = samples_per_instance
self.instance_dirs = sorted(glob(os.path.join(root_dir, "*/")))
assert (len(self.instance_dirs) != 0), "No objects in the data directory"
if max_num_instances != -1:
self.instance_dirs = self.instance_dirs[:max_num_instances]
self.all_instances = [SceneInstanceDataset(instance_idx=idx,
instance_dir=dir,
specific_observation_idcs=specific_observation_idcs,
img_sidelength=img_sidelength,
num_images=max_observations_per_instance)
for idx, dir in enumerate(self.instance_dirs)]
self.num_per_instance_observations = [len(obj) for obj in self.all_instances]
self.num_instances = len(self.all_instances)
def set_img_sidelength(self, new_img_sidelength):
"""For multi-resolution training: Updates the image sidelength with whichimages are loaded."""
for instance in self.all_instances:
instance.set_img_sidelength(new_img_sidelength)
def __len__(self):
return np.sum(self.num_per_instance_observations)
def get_instance_idx(self, idx):
"""Maps an index into all tuples of all objects to the idx of the tuple relative to the other tuples of that
object
"""
obj_idx = 0
while idx >= 0:
idx -= self.num_per_instance_observations[obj_idx]
obj_idx += 1
return obj_idx - 1, int(idx + self.num_per_instance_observations[obj_idx - 1])
def collate_fn(self, batch_list):
batch_list = zip(*batch_list)
all_parsed = []
for entry in batch_list:
# make them all into a new dict
ret = {}
for k in entry[0][0].keys():
ret[k] = []
# flatten the list of list
for b in entry:
for k in entry[0][0].keys():
ret[k].extend( [bi[k] for bi in b])
for k in ret.keys():
if type(ret[k][0]) == torch.Tensor:
ret[k] = torch.stack(ret[k])
all_parsed.append(ret)
return tuple(all_parsed)
def __getitem__(self, idx):
"""Each __getitem__ call yields a list of self.samples_per_instance observations of a single scene (each a dict),
as well as a list of ground-truths for each observation (also a dict)."""
obj_idx, rel_idx = self.get_instance_idx(idx)
observations = []
observations.append(self.all_instances[obj_idx][rel_idx])
for i in range(self.samples_per_instance - 1):
observations.append(self.all_instances[obj_idx][np.random.randint(len(self.all_instances[obj_idx]))])
ground_truth = [{'rgb':ray_bundle['rgb']} for ray_bundle in observations]
return observations, ground_truth