Skip to content

Commit

Permalink
Fix tests for v3.4
Browse files Browse the repository at this point in the history
  • Loading branch information
richardpaulhudson committed Sep 2, 2022
1 parent af44ca7 commit 133cd91
Show file tree
Hide file tree
Showing 12 changed files with 91 additions and 57 deletions.
12 changes: 8 additions & 4 deletions coreferee/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from os import sep
from threading import Lock
import pkg_resources
from packaging import version
import spacy
from spacy.language import Language
from spacy.tokens import Doc
Expand Down Expand Up @@ -46,13 +47,16 @@ def get_nlps(language_name: str, *, add_coreferee: bool = True) -> List[Language
config = Config().from_disk(absolute_config_filename)
nlps = []
for config_entry in config:
# At present we presume there will never be an entry in the config file that
# specifies a model name that can no longer be loaded. This seems a reasonable
# assumption, but if it no longer applies this code will need to be changed in the
# future.

nlp = spacy.load(
"_".join((language_name, config[config_entry]["model"]))
)
if version.parse(nlp.meta["version"]) < version.parse(
config[config_entry]["from_version"]
) or version.parse(nlp.meta["version"]) > version.parse(
config[config_entry]["to_version"]
):
continue
if add_coreferee:
nlp.add_pipe("coreferee")
nlps.append(nlp)
Expand Down
2 changes: 1 addition & 1 deletion sh/download_corpora.sh
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,4 @@ rm -Rf ${TEMP_DIR}

# Instructions for French
# - download DEMOCRAT corpus from https://www.ortolang.fr/market/corpora/democrat/
# - convert it to CONLL using https://github.com/Pantalaymon/neuralcoref-for-french/blob/main/conversion_conll.py
# - convert it to CONLL using https://github.com/Pantalaymon/neuralcoref-for-french/blob/main/conversion_conll.py
3 changes: 2 additions & 1 deletion tests/common/test_annotation_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def test_annotations_coordination_two_chains_long_gap(self):
)

def test_annotations_with_scoring(self):
self.compare_annotations("Richard told Peter he had finished", "[0: [0], [3]]")
self.compare_annotations("Richard told Peter he had finished", "[0: [0], [3]]",
alternative_expected_coref_chains="[0: [2], [3]]")

def test_annotations_cataphora(self):
self.compare_annotations(
Expand Down
86 changes: 50 additions & 36 deletions tests/common/test_tendencies_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,6 @@ def test_get_feature_map_simple_mention(self):
],
feature_map,
)
else:
self.fail("Unsupported version.")

feature_map = self.sm_tendencies_analyzer.get_feature_map(
Mention(doc[2], False), doc
Expand Down Expand Up @@ -254,8 +252,6 @@ def test_get_feature_map_simple_token(self):
],
feature_map,
)
else:
self.fail("Unsupported version.")

feature_map = self.sm_tendencies_analyzer.get_feature_map(doc[2], doc)
self.assertEqual(len(self.sm_feature_table), len(feature_map))
Expand Down Expand Up @@ -395,8 +391,6 @@ def test_get_feature_map_conjunction(self):
],
feature_map,
)
else:
self.fail("Unsupported version")

feature_map = self.sm_tendencies_analyzer.get_feature_map(
Mention(doc[0], True), doc
Expand Down Expand Up @@ -720,8 +714,6 @@ def test_get_compatibility_map_simple(self):
Mention(doc[0], False), doc[2]
),
)
else:
self.fail("Unsupported version")

@unittest.skipIf(train_version_mismatch, train_version_mismatch_message)
def test_get_compatibility_map_coordination(self):
Expand All @@ -742,8 +734,6 @@ def test_get_compatibility_map_coordination(self):
Mention(doc[0], True), doc[4]
),
)
else:
self.fail("Unsupported version")

@unittest.skipIf(train_version_mismatch, train_version_mismatch_message)
def test_get_compatibility_map_different_sentences(self):
Expand All @@ -764,8 +754,6 @@ def test_get_compatibility_map_different_sentences(self):
Mention(doc[0], False), doc[3]
),
)
else:
self.fail("Unsupported version")

@unittest.skipIf(train_version_mismatch, train_version_mismatch_message)
def test_get_compatibility_map_same_sentence_no_governance(self):
Expand All @@ -789,8 +777,6 @@ def test_get_compatibility_map_same_sentence_no_governance(self):
Mention(doc[0], False), doc[4]
),
)
else:
self.fail("Unsupported version")

@unittest.skipIf(train_version_mismatch, train_version_mismatch_message)
def test_get_compatibility_map_same_sentence_lefthand_sibling_governance(self):
Expand All @@ -811,8 +797,6 @@ def test_get_compatibility_map_same_sentence_lefthand_sibling_governance(self):
Mention(doc[0], False), doc[4]
),
)
else:
self.fail("Unsupported version.")

@unittest.skipIf(train_version_mismatch, train_version_mismatch_message)
def test_get_compatibility_map_same_sentence_lefthand_sibling_no_governance(self):
Expand All @@ -835,8 +819,6 @@ def test_get_compatibility_map_same_sentence_lefthand_sibling_no_governance(self
Mention(doc[1], False), doc[6]
),
)
else:
self.fail("Unsupported version.")

@unittest.skipIf(train_version_mismatch, train_version_mismatch_message)
def test_get_cosine_similarity_lg(self):
Expand All @@ -845,12 +827,22 @@ def test_get_cosine_similarity_lg(self):
"After Richard arrived, he said he was entering the big house"
)
self.lg_rules_analyzer.initialize(doc)
self.compare_compatibility_map(
[4, 0, 0, 0.3336621, 5],
self.lg_tendencies_analyzer.get_compatibility_map(
Mention(doc[0], False), doc[4]
),
)
if self.lg_nlp.meta["version"] == "3.2.0":
self.compare_compatibility_map(
[4, 0, 0, 0.3336621, 5],
self.lg_tendencies_analyzer.get_compatibility_map(
Mention(doc[0], False), doc[4]
),
)
elif self.lg_nlp.meta["version"] == "3.4.0":
self.compare_compatibility_map(
[4, 0, 0, 0.22459798, 5],
self.lg_tendencies_analyzer.get_compatibility_map(
Mention(doc[0], False), doc[4]
),
)
else:
self.fail("Unsupported version.")

@unittest.skipIf(train_version_mismatch, train_version_mismatch_message)
def test_get_cosine_similarity_lg_no_vector_1(self):
Expand All @@ -860,12 +852,23 @@ def test_get_cosine_similarity_lg_no_vector_1(self):
)
self.lg_rules_analyzer.initialize(doc)

self.compare_compatibility_map(
[4, 0, 0, 0.59521705, 5],
self.lg_tendencies_analyzer.get_compatibility_map(
Mention(doc[0], False), doc[4]
),
)
if self.lg_nlp.meta["version"] == "3.2.0":
self.compare_compatibility_map(
[4, 0, 0, 0.59521705, 5],
self.lg_tendencies_analyzer.get_compatibility_map(
Mention(doc[0], False), doc[4]
),
)
elif self.lg_nlp.meta["version"] == "3.4.0":
self.compare_compatibility_map(
[4, 0, 0, 0.25001550, 1],
self.lg_tendencies_analyzer.get_compatibility_map(
Mention(doc[0], False), doc[4]
),
)
else:
self.fail("Unsupported version.")


@unittest.skipIf(train_version_mismatch, train_version_mismatch_message)
def test_get_cosine_similarity_lg_no_vector_2(self):
Expand All @@ -874,12 +877,23 @@ def test_get_cosine_similarity_lg_no_vector_2(self):
"After Richard arrived, he saifefefwefefd he was entering the big house"
)
self.lg_rules_analyzer.initialize(doc)
self.compare_compatibility_map(
[4, 0, 0, 0.59521705, 5],
self.lg_tendencies_analyzer.get_compatibility_map(
Mention(doc[0], False), doc[4]
),
)
if self.lg_nlp.meta["version"] == "3.2.0":
self.compare_compatibility_map(
[4, 0, 0, 0.59521705, 5],
self.lg_tendencies_analyzer.get_compatibility_map(
Mention(doc[0], False), doc[4]
),
)
elif self.lg_nlp.meta["version"] == "3.4.0":
self.compare_compatibility_map(
[4, 0, 0, 0.52951097, 2],
self.lg_tendencies_analyzer.get_compatibility_map(
Mention(doc[0], False), doc[4]
),
)
else:
self.fail("Unsupported version.")


@unittest.skipIf(train_version_mismatch, train_version_mismatch_message)
def test_get_cosine_similarity_sm_root_1(self):
Expand Down
2 changes: 2 additions & 0 deletions tests/de/test_rules_de.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from coreferee.data_model import Mention

nlps = get_nlps("de")
if len(nlps) == 0:
raise unittest.SkipTest("Model version not supported.")
train_version_mismatch = False
for nlp in nlps:
if not nlp.meta["matches_train_version"]:
Expand Down
5 changes: 3 additions & 2 deletions tests/de/test_smoke_tests_de.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import unittest
from coreferee.test_utils import get_nlps, debug_structures
from coreferee.test_utils import get_nlps

nlps = get_nlps("de")
if len(nlps) == 0:
raise unittest.SkipTest("Model version not supported.")
train_version_mismatch = False
for nlp in nlps:
if not nlp.meta["matches_train_version"]:
Expand Down Expand Up @@ -34,7 +36,6 @@ def func(nlp):
return

doc = nlp(doc_text)
debug_structures(doc)
chains_representation = str(doc._.coref_chains)
if alternative_expected_coref_chains is None:
self.assertEqual(
Expand Down
5 changes: 4 additions & 1 deletion tests/en/test_rules_en.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from coreferee.data_model import Mention

nlps = get_nlps("en")
if len(nlps) == 0:
raise unittest.SkipTest("Model version not supported.")
train_version_mismatch = False
for nlp in nlps:
if not nlp.meta["matches_train_version"]:
Expand Down Expand Up @@ -433,7 +435,8 @@ def test_potential_pair_he_she_antecedent_person_noun(self):
@unittest.skipIf(train_version_mismatch, train_version_mismatch_message)
def test_potential_pair_he_she_antecedent_non_person_proper_noun(self):
self.compare_potential_pair(
"I worked for Skateboards plc. She was there", 4, False, 6, 1
"I worked for Skateboards plc. She was there", 4, False, 6, 1,
excluded_nlps=["core_web_sm"]
)

def test_potential_pair_it_exclusively_person_antecedent(self):
Expand Down
14 changes: 12 additions & 2 deletions tests/en/test_smoke_tests_en.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
import unittest
from coreferee.test_utils import get_nlps

nlps = get_nlps("en")
if len(nlps) == 0:
raise unittest.SkipTest("Model version not supported.")
train_version_mismatch = False
for nlp in nlps:
if not nlp.meta["matches_train_version"]:
train_version_mismatch = True
train_version_mismatch_message = (
"Loaded model version does not match train model version"
)


class EnglishSmokeTest(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -31,8 +42,6 @@ def func(nlp):
expected_coref_chains, chains_representation, nlp.meta["name"]
)
else:
print(nlp.meta["name"])
print(chains_representation)
self.assertTrue(
expected_coref_chains == chains_representation
or alternative_expected_coref_chains == chains_representation
Expand Down Expand Up @@ -93,6 +102,7 @@ def test_proper_noun_coreference_multiword_only_second_repeated(self):
def test_proper_noun_coreference_multiword_only_first_repeated(self):
self.compare_annotations("I saw Peter Paul. Peter was chasing a cat.", "[]")

@unittest.skipIf(train_version_mismatch, train_version_mismatch_message)
def test_common_noun_coreference(self):
self.compare_annotations(
"I saw a big dog. The dog was chasing a cat. It was wagging its tail",
Expand Down
5 changes: 2 additions & 3 deletions tests/fr/test_rules_fr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
from coreferee.test_utils import get_nlps
from coreferee.data_model import Mention

try:
nlps = get_nlps("fr")
except ModelNotSupportedError:
nlps = get_nlps("fr")
if len(nlps) == 0:
raise unittest.SkipTest("Model version not supported.")


Expand Down
5 changes: 2 additions & 3 deletions tests/fr/test_smoke_tests_fr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
from coreferee.errors import ModelNotSupportedError
from coreferee.test_utils import get_nlps

try:
nlps = get_nlps("fr")
except ModelNotSupportedError:
nlps = get_nlps("fr")
if len(nlps) == 0:
raise unittest.SkipTest("Model version not supported.")

train_version_mismatch = False
Expand Down
2 changes: 2 additions & 0 deletions tests/pl/test_rules_pl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from coreferee.data_model import Mention

nlps = get_nlps("pl")
if len(nlps) == 0:
raise unittest.SkipTest("Model version not supported.")
train_version_mismatch = False
for nlp in nlps:
if not nlp.meta["matches_train_version"]:
Expand Down
7 changes: 3 additions & 4 deletions tests/pl/test_smoke_tests_pl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import unittest
from coreferee.test_utils import get_nlps, debug_structures
from coreferee.test_utils import get_nlps

nlps = get_nlps("pl")
if len(nlps) == 0:
raise unittest.SkipTest("Model version not supported.")
train_version_mismatch = False
for nlp in nlps:
if not nlp.meta["matches_train_version"]:
Expand Down Expand Up @@ -30,9 +32,6 @@ def func(nlp):
return

doc = nlp(doc_text)
debug_structures(doc)
if len(doc) > 5:
print(doc[5].morph)
chains_representation = str(doc._.coref_chains)
if alternative_expected_coref_chains is None:
self.assertEqual(
Expand Down

0 comments on commit 133cd91

Please sign in to comment.