forked from google-deepmind/deepmind-research
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
522 lines (452 loc) · 18.7 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# WikiGraphs is licensed under the terms of the Creative Commons
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
#
# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the
# terms of the Creative Commons Attribution-ShareAlike 4.0 International
# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at:
#
# https://creativecommons.org/licenses/by-sa/4.0/legalcode
#
# Freebase data is licensed by Google LLC under the terms of the Creative
# Commons CC BY 4.0 license. You may obtain a copy of the License at:
#
# https://creativecommons.org/licenses/by/4.0/legalcode
#
# ==============================================================================
"""Utility functions for the training script."""
import collections
import math
import random
from absl import flags
from absl import logging
import jax.numpy as jnp
import jraph
import numpy as np
import sklearn
from wikigraphs.data import paired_dataset as pd
from wikigraphs.data import tokenizers
from wikigraphs.data import wikitext as wt
from wikigraphs.model import graph_net as gn
from wikigraphs.model import sampler as transformer_sampler
from wikigraphs.model import transformer
FLAGS = flags.FLAGS
VOCAB_FILES_MAP = {
'wikitext': '/tmp/data/wikitext-vocab.csv',
'freebase2wikitext': '/tmp/data/text-vocab.csv',
}
GRAPH_VOCAB_FILE = '/tmp/data/graph-vocab.csv'
def init_tokenizer(dataset_name):
"""Initialie the tokenizer."""
logging.info('Loading tokenizer...')
tokenizer = tokenizers.WordTokenizer(VOCAB_FILES_MAP[dataset_name])
logging.info('Vocab size: %d', tokenizer.vocab_size)
return tokenizer
def init_graph_tokenizer():
"""Initialie the tokenizer."""
logging.info('Loading graph tokenizer...')
tokenizer = tokenizers.GraphTokenizer(GRAPH_VOCAB_FILE)
logging.info('Vocab size: %d', tokenizer.vocab_size)
return tokenizer
def get_dataset_class(dataset_name, model_type, job_mode='train'):
"""Get the dataset class used for all jobs."""
if dataset_name == 'freebase2wikitext':
if model_type == 'bow2text':
return pd.Bow2TextDataset
elif FLAGS.model_type == 'graph2text':
return pd.Graph2TextDataset
elif FLAGS.model_type == 'text':
if job_mode in ['train', 'eval']:
return pd.TextOnlyDataset
else:
# for sampling: taking the unique graphs for a fair comparison
return pd.Bow2TextDataset
else:
# Add other graph2text data here.
raise NotImplementedError()
else:
def dataset(graph_tokenizer, *args, **kwargs):
del graph_tokenizer
return wt.Dataset(*args, **kwargs)
return dataset
def preprocess(batch, model_type, num_devices=1):
"""Preprocess the batch before sending to the model."""
if model_type == 'text':
if 'graphs' in batch:
del batch['graphs']
elif model_type == 'bow2text':
# Do nothing, bow2text data is already in a good form.
pass
else: # graph2text
if num_devices == 1:
graphs = gn.pad_graphs(jraph.batch(batch['graphs']))
else:
# We need to first batch graphs into num_devices batchs.
graphs = gn.batch_graphs_by_device(batch['graphs'], num_devices)
# Then we pad them to the maximum graph size in the batch and concat.
# This way graphs can be distributed to each device through pmap.
graphs = gn.pad_graphs_by_device(graphs)
max_graph_size = gn.pad_size(graphs.n_node.max())
batch.update({
'graphs': graphs,
'max_graph_size': max_graph_size})
return batch
def text_model_fn(vocab_size):
return transformer.TransformerXL(
vocab_size=vocab_size,
emb_dim=FLAGS.emb_dim,
num_layers=FLAGS.num_layers,
num_heads=FLAGS.num_heads,
dropout_prob=FLAGS.dropout,
dropout_attn_prob=FLAGS.dropout_attn,
self_att_init_scale=FLAGS.self_att_init_scale,
dense_init_scale=FLAGS.dense_init_scale,
dense_dim=FLAGS.dense_dim,
tail_shrink_factor=FLAGS.tail_shrink_factor,
relative_pos_clamp_len=FLAGS.clamp_len or None)
def graph2text_model_fn(vocab_size):
"""Get graph2text transformer model."""
return transformer.Graph2TextTransformer(
vocab_size=vocab_size,
emb_dim=FLAGS.emb_dim,
num_layers=FLAGS.num_layers,
num_heads=FLAGS.num_heads,
dropout_prob=FLAGS.dropout,
dropout_attn_prob=FLAGS.dropout_attn,
self_att_init_scale=FLAGS.self_att_init_scale,
dense_init_scale=FLAGS.dense_init_scale,
dense_dim=FLAGS.dense_dim,
tail_shrink_factor=FLAGS.tail_shrink_factor,
relative_pos_clamp_len=FLAGS.clamp_len or None,
gnn_embed_dim=FLAGS.gnn_embed_dim,
gnn_num_layers=FLAGS.gnn_num_layers,
gnn_layer_norm=FLAGS.gnn_layer_norm)
def bow2text_model_fn(vocab_size):
"""Get the bow2text model."""
return transformer.Bow2TextTransformer(
vocab_size=vocab_size,
emb_dim=FLAGS.emb_dim,
num_layers=FLAGS.num_layers,
num_heads=FLAGS.num_heads,
dropout_prob=FLAGS.dropout,
dropout_attn_prob=FLAGS.dropout_attn,
self_att_init_scale=FLAGS.self_att_init_scale,
dense_init_scale=FLAGS.dense_init_scale,
dense_dim=FLAGS.dense_dim,
tail_shrink_factor=FLAGS.tail_shrink_factor,
relative_pos_clamp_len=FLAGS.clamp_len or None,
bow_embedding_dim=FLAGS.bow_embedding_dim,
bow_n_tokens=FLAGS.bow_n_tokens)
def build_loss_fn(vocab_size, cache_steps):
"""Build the appropriate loss function according to the configs."""
if FLAGS.model_type == 'text':
def loss_fn(data, is_training=True):
return text_model_fn(vocab_size=vocab_size).loss(
data['obs'], data['target'], data['mask'],
is_training=is_training,
should_reset=data['should_reset'],
cache_steps=cache_steps)
elif FLAGS.model_type == 'graph2text':
def loss_fn(data, max_graph_size, is_training=True):
return graph2text_model_fn(vocab_size=vocab_size).loss(
data['graphs'], max_graph_size, True,
data['obs'], data['target'], data['mask'],
is_training=is_training,
should_reset=data['should_reset'],
cache_steps=cache_steps)
elif FLAGS.model_type == 'bow2text':
def loss_fn(data, is_training=True):
return bow2text_model_fn(vocab_size=vocab_size).loss(
data['graphs'], data['obs'], data['target'], data['mask'],
is_training=is_training,
should_reset=data['should_reset'],
cache_steps=cache_steps)
else:
raise ValueError(f'Unknown model type "{FLAGS.model_type}".')
return loss_fn
def build_sampler(tokenizer, device=None):
"""Build the appropriate sampler according to the configs."""
if FLAGS.model_type == 'text':
model_fn = lambda prompts: text_model_fn(tokenizer.vocab_size)( # pylint: disable=g-long-lambda
prompts, is_training=False, cache_steps=FLAGS.sample_memory_size)
sampler_class = transformer_sampler.TransformerXLSampler
elif FLAGS.model_type == 'graph2text':
def model_fn(graphs, max_graph_size, prompts):
return graph2text_model_fn(tokenizer.vocab_size)(
graphs, max_graph_size, True, prompts, is_training=False,
cache_steps=FLAGS.sample_memory_size)
sampler_class = transformer_sampler.Graph2TextTransformerSampler
elif FLAGS.model_type == 'bow2text':
def model_fn(graphs, prompts):
return bow2text_model_fn(tokenizer.vocab_size)(
graphs, prompts, is_training=False,
cache_steps=FLAGS.sample_memory_size)
sampler_class = transformer_sampler.Bow2TextTransformerSampler
sampler = sampler_class(model_fn, FLAGS.sampling_temperature, device)
return sampler
def schedule(i, lr_schedule, init_lr, min_lr_ratio, max_steps):
if lr_schedule == 'cosine':
cosine_decay = 0.5 * (1 + jnp.cos(jnp.pi * i / max_steps))
decayed = (1 - min_lr_ratio) * cosine_decay + min_lr_ratio
return init_lr * decayed
else:
return jnp.where(
i > 350000, init_lr / 3**3,
jnp.where(i > 250000, init_lr / 3**2,
jnp.where(i > 150000, init_lr / 3, init_lr)))
def evaluate(eval_set, initial_state, updater, eval_batch_size=1,
preprocess_fn=None, max_eval_samples=-1,
print_progress_every=None):
"""Evaluate a model on given dataset."""
total_losses = []
total_counts = []
token_accuracy = []
seq_accuracy = []
state = initial_state
step = state['step']
for i, batch in enumerate(eval_set):
state, eval_out = updater.eval_return_state(state, preprocess_fn(batch))
total_losses.append(eval_out['total_loss'])
total_counts.append(eval_out['total_count'])
token_accuracy.append(
eval_out['token_accuracy'] * eval_out['total_count'])
seq_accuracy.append(eval_out['seq_accuracy'])
if print_progress_every and (i + 1) % print_progress_every == 0:
total_loss = float(jnp.array(total_losses).sum())
total_count = float(jnp.array(total_counts).sum())
avg_loss = total_loss / total_count
bpc = avg_loss * np.log2(np.e)
perplexity = np.exp(avg_loss)
logging.info(
'Evaluated %d batches, total tokens %d, average loss %g,'
' bpc %g, perplexity %g.',
i + 1, total_count, avg_loss, bpc, perplexity)
if 0 < max_eval_samples <= (i + 1) * eval_batch_size:
break
total_loss = jnp.array(total_losses).sum()
total_count = jnp.array(total_counts).sum()
avg_loss = total_loss / total_count
eval_out = dict(total_loss=float(total_loss),
total_count=float(total_count),
loss=float(avg_loss),
token_accuracy=float(
jnp.array(token_accuracy).sum() / total_count),
seq_accuracy=float(
jnp.array(seq_accuracy).sum() / len(seq_accuracy)),
step=float(step),
bits_per_token=float(avg_loss) * np.log2(np.e),
perplexity=np.exp(float(avg_loss)))
return eval_out, state
def extract_title(text, tokenizer):
r"""Extract the title in the text.
The wikitext articles is in the format of `\n = TITLE = \n \n...`. We extract
the title as the tokens from the start to when the `\n \n` first appears.
Args:
text: tokenized input text using `tokenizer`.
tokenizer: text tokenizer.
Returns:
title_end_idx: a numpy.array of shape (batch_size,), it indicates the index
in `text` that marks the end of the title.
"""
batch_size, text_length = text.shape
title_end_idx = np.ones(batch_size, dtype=np.int32)
newline_token = tokenizer.encode('\n')[0]
for b in range(batch_size):
prev_token = 1 # start tokens
for i in range(1, text_length): # skip start token
# when we first see '\n \n', that is the title
if prev_token == newline_token and text[b, i] == newline_token:
title_end_idx[b] = i
break
else:
prev_token = text[b, i]
return title_end_idx
def construct_prompts(text, batch_size, sample_length, tokenizer, prompt_title):
"""Construct prompts for text generation.
Args:
text: tokenized input text using `tokenizer`.
batch_size: the size of the batch.
sample_length: the length of the sample to be generated.
tokenizer: text tokenizer.
prompt_title: whether to return a prompt with the title of the `text`.
Returns:
prompts: a numpy.array of shape [batch_size, sample_length], in which -1
indicates tokens that need to be generated using the sampler.
"""
prompts = -np.ones((batch_size, sample_length), dtype=np.int32)
prompts[:, 0] = tokenizer.bos_token()
if prompt_title and text is not None:
title_end_idx = extract_title(text, tokenizer)
for i in range(batch_size):
prompts[i, 1:title_end_idx[i]+1] = text[i, 1:title_end_idx[i]+1]
return prompts
def generate_samples(params, tokenizer, sampler, model_type, prompts, graphs):
"""Generate a batch of samples using a sampler."""
if model_type == 'text':
samples = sampler.sample(params, prompts)
elif model_type == 'graph2text':
samples = sampler.sample(params, prompts, graphs, pad=True)
elif model_type == 'bow2text':
samples = sampler.sample(params, prompts, graphs)
else:
raise ValueError(f'Unknown model_type {model_type}')
return [tokenizer.decode(s) for s in samples], samples
def take_unique_graphs(data_iter, model_type):
"""Filter data such that it only returns batches with unique graphs."""
prev_graphs = None
for batch in data_iter:
graphs = batch.get('graphs', None)
# If there's no graph in batch, don't do any filtering
if graphs is None:
yield batch
else:
if prev_graphs is None:
prev_graphs = graphs
yield batch
else:
if model_type == 'graph2text':
not_same_graph = (prev_graphs.nodes.shape != graphs.nodes.shape or
not (prev_graphs.nodes == graphs.nodes).all())
else:
not_same_graph = (prev_graphs.shape != graphs.shape or
not (prev_graphs == graphs).all())
if not_same_graph:
prev_graphs = graphs
yield batch
def compute_map_sklearn(pred, gt):
"""Computes mAP using scikit-learn."""
assert len(gt.shape) == len(pred.shape) == 2, (
'gt should be a one-hot encoding with the same shape as pred')
ap = [
sklearn.metrics.average_precision_score(
gt[c, :], pred[c, :], average=None)
for c in range(gt.shape[0])
]
return sum(ap) / len(ap)
def compute_recall_at_k(pred, k=1):
"""Computes recall@1 score."""
num_articles = pred.shape[1]
return sklearn.metrics.top_k_accuracy_score(
np.arange(num_articles), pred, k=k)
def compute_text_graph_relevance(
eval_set, initial_state, updater, eval_batch_size=1, preprocess_fn=None,
print_progress_every=None):
"""Compute the text and graph relevance a model on given dataset."""
assert eval_batch_size == 1
num_articles = eval_set.num_articles
tokens_count = np.zeros((num_articles, num_articles))
log_probs = np.zeros((num_articles, num_articles)) # [graphs, texts]
state = initial_state
for i, batch in enumerate(eval_set):
state, eval_out = updater.eval_return_state(state, preprocess_fn(batch))
graph_id = batch['graph_id'][0]
seq_id = batch['seq_id'][0]
tokens_count[graph_id, seq_id] += eval_out['total_count']
log_probs[graph_id, seq_id] += eval_out['log_probs']
if print_progress_every is not None and (i + 1) % print_progress_every == 0:
logging.info('Evaluated %d samples', i + 1)
log_probs_per_token = log_probs / tokens_count
labels = np.eye(num_articles)
eval_out = dict(
log_probs=log_probs,
tokens_count=tokens_count,
log_probs_per_token=log_probs_per_token,
text2graph_recall_at_1=compute_recall_at_k(log_probs_per_token.T, k=1),
text2graph_recall_at_5=compute_recall_at_k(log_probs_per_token.T, k=5),
text2graph_map=compute_map_sklearn(log_probs_per_token.T, labels),
graph2text_recall_at_1=compute_recall_at_k(log_probs_per_token, k=1),
graph2text_recall_at_5=compute_recall_at_k(log_probs_per_token, k=5),
graph2text_map=compute_map_sklearn(log_probs_per_token, labels))
return eval_out, state
def _get_ngrams(segment, max_order):
"""Extracts all n-grams upto a given maximum order from an input segment.
Args:
segment: text segment from which n-grams will be extracted.
max_order: maximum length in tokens of the n-grams returned by this
methods.
Returns:
The Counter containing all n-grams upto max_order in segment
with a count of how many times each n-gram occurred.
"""
ngram_counts = collections.Counter()
for order in range(1, max_order + 1):
for i in range(0, len(segment) - order + 1):
ngram = tuple(segment[i:i+order])
ngram_counts[ngram] += 1
return ngram_counts
def compute_bleu(reference_corpus, translation_corpus, max_order=4,
smooth=False):
"""Computes BLEU score of translated segments against one or more references.
Originally from tensor2tensor/tensor2tensor/utils/bleu_hook.py
Args:
reference_corpus: list of lists of references for each translation. Each
reference should be tokenized into a list of tokens.
translation_corpus: list of translations to score. Each translation
should be tokenized into a list of tokens.
max_order: Maximum n-gram order to use when computing BLEU score.
smooth: Whether or not to apply Lin et al. 2004 smoothing.
Returns:
BLEU score and n-gram precisions.
"""
matches_by_order = [0] * max_order
possible_matches_by_order = [0] * max_order
reference_length = 0
translation_length = 0
for (references, translation) in zip(reference_corpus,
translation_corpus):
reference_length += min(len(r) for r in references)
translation_length += len(translation)
merged_ref_ngram_counts = collections.Counter()
for reference in references:
merged_ref_ngram_counts |= _get_ngrams(reference, max_order)
translation_ngram_counts = _get_ngrams(translation, max_order)
overlap = translation_ngram_counts & merged_ref_ngram_counts
for ngram in overlap:
matches_by_order[len(ngram)-1] += overlap[ngram]
for order in range(1, max_order+1):
possible_matches = len(translation) - order + 1
if possible_matches > 0:
possible_matches_by_order[order-1] += possible_matches
if random.random() < 0.01:
print('==========')
for k, v in overlap.items():
if len(k) >= 3:
print('%s : %d' % (str(k), v))
# print(matches_by_order)
# print(possible_matches_by_order)
precisions = [0] * max_order
for i in range(0, max_order):
if smooth:
precisions[i] = ((matches_by_order[i] + 1.) /
(possible_matches_by_order[i] + 1.))
else:
if possible_matches_by_order[i] > 0:
precisions[i] = (float(matches_by_order[i]) /
possible_matches_by_order[i])
else:
precisions[i] = 0.0
if min(precisions) > 0:
p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions)
geo_mean = math.exp(p_log_sum)
else:
geo_mean = 0
ratio = float(translation_length) / reference_length
if ratio > 1.0:
bp = 1.
else:
bp = math.exp(1 - 1. / ratio)
bleu = geo_mean * bp
return bleu, precisions