-
Notifications
You must be signed in to change notification settings - Fork 1
/
rank.py
65 lines (45 loc) · 1.61 KB
/
rank.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import numpy
from sentence_retriever import getSentences
from grammar import correct_grammar
from loadconfig import loadConfig
import random
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
correct_phrase = loadConfig('Rank')
def getRoberta():
roberta = torch.hub.load('pytorch/fairseq', 'roberta.large.mnli')
roberta.cuda()
roberta.eval()
return roberta
def getContradictionScores(roberta,sentences,rov):
scores = []
gender = ''
for sent in sentences:
sent,gender = correct_grammar(rov,sent,gender)
tokens = roberta.encode(rov, sent)
value = roberta.predict('mnli', tokens).cpu().detach().numpy()
value = round(value[0].tolist()[0],3)
scores.append((value,sent.capitalize()))
return scores
def rank_sentences_based_on_contradiction(roberta,sentences,rov):
scores = getContradictionScores(roberta,sentences,rov)
scores.sort(key = lambda x: (x[0],-len(x[1].split())),reverse=True)
return scores[0]
def getSentenceforNSI(sentences,rov,commonsense,extra=''):
sentences = getSentences(commonsense,'',False)
return random.choice(sentences)
def rankContext(roberta,rov,commonsense,extra=''):
sentences = getSentences(commonsense,rov)
if extra!='':
for i in range(len(sentences)):
if commonsense in correct_phrase:
replacement = commonsense+' '+extra
else:
replacement = extra+' '+commonsense
if replacement not in sentences[i]:
sentences[i] = sentences[i].lower().replace(commonsense,replacement).capitalize()
mostcontradictory = rank_sentences_based_on_contradiction(roberta,sentences,rov)
x = mostcontradictory[1].capitalize()
x = x.replace(' i ',' I ')
return x