Skip to content

Commit

Permalink
fix category based ap logging (#122)
Browse files Browse the repository at this point in the history
* fix category based ap logging

* fix typo

* add class based ap html export
  • Loading branch information
fcakyon authored Jun 24, 2022
1 parent 6ccf614 commit bd8b108
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 16 deletions.
2 changes: 1 addition & 1 deletion yolov5/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from yolov5.helpers import YOLOv5
from yolov5.helpers import load_model as load

__version__ = "6.1.3"
__version__ = "6.1.4"
24 changes: 12 additions & 12 deletions yolov5/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,22 +452,22 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
final_epoch = (epoch + 1 == epochs) or stopper.possible_stop
if not noval or final_epoch: # Calculate mAP
results, maps, _ = val.run(data_dict,
batch_size=batch_size // WORLD_SIZE * 2,
imgsz=imgsz,
model=ema.ema,
single_cls=single_cls,
dataloader=val_loader,
save_dir=save_dir,
plots=False,
callbacks=callbacks,
compute_loss=compute_loss)
results, maps, map50s, _ = val.run(data_dict,
batch_size=batch_size // WORLD_SIZE * 2,
imgsz=imgsz,
model=ema.ema,
single_cls=single_cls,
dataloader=val_loader,
save_dir=save_dir,
plots=False,
callbacks=callbacks,
compute_loss=compute_loss)

# Update best mAP
fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, [email protected], [email protected]]
if fi > best_fitness:
best_fitness = fi
log_vals = list(mloss) + list(results) + lr + list(maps)
log_vals = list(mloss) + list(results) + lr + list(maps) + list(map50s)
callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitness, fi)

# Save model
Expand Down Expand Up @@ -534,7 +534,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
strip_optimizer(f) # strip optimizers
if f is best:
LOGGER.info(f'\nValidating {f}...')
results, _, _ = val.run(
results, _, _, _ = val.run(
data_dict,
batch_size=batch_size // WORLD_SIZE * 2,
imgsz=imgsz,
Expand Down
4 changes: 2 additions & 2 deletions yolov5/utils/loggers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ def __init__(self, save_dir=None, weights=None, opt=None, hyp=None, logger=None,
self.csv = True # always log to csv
self.class_names = class_names
if not mmdet_keys:
self.class_name_keys = ['metrics/' + name + '_mAP_50' for name in class_names]
self.class_name_keys = ['metrics/' + name + '_mAP' for name in class_names] + ['metrics/' + name + '_mAP_50' for name in class_names]
else:
self.class_name_keys = ['val/' + name + '_mAP_50' for name in class_names]
self.class_name_keys = ['val/' + name + '_mAP' for name in class_names] + ['val/' + name + '_mAP_50' for name in class_names]
self.s3_weight_folder = None if not opt.s3_upload_dir else "s3://" + str(Path(opt.s3_upload_dir.replace("s3://","")) / save_dir.name / "weights").replace(os.sep, '/')

# Message
Expand Down
19 changes: 18 additions & 1 deletion yolov5/val.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from threading import Thread

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm

Expand Down Expand Up @@ -294,6 +295,19 @@ def run(
for i, c in enumerate(ap_class):
LOGGER.info(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i]))

# Export results as html
header = "Class Images Labels P R [email protected] [email protected]:.95"
headers = header.split()
data = []
data.append(['all', seen, nt.sum(), f"{float(mp):0.3f}", f"{float(mr):0.3f}", f"{float(map50):0.3f}", f"{float(map):0.3f}"])
for i, c in enumerate(ap_class):
data.append([names[c], seen, nt[c], f"{float(p[i]):0.3f}", f"{float(r[i]):0.3f}", f"{float(ap50[i]):0.3f}", f"{float(ap[i]):0.3f}"])
results_df = pd.DataFrame(data,columns=headers)
results_html = results_df.to_html()
text_file = open(save_dir / "results.html", "w")
text_file.write(results_html)
text_file.close()

# Print speeds
t = tuple(x / seen * 1E3 for x in dt) # speeds per image
if not training:
Expand Down Expand Up @@ -339,7 +353,10 @@ def run(
maps = np.zeros(nc) + map
for i, c in enumerate(ap_class):
maps[c] = ap[i]
return (mp, mr, map50, map, *(loss.cpu() / len(dataloader)).tolist()), maps, t
map50s = np.zeros(nc) + map50
for i, c in enumerate(ap_class):
map50s[c] = ap50[i]
return (mp, mr, map50, map, *(loss.cpu() / len(dataloader)).tolist()), maps, map50s, t


def parse_opt():
Expand Down

0 comments on commit bd8b108

Please sign in to comment.