Skip to content

Commit

Permalink
Merge pull request #63 from chakki-works/enhancement/speedUp
Browse files Browse the repository at this point in the history
Enhancement/speed up
  • Loading branch information
Hironsan authored Oct 17, 2020
2 parents 8db3fe6 + d88969e commit 29a0e1e
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 38 deletions.
24 changes: 17 additions & 7 deletions seqeval/metrics/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,8 @@ def precision_recall_fscore_support(y_true: List[List[str]],
sample_weight: Optional[List[int]] = None,
zero_division: str = 'warn',
scheme: Optional[Type[Token]] = None,
suffix: bool = False) -> SCORES:
suffix: bool = False,
**kwargs) -> SCORES:
"""Compute precision, recall, F-measure and support for each class.
Args:
Expand Down Expand Up @@ -288,9 +289,11 @@ def precision_recall_fscore_support(y_true: List[List[str]],
modified with ``zero_division``.
"""
def extract_tp_actual_correct(y_true, y_pred, suffix, scheme):
target_names = unique_labels(y_true, y_pred, scheme, suffix)
entities_true = Entities(y_true, scheme, suffix)
entities_pred = Entities(y_pred, scheme, suffix)
# If this function is called from classification_report,
# try to reuse entities to optimize the function.
entities_true = kwargs.get('entities_true') or Entities(y_true, scheme, suffix)
entities_pred = kwargs.get('entities_pred') or Entities(y_pred, scheme, suffix)
target_names = sorted(entities_true.unique_tags | entities_pred.unique_tags)

tp_sum = np.array([], dtype=np.int32)
pred_sum = np.array([], dtype=np.int32)
Expand Down Expand Up @@ -376,7 +379,10 @@ def classification_report(y_true: List[List[str]],

if scheme is None or not issubclass(scheme, Token):
scheme = auto_detect(y_true, suffix)
target_names = unique_labels(y_true, y_pred, scheme, suffix)

entities_true = Entities(y_true, scheme, suffix)
entities_pred = Entities(y_pred, scheme, suffix)
target_names = sorted(entities_true.unique_tags | entities_pred.unique_tags)

if output_dict:
reporter = DictReporter()
Expand All @@ -393,7 +399,9 @@ def classification_report(y_true: List[List[str]],
sample_weight=sample_weight,
zero_division=zero_division,
scheme=scheme,
suffix=suffix
suffix=suffix,
entities_true=entities_true,
entities_pred=entities_pred
)
for row in zip(target_names, p, r, f1, s):
reporter.write(*row)
Expand All @@ -408,7 +416,9 @@ def classification_report(y_true: List[List[str]],
sample_weight=sample_weight,
zero_division=zero_division,
scheme=scheme,
suffix=suffix
suffix=suffix,
entities_true=entities_true,
entities_pred=entities_pred
)
reporter.write('{} avg'.format(average), avg_p, avg_r, avg_f1, support)
reporter.write_blank()
Expand Down
33 changes: 9 additions & 24 deletions seqeval/scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ class Prefix(enum.Flag):
ANY = I | O | B | E | S | U | L


Prefixes = dict(Prefix.__members__)


class Tag(enum.Flag):
SAME = enum.auto()
DIFF = enum.auto()
Expand All @@ -49,25 +52,13 @@ class Token:

def __init__(self, token: str, suffix: bool = False, delimiter: str = '-'):
self.token = token
self.suffix = suffix
self.delimiter = delimiter
self.prefix = Prefixes[token[-1]] if suffix else Prefixes[token[0]]
tag = token[:-1] if suffix else token[1:]
self.tag = tag.strip(delimiter) or '_'

def __repr__(self):
return self.token

@property
def prefix(self):
"""Extracts a prefix from the token."""
prefix = self.token[-1] if self.suffix else self.token[0]
return Prefix[prefix]

@property
def tag(self):
"""Extracts a tag from the token."""
tag = self.token[:-1] if self.suffix else self.token[1:]
tag = tag.strip(self.delimiter) or '_'
return tag

def is_valid(self):
"""Check whether the prefix is allowed or not."""
if self.prefix not in self.allowed_prefix:
Expand Down Expand Up @@ -229,9 +220,9 @@ class Tokens:

def __init__(self, tokens: List[str], scheme: Type[Token],
suffix: bool = False, delimiter: str = '-', sent_id: int = None):
self.tokens = [scheme(token, suffix=suffix, delimiter=delimiter) for token in tokens]
self.scheme = scheme
self.outside_token = scheme('O', suffix=suffix, delimiter=delimiter)
self.tokens = [scheme(token, suffix=suffix, delimiter=delimiter) for token in tokens]
self.extended_tokens = self.tokens + [self.outside_token]
self.sent_id = sent_id

@property
Expand Down Expand Up @@ -276,12 +267,6 @@ def _is_end(self, i: int):
prev = self.extended_tokens[i - 1]
return token.is_end(prev)

@property
def extended_tokens(self):
# append a sentinel.
tokens = self.tokens + [self.outside_token]
return tokens


class Entities:

Expand Down Expand Up @@ -315,8 +300,8 @@ def auto_detect(sequences: List[List[str]], suffix: bool = False, delimiter: str
error_message = 'This scheme is not supported: {}'
for tokens in sequences:
for token in tokens:
token = Token(token, suffix=suffix, delimiter=delimiter)
try:
token = Token(token, suffix=suffix, delimiter=delimiter)
prefixes.add(token.prefix)
except KeyError:
raise ValueError(error_message.format(token))
Expand Down
15 changes: 8 additions & 7 deletions tests/test_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,22 +627,23 @@ def test_bilou_tokens_without_tag(tokens, expected):
class TestToken:

def test_raises_type_error_if_input_is_binary_string(self):
token = Token('I-組織'.encode('utf-8'))
with pytest.raises(TypeError):
tag = token.tag
with pytest.raises(KeyError):
token = Token('I-組織'.encode('utf-8'))

def test_raises_index_error_if_input_is_empty_string(self):
token = Token('')
with pytest.raises(IndexError):
prefix = token.prefix
token = Token('')

def test_representation(self):
token = Token('B-ORG')
assert 'B-ORG' == str(token)


class TestIOB2Token:

def test_invalid_prefix(self):
token = IOB2('T')
with pytest.raises(KeyError):
prefix = token.prefix
token = IOB2('T')


@pytest.mark.parametrize(
Expand Down

0 comments on commit 29a0e1e

Please sign in to comment.