-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
79 lines (63 loc) · 2.87 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import argparse
import jsonlines
import torch
from tqdm import tqdm
from coref import CorefModel
from coref.tokenizer_customization import *
def build_doc(doc: dict, model: CorefModel) -> dict:
filter_func = TOKENIZER_FILTERS.get(model.config.bert_model,
lambda _: True)
token_map = TOKENIZER_MAPS.get(model.config.bert_model, {})
word2subword = []
subwords = []
word_id = []
for i, word in enumerate(doc["cased_words"]):
tokenized_word = (token_map[word]
if word in token_map
else model.tokenizer.tokenize(word))
tokenized_word = list(filter(filter_func, tokenized_word))
word2subword.append((len(subwords), len(subwords) + len(tokenized_word)))
subwords.extend(tokenized_word)
word_id.extend([i] * len(tokenized_word))
doc["word2subword"] = word2subword
doc["subwords"] = subwords
doc["word_id"] = word_id
doc["head2span"] = []
if "speaker" not in doc:
doc["speaker"] = ["_" for _ in doc["cased_words"]]
doc["word_clusters"] = []
doc["span_clusters"] = []
return doc
if __name__ == "__main__":
argparser = argparse.ArgumentParser()
argparser.add_argument("experiment")
argparser.add_argument("input_file")
argparser.add_argument("output_file")
argparser.add_argument("--config-file", default="config.toml")
argparser.add_argument("--batch-size", type=int,
help="Adjust to override the config value if you're"
" experiencing out-of-memory issues")
argparser.add_argument("--weights",
help="Path to file with weights to load."
" If not supplied, in the latest"
" weights of the experiment will be loaded;"
" if there aren't any, an error is raised.")
args = argparser.parse_args()
model = CorefModel(args.config_file, args.experiment)
if args.batch_size:
model.config.a_scoring_batch_size = args.batch_size
model.load_weights(path=args.weights, map_location="cpu",
ignore={"bert_optimizer", "general_optimizer",
"bert_scheduler", "general_scheduler"})
model.training = False
with jsonlines.open(args.input_file, mode="r") as input_data:
docs = [build_doc(doc, model) for doc in input_data]
with torch.no_grad():
for doc in tqdm(docs, unit="docs"):
result = model.run(doc)
doc["span_clusters"] = result.span_clusters
doc["word_clusters"] = result.word_clusters
for key in ("word2subword", "subwords", "word_id", "head2span"):
del doc[key]
with jsonlines.open(args.output_file, mode="w") as output_data:
output_data.write_all(docs)