Skip to content

Commit

Permalink
Merge pull request #51 from kaushikacharya/classification_report_dict
Browse files Browse the repository at this point in the history
classification_report outputs string/dict as requested in issue #41
  • Loading branch information
Hironsan authored Oct 7, 2020
2 parents 6367a00 + 4216f3b commit a48a9d1
Showing 1 changed file with 51 additions and 31 deletions.
82 changes: 51 additions & 31 deletions seqeval/metrics/sequence_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,16 +301,17 @@ def performance_measure(y_true, y_pred):
return performance_dict


def classification_report(y_true, y_pred, digits=2, suffix=False):
def classification_report(y_true, y_pred, digits=2, suffix=False, output_dict=False):
"""Build a text report showing the main classification metrics.
Args:
y_true : 2d array. Ground truth (correct) target values.
y_pred : 2d array. Estimated targets as returned by a classifier.
digits : int. Number of digits for formatting output floating point values.
output_dict : bool(default=False). If True, return output as dict else str.
Returns:
report : string. Text summary of the precision, recall, F1 score for each class.
report : string/dict. Summary of the precision, recall, F1 score for each class.
Examples:
>>> from seqeval.metrics import classification_report
Expand All @@ -324,6 +325,7 @@ def classification_report(y_true, y_pred, digits=2, suffix=False):
<BLANKLINE>
micro avg 0.50 0.50 0.50 2
macro avg 0.50 0.50 0.50 2
weighted avg 0.50 0.50 0.50 2
<BLANKLINE>
"""
true_entities = set(get_entities(y_true, suffix))
Expand All @@ -338,15 +340,19 @@ def classification_report(y_true, y_pred, digits=2, suffix=False):
for e in pred_entities:
d2[e[0]].add((e[1], e[2]))

last_line_heading = 'weighted avg'
width = max(name_width, len(last_line_heading), digits)
avg_types = ['micro avg', 'macro avg', 'weighted avg']

headers = ["precision", "recall", "f1-score", "support"]
head_fmt = u'{:>{width}s} ' + u' {:>9}' * len(headers)
report = head_fmt.format(u'', *headers, width=width)
report += u'\n\n'
if output_dict:
report_dict = dict()
else:
avg_width = max([len(x) for x in avg_types])
width = max(name_width, avg_width, digits)
headers = ["precision", "recall", "f1-score", "support"]
head_fmt = u'{:>{width}s} ' + u' {:>9}' * len(headers)
report = head_fmt.format(u'', *headers, width=width)
report += u'\n\n'

row_fmt = u'{:>{width}s} ' + u' {:>9.{digits}f}' * 3 + u' {:>9}\n'
row_fmt = u'{:>{width}s} ' + u' {:>9.{digits}f}' * 3 + u' {:>9}\n'

ps, rs, f1s, s = [], [], [], []
for type_name in sorted(d1.keys()):
Expand All @@ -360,33 +366,47 @@ def classification_report(y_true, y_pred, digits=2, suffix=False):
r = nb_correct / nb_true if nb_true > 0 else 0
f1 = 2 * p * r / (p + r) if p + r > 0 else 0

report += row_fmt.format(*[type_name, p, r, f1, nb_true], width=width, digits=digits)
if output_dict:
report_dict[type_name] = {'precision': p, 'recall': r, 'f1-score': f1, 'support': nb_true}
else:
report += row_fmt.format(*[type_name, p, r, f1, nb_true], width=width, digits=digits)

ps.append(p)
rs.append(r)
f1s.append(f1)
s.append(nb_true)

report += u'\n'
if not output_dict:
report += u'\n'

# compute averages
report += row_fmt.format('micro avg',
precision_score(y_true, y_pred, suffix=suffix),
recall_score(y_true, y_pred, suffix=suffix),
f1_score(y_true, y_pred, suffix=suffix),
np.sum(s),
width=width, digits=digits)
report += row_fmt.format('macro avg',
np.average(ps),
np.average(rs),
np.average(f1s),
np.sum(s),
width=width, digits=digits)
report += row_fmt.format(last_line_heading,
np.average(ps, weights=s),
np.average(rs, weights=s),
np.average(f1s, weights=s),
np.sum(s),
width=width, digits=digits)

return report
nb_true = np.sum(s)

for avg_type in avg_types:
if avg_type == 'micro avg':
# micro average
p = precision_score(y_true, y_pred, suffix=suffix)
r = recall_score(y_true, y_pred, suffix=suffix)
f1 = f1_score(y_true, y_pred, suffix=suffix)
elif avg_type == 'macro avg':
# macro average
p = np.average(ps)
r = np.average(rs)
f1 = np.average(f1s)
elif avg_type == 'weighted avg':
# weighted average
p = np.average(ps, weights=s)
r = np.average(rs, weights=s)
f1 = np.average(f1s, weights=s)
else:
assert False, "unexpected average: {}".format(avg_type)

if output_dict:
report_dict[avg_type] = {'precision': p, 'recall': r, 'f1-score': f1, 'support': nb_true}
else:
report += row_fmt.format(*[avg_type, p, r, f1, nb_true], width=width, digits=digits)

if output_dict:
return report_dict
else:
return report

0 comments on commit a48a9d1

Please sign in to comment.