forked from google-deepmind/deepmind-research
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
118 lines (102 loc) · 4.14 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# pylint: disable=g-bad-file-header
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Utility functions for reading the datasets."""
import functools
import json
import os
import tensorflow.compat.v1 as tf
from meshgraphnets.common import NodeType
def _parse(proto, meta):
"""Parses a trajectory from tf.Example."""
feature_lists = {k: tf.io.VarLenFeature(tf.string)
for k in meta['field_names']}
features = tf.io.parse_single_example(proto, feature_lists)
out = {}
for key, field in meta['features'].items():
data = tf.io.decode_raw(features[key].values, getattr(tf, field['dtype']))
data = tf.reshape(data, field['shape'])
if field['type'] == 'static':
data = tf.tile(data, [meta['trajectory_length'], 1, 1])
elif field['type'] == 'dynamic_varlen':
length = tf.io.decode_raw(features['length_'+key].values, tf.int32)
length = tf.reshape(length, [-1])
data = tf.RaggedTensor.from_row_lengths(data, row_lengths=length)
elif field['type'] != 'dynamic':
raise ValueError('invalid data format')
out[key] = data
return out
def load_dataset(path, split):
"""Load dataset."""
with open(os.path.join(path, 'meta.json'), 'r') as fp:
meta = json.loads(fp.read())
ds = tf.data.TFRecordDataset(os.path.join(path, split+'.tfrecord'))
ds = ds.map(functools.partial(_parse, meta=meta), num_parallel_calls=8)
ds = ds.prefetch(1)
return ds
def add_targets(ds, fields, add_history):
"""Adds target and optionally history fields to dataframe."""
def fn(trajectory):
out = {}
for key, val in trajectory.items():
out[key] = val[1:-1]
if key in fields:
if add_history:
out['prev|'+key] = val[0:-2]
out['target|'+key] = val[2:]
return out
return ds.map(fn, num_parallel_calls=8)
def split_and_preprocess(ds, noise_field, noise_scale, noise_gamma):
"""Splits trajectories into frames, and adds training noise."""
def add_noise(frame):
noise = tf.random.normal(tf.shape(frame[noise_field]),
stddev=noise_scale, dtype=tf.float32)
# don't apply noise to boundary nodes
mask = tf.equal(frame['node_type'], NodeType.NORMAL)[:, 0]
noise = tf.where(mask, noise, tf.zeros_like(noise))
frame[noise_field] += noise
frame['target|'+noise_field] += (1.0 - noise_gamma) * noise
return frame
ds = ds.flat_map(tf.data.Dataset.from_tensor_slices)
ds = ds.map(add_noise, num_parallel_calls=8)
ds = ds.shuffle(10000)
ds = ds.repeat(None)
return ds.prefetch(10)
def batch_dataset(ds, batch_size):
"""Batches input datasets."""
shapes = ds.output_shapes
types = ds.output_types
def renumber(buffer, frame):
nodes, cells = buffer
new_nodes, new_cells = frame
return nodes + new_nodes, tf.concat([cells, new_cells+nodes], axis=0)
def batch_accumulate(ds_window):
out = {}
for key, ds_val in ds_window.items():
initial = tf.zeros((0, shapes[key][1]), dtype=types[key])
if key == 'cells':
# renumber node indices in cells
num_nodes = ds_window['node_type'].map(lambda x: tf.shape(x)[0])
cells = tf.data.Dataset.zip((num_nodes, ds_val))
initial = (tf.constant(0, tf.int32), initial)
_, out[key] = cells.reduce(initial, renumber)
else:
merge = lambda prev, cur: tf.concat([prev, cur], axis=0)
out[key] = ds_val.reduce(initial, merge)
return out
if batch_size > 1:
ds = ds.window(batch_size, drop_remainder=True)
ds = ds.map(batch_accumulate, num_parallel_calls=8)
return ds