From 41603da1624a8590889349e75be02001c6c31899 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 25 Feb 2024 14:04:01 +0100 Subject: [PATCH] Add docstrings to YOLOv5 functions (#12760) * Add docstrings to top level files * Auto-format by https://ultralytics.com/actions * Add docstrings * Auto-format by https://ultralytics.com/actions * Add docstrings * Auto-format by https://ultralytics.com/actions * Add docstrings * Add docstrings * Auto-format by https://ultralytics.com/actions * Add docstrings * Add docstrings * Auto-format by https://ultralytics.com/actions * Add docstrings * Auto-format by https://ultralytics.com/actions * Add docstrings * Auto-format by https://ultralytics.com/actions * Add docstrings * Auto-format by https://ultralytics.com/actions * Add docstrings * Auto-format by https://ultralytics.com/actions * Add docstrings * Auto-format by https://ultralytics.com/actions * Add docstrings * Add docstrings * Auto-format by https://ultralytics.com/actions * Add docstrings * Auto-format by https://ultralytics.com/actions * Add docstrings * Auto-format by https://ultralytics.com/actions * Add docstrings * Auto-format by https://ultralytics.com/actions * Add docstrings * Auto-format by https://ultralytics.com/actions * Add docstrings * Auto-format by https://ultralytics.com/actions * Add docstrings * Auto-format by https://ultralytics.com/actions * Update activations.py Signed-off-by: Glenn Jocher * Auto-format by https://ultralytics.com/actions --------- Signed-off-by: Glenn Jocher Co-authored-by: UltralyticsAssistant --- benchmarks.py | 2 + classify/predict.py | 2 + classify/train.py | 12 +- classify/val.py | 2 + detect.py | 3 + export.py | 46 ++++++-- hubconf.py | 42 +++++-- models/common.py | 133 +++++++++++++++++++-- models/experimental.py | 15 ++- models/tf.py | 114 ++++++++++++++++-- models/yolo.py | 33 +++++- segment/predict.py | 4 + segment/train.py | 13 ++- segment/val.py | 14 ++- train.py | 10 +- utils/__init__.py | 180 +++++++++++++++-------------- utils/activations.py | 43 +++++-- utils/augmentations.py | 47 ++++++-- utils/autoanchor.py | 4 +- utils/autobatch.py | 4 +- utils/callbacks.py | 2 +- utils/dataloaders.py | 109 +++++++++++++---- utils/downloads.py | 20 +++- utils/flask_rest_api/restapi.py | 3 + utils/general.py | 161 +++++++++++++++++--------- utils/loggers/__init__.py | 39 ++++--- utils/loggers/comet/__init__.py | 29 +++++ utils/loggers/comet/comet_utils.py | 1 + utils/loggers/comet/hpo.py | 4 + utils/loggers/wandb/wandb_utils.py | 1 + utils/loss.py | 19 ++- utils/metrics.py | 26 ++++- utils/plots.py | 42 +++++-- utils/segment/augmentations.py | 6 +- utils/segment/dataloaders.py | 4 +- utils/segment/general.py | 4 +- utils/segment/loss.py | 9 +- utils/segment/metrics.py | 10 +- utils/segment/plots.py | 8 +- utils/torch_utils.py | 72 ++++++++---- utils/triton.py | 1 + val.py | 12 +- 42 files changed, 983 insertions(+), 322 deletions(-) diff --git a/benchmarks.py b/benchmarks.py index 09e82e588a2a..100cabacdc97 100644 --- a/benchmarks.py +++ b/benchmarks.py @@ -149,6 +149,7 @@ def test( def parse_opt(): + """Parses command-line arguments for YOLOv5 model inference configuration.""" parser = argparse.ArgumentParser() parser.add_argument("--weights", type=str, default=ROOT / "yolov5s.pt", help="weights path") parser.add_argument("--imgsz", "--img", "--img-size", type=int, default=640, help="inference size (pixels)") @@ -166,6 +167,7 @@ def parse_opt(): def main(opt): + """Executes a test run if `opt.test` is True, otherwise starts training or inference with provided options.""" test(**vars(opt)) if opt.test else run(**vars(opt)) diff --git a/classify/predict.py b/classify/predict.py index b7d2f05d7bce..3139d82e7b7d 100644 --- a/classify/predict.py +++ b/classify/predict.py @@ -204,6 +204,7 @@ def run( def parse_opt(): + """Parses command line arguments for YOLOv5 inference settings including model, source, device, and image size.""" parser = argparse.ArgumentParser() parser.add_argument("--weights", nargs="+", type=str, default=ROOT / "yolov5s-cls.pt", help="model path(s)") parser.add_argument("--source", type=str, default=ROOT / "data/images", help="file/dir/URL/glob/screen/0(webcam)") @@ -229,6 +230,7 @@ def parse_opt(): def main(opt): + """Executes YOLOv5 model inference with options for ONNX DNN and video frame-rate stride adjustments.""" check_requirements(ROOT / "requirements.txt", exclude=("tensorboard", "thop")) run(**vars(opt)) diff --git a/classify/train.py b/classify/train.py index 63befed0f780..5556e03edff5 100644 --- a/classify/train.py +++ b/classify/train.py @@ -76,6 +76,7 @@ def train(opt, device): + """Trains a YOLOv5 model, managing datasets, model optimization, logging, and saving checkpoints.""" init_seeds(opt.seed + 1 + RANK, deterministic=True) save_dir, data, bs, epochs, nw, imgsz, pretrained = ( opt.save_dir, @@ -306,6 +307,9 @@ def train(opt, device): def parse_opt(known=False): + """Parses command line arguments for YOLOv5 training including model path, dataset, epochs, and more, returning + parsed arguments. + """ parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, default="yolov5s-cls.pt", help="initial weights path") parser.add_argument("--data", type=str, default="imagenette160", help="cifar10, cifar100, mnist, imagenet, ...") @@ -333,7 +337,7 @@ def parse_opt(known=False): def main(opt): - # Checks + """Executes YOLOv5 training with given options, handling device setup and DDP mode; includes pre-training checks.""" if RANK in {-1, 0}: print_args(vars(opt)) check_git_status() @@ -357,7 +361,11 @@ def main(opt): def run(**kwargs): - # Usage: from yolov5 import classify; classify.train.run(data=mnist, imgsz=320, model='yolov5m') + """ + Executes YOLOv5 model training or inference with specified parameters, returning updated options. + + Example: from yolov5 import classify; classify.train.run(data=mnist, imgsz=320, model='yolov5m') + """ opt = parse_opt(True) for k, v in kwargs.items(): setattr(opt, k, v) diff --git a/classify/val.py b/classify/val.py index b170253d6e0c..427618791d65 100644 --- a/classify/val.py +++ b/classify/val.py @@ -147,6 +147,7 @@ def run( def parse_opt(): + """Parses and returns command line arguments for YOLOv5 model evaluation and inference settings.""" parser = argparse.ArgumentParser() parser.add_argument("--data", type=str, default=ROOT / "../datasets/mnist", help="dataset path") parser.add_argument("--weights", nargs="+", type=str, default=ROOT / "yolov5s-cls.pt", help="model.pt path(s)") @@ -166,6 +167,7 @@ def parse_opt(): def main(opt): + """Executes the YOLOv5 model prediction workflow, handling argument parsing and requirement checks.""" check_requirements(ROOT / "requirements.txt", exclude=("tensorboard", "thop")) run(**vars(opt)) diff --git a/detect.py b/detect.py index b7d77ef431d4..c58aa80a68fc 100644 --- a/detect.py +++ b/detect.py @@ -166,6 +166,7 @@ def run( # Create or append to the CSV file def write_to_csv(image_name, prediction, confidence): + """Writes prediction data for an image to a CSV file, appending if the file exists.""" data = {"Image Name": image_name, "Prediction": prediction, "Confidence": confidence} with open(csv_path, mode="a", newline="") as f: writer = csv.DictWriter(f, fieldnames=data.keys()) @@ -264,6 +265,7 @@ def write_to_csv(image_name, prediction, confidence): def parse_opt(): + """Parses command-line arguments for YOLOv5 detection, setting inference options and model configurations.""" parser = argparse.ArgumentParser() parser.add_argument("--weights", nargs="+", type=str, default=ROOT / "yolov5s.pt", help="model path or triton URL") parser.add_argument("--source", type=str, default=ROOT / "data/images", help="file/dir/URL/glob/screen/0(webcam)") @@ -300,6 +302,7 @@ def parse_opt(): def main(opt): + """Executes YOLOv5 model inference with given options, checking requirements before running the model.""" check_requirements(ROOT / "requirements.txt", exclude=("tensorboard", "thop")) run(**vars(opt)) diff --git a/export.py b/export.py index 99ea15e7030c..9ea2b936d740 100644 --- a/export.py +++ b/export.py @@ -92,6 +92,7 @@ class iOSModel(torch.nn.Module): def __init__(self, model, im): + """Initializes an iOS compatible model with normalization based on image dimensions.""" super().__init__() b, c, h, w = im.shape # batch, channel, height, width self.model = model @@ -104,12 +105,13 @@ def __init__(self, model, im): # self.normalize = torch.tensor([1. / w, 1. / h, 1. / w, 1. / h]).expand(np, 4) # explicit (faster, larger) def forward(self, x): + """Runs forward pass on the input tensor, returning class confidences and normalized coordinates.""" xywh, conf, cls = self.model(x)[0].squeeze().split((4, 1, self.nc), 1) return cls * conf, xywh * self.normalize # confidence (3780, 80), coordinates (3780, 4) def export_formats(): - # YOLOv5 export formats + """Returns a DataFrame of supported YOLOv5 model export formats and their properties.""" x = [ ["PyTorch", "-", ".pt", True, True], ["TorchScript", "torchscript", ".torchscript", True, True], @@ -128,7 +130,7 @@ def export_formats(): def try_export(inner_func): - # YOLOv5 export decorator, i..e @try_export + """Decorator @try_export for YOLOv5 model export functions that logs success/failure, time taken, and file size.""" inner_args = get_default_args(inner_func) def outer_func(*args, **kwargs): @@ -147,7 +149,9 @@ def outer_func(*args, **kwargs): @try_export def export_torchscript(model, im, file, optimize, prefix=colorstr("TorchScript:")): - # YOLOv5 TorchScript model export + """Exports YOLOv5 model to TorchScript format, optionally optimized for mobile, with image shape and stride + metadata. + """ LOGGER.info(f"\n{prefix} starting export with torch {torch.__version__}...") f = file.with_suffix(".torchscript") @@ -163,7 +167,7 @@ def export_torchscript(model, im, file, optimize, prefix=colorstr("TorchScript:" @try_export def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr("ONNX:")): - # YOLOv5 ONNX export + """Exports a YOLOv5 model to ONNX format with dynamic axes and optional simplification.""" check_requirements("onnx>=1.12.0") import onnx @@ -276,7 +280,9 @@ def transform_fn(data_item): @try_export def export_paddle(model, im, file, metadata, prefix=colorstr("PaddlePaddle:")): - # YOLOv5 Paddle export + """Exports a YOLOv5 model to PaddlePaddle format using X2Paddle, saving to `save_dir` and adding a metadata.yaml + file. + """ check_requirements(("paddlepaddle", "x2paddle")) import x2paddle from x2paddle.convert import pytorch2paddle @@ -291,7 +297,7 @@ def export_paddle(model, im, file, metadata, prefix=colorstr("PaddlePaddle:")): @try_export def export_coreml(model, im, file, int8, half, nms, prefix=colorstr("CoreML:")): - # YOLOv5 CoreML export + """Exports YOLOv5 model to CoreML format with optional NMS, INT8, and FP16 support; requires coremltools.""" check_requirements("coremltools") import coremltools as ct @@ -316,7 +322,11 @@ def export_coreml(model, im, file, int8, half, nms, prefix=colorstr("CoreML:")): @try_export def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose=False, prefix=colorstr("TensorRT:")): - # YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt + """ + Exports a YOLOv5 model to TensorRT engine format, requiring GPU and TensorRT>=7.0.0. + + https://developer.nvidia.com/tensorrt + """ assert im.device.type != "cpu", "export running on CPU but must be on GPU, i.e. `python export.py --device 0`" try: import tensorrt as trt @@ -440,7 +450,7 @@ def export_saved_model( @try_export def export_pb(keras_model, file, prefix=colorstr("TensorFlow GraphDef:")): - # YOLOv5 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow + """Exports YOLOv5 model to TensorFlow GraphDef *.pb format; see https://github.com/leimao/Frozen_Graph_TensorFlow for details.""" import tensorflow as tf from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 @@ -493,7 +503,11 @@ def export_tflite( @try_export def export_edgetpu(file, prefix=colorstr("Edge TPU:")): - # YOLOv5 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/ + """ + Exports a YOLOv5 model to Edge TPU compatible TFLite format; requires Linux and Edge TPU compiler. + + https://coral.ai/docs/edgetpu/models-intro/ + """ cmd = "edgetpu_compiler --version" help_url = "https://coral.ai/docs/edgetpu/compiler/" assert platform.system() == "Linux", f"export only supported on Linux. See {help_url}" @@ -531,7 +545,7 @@ def export_edgetpu(file, prefix=colorstr("Edge TPU:")): @try_export def export_tfjs(file, int8, prefix=colorstr("TensorFlow.js:")): - # YOLOv5 TensorFlow.js export + """Exports a YOLOv5 model to TensorFlow.js format, optionally with uint8 quantization.""" check_requirements("tensorflowjs") import tensorflowjs as tfjs @@ -568,7 +582,11 @@ def export_tfjs(file, int8, prefix=colorstr("TensorFlow.js:")): def add_tflite_metadata(file, metadata, num_outputs): - # Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata + """ + Adds TFLite metadata to a model file, supporting multiple outputs, as specified by TensorFlow guidelines. + + https://www.tensorflow.org/lite/models/convert/metadata + """ with contextlib.suppress(ImportError): # check_requirements('tflite_support') from tflite_support import flatbuffers @@ -601,7 +619,9 @@ def add_tflite_metadata(file, metadata, num_outputs): def pipeline_coreml(model, im, file, names, y, prefix=colorstr("CoreML Pipeline:")): - # YOLOv5 CoreML pipeline + """Converts a PyTorch YOLOv5 model to CoreML format with NMS, handling different input/output shapes and saving the + model. + """ import coremltools as ct from PIL import Image @@ -869,6 +889,7 @@ def run( def parse_opt(known=False): + """Parses command-line arguments for YOLOv5 model export configurations, returning the parsed options.""" parser = argparse.ArgumentParser() parser.add_argument("--data", type=str, default=ROOT / "data/coco128.yaml", help="dataset.yaml path") parser.add_argument("--weights", nargs="+", type=str, default=ROOT / "yolov5s.pt", help="model.pt path(s)") @@ -904,6 +925,7 @@ def parse_opt(known=False): def main(opt): + """Executes the YOLOv5 model inference or export with specified weights and options.""" for opt.weights in opt.weights if isinstance(opt.weights, list) else [opt.weights]: run(**vars(opt)) diff --git a/hubconf.py b/hubconf.py index 691d8eb64749..53afdff62aea 100644 --- a/hubconf.py +++ b/hubconf.py @@ -84,57 +84,77 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo def custom(path="path/to/model.pt", autoshape=True, _verbose=True, device=None): - # YOLOv5 custom or local model + """Loads a custom or local YOLOv5 model from a given path with optional autoshaping and device specification.""" return _create(path, autoshape=autoshape, verbose=_verbose, device=device) def yolov5n(pretrained=True, channels=3, classes=80, autoshape=True, _verbose=True, device=None): - # YOLOv5-nano model https://github.com/ultralytics/yolov5 + """Instantiates the YOLOv5-nano model with options for pretraining, input channels, class count, autoshaping, + verbosity, and device. + """ return _create("yolov5n", pretrained, channels, classes, autoshape, _verbose, device) def yolov5s(pretrained=True, channels=3, classes=80, autoshape=True, _verbose=True, device=None): - # YOLOv5-small model https://github.com/ultralytics/yolov5 + """Creates YOLOv5-small model with options for pretraining, input channels, class count, autoshaping, verbosity, and + device. + """ return _create("yolov5s", pretrained, channels, classes, autoshape, _verbose, device) def yolov5m(pretrained=True, channels=3, classes=80, autoshape=True, _verbose=True, device=None): - # YOLOv5-medium model https://github.com/ultralytics/yolov5 + """Instantiates the YOLOv5-medium model with customizable pretraining, channel count, class count, autoshaping, + verbosity, and device. + """ return _create("yolov5m", pretrained, channels, classes, autoshape, _verbose, device) def yolov5l(pretrained=True, channels=3, classes=80, autoshape=True, _verbose=True, device=None): - # YOLOv5-large model https://github.com/ultralytics/yolov5 + """Creates YOLOv5-large model with options for pretraining, channels, classes, autoshaping, verbosity, and device + selection. + """ return _create("yolov5l", pretrained, channels, classes, autoshape, _verbose, device) def yolov5x(pretrained=True, channels=3, classes=80, autoshape=True, _verbose=True, device=None): - # YOLOv5-xlarge model https://github.com/ultralytics/yolov5 + """Instantiates the YOLOv5-xlarge model with customizable pretraining, channel count, class count, autoshaping, + verbosity, and device. + """ return _create("yolov5x", pretrained, channels, classes, autoshape, _verbose, device) def yolov5n6(pretrained=True, channels=3, classes=80, autoshape=True, _verbose=True, device=None): - # YOLOv5-nano-P6 model https://github.com/ultralytics/yolov5 + """Creates YOLOv5-nano-P6 model with options for pretraining, channels, classes, autoshaping, verbosity, and + device. + """ return _create("yolov5n6", pretrained, channels, classes, autoshape, _verbose, device) def yolov5s6(pretrained=True, channels=3, classes=80, autoshape=True, _verbose=True, device=None): - # YOLOv5-small-P6 model https://github.com/ultralytics/yolov5 + """Instantiate YOLOv5-small-P6 model with options for pretraining, input channels, number of classes, autoshaping, + verbosity, and device selection. + """ return _create("yolov5s6", pretrained, channels, classes, autoshape, _verbose, device) def yolov5m6(pretrained=True, channels=3, classes=80, autoshape=True, _verbose=True, device=None): - # YOLOv5-medium-P6 model https://github.com/ultralytics/yolov5 + """Creates YOLOv5-medium-P6 model with options for pretraining, channel count, class count, autoshaping, verbosity, + and device. + """ return _create("yolov5m6", pretrained, channels, classes, autoshape, _verbose, device) def yolov5l6(pretrained=True, channels=3, classes=80, autoshape=True, _verbose=True, device=None): - # YOLOv5-large-P6 model https://github.com/ultralytics/yolov5 + """Instantiates the YOLOv5-large-P6 model with customizable pretraining, channel and class counts, autoshaping, + verbosity, and device selection. + """ return _create("yolov5l6", pretrained, channels, classes, autoshape, _verbose, device) def yolov5x6(pretrained=True, channels=3, classes=80, autoshape=True, _verbose=True, device=None): - # YOLOv5-xlarge-P6 model https://github.com/ultralytics/yolov5 + """Creates YOLOv5-xlarge-P6 model with options for pretraining, channels, classes, autoshaping, verbosity, and + device. + """ return _create("yolov5x6", pretrained, channels, classes, autoshape, _verbose, device) diff --git a/models/common.py b/models/common.py index b21b42b00d0d..fd8c998149f5 100644 --- a/models/common.py +++ b/models/common.py @@ -22,7 +22,7 @@ from PIL import Image from torch.cuda import amp -# Import 'ultralytics' package or install if if missing +# Import 'ultralytics' package or install if missing try: import ultralytics @@ -71,15 +71,18 @@ class Conv(nn.Module): default_act = nn.SiLU() # default activation def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True): + """Initializes a standard convolution layer with optional batch normalization and activation.""" super().__init__() self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False) self.bn = nn.BatchNorm2d(c2) self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity() def forward(self, x): + """Applies a convolution followed by batch normalization and an activation function to the input tensor `x`.""" return self.act(self.bn(self.conv(x))) def forward_fuse(self, x): + """Applies a fused convolution and activation function to the input tensor `x`.""" return self.act(self.conv(x)) @@ -98,6 +101,11 @@ def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0): # ch_in, ch_out, kernel, stri class TransformerLayer(nn.Module): # Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance) def __init__(self, c, num_heads): + """ + Initializes a transformer layer, sans LayerNorm for performance, with multihead attention and linear layers. + + See as described in https://arxiv.org/abs/2010.11929. + """ super().__init__() self.q = nn.Linear(c, c, bias=False) self.k = nn.Linear(c, c, bias=False) @@ -107,6 +115,7 @@ def __init__(self, c, num_heads): self.fc2 = nn.Linear(c, c, bias=False) def forward(self, x): + """Performs forward pass using MultiheadAttention and two linear transformations with residual connections.""" x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x x = self.fc2(self.fc1(x)) + x return x @@ -115,6 +124,9 @@ def forward(self, x): class TransformerBlock(nn.Module): # Vision Transformer https://arxiv.org/abs/2010.11929 def __init__(self, c1, c2, num_heads, num_layers): + """Initializes a Transformer block for vision tasks, adapting dimensions if necessary and stacking specified + layers. + """ super().__init__() self.conv = None if c1 != c2: @@ -124,6 +136,9 @@ def __init__(self, c1, c2, num_heads, num_layers): self.c2 = c2 def forward(self, x): + """Processes input through an optional convolution, followed by Transformer layers and position embeddings for + object detection. + """ if self.conv is not None: x = self.conv(x) b, _, w, h = x.shape @@ -141,6 +156,9 @@ def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcu self.add = shortcut and c1 == c2 def forward(self, x): + """Processes input through two convolutions, optionally adds shortcut if channel dimensions match; input is a + tensor. + """ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) @@ -158,6 +176,9 @@ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, nu self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n))) def forward(self, x): + """Performs forward pass by applying layers, activation, and concatenation on input x, returning feature- + enhanced output. + """ y1 = self.cv3(self.m(self.cv1(x))) y2 = self.cv2(x) return self.cv4(self.act(self.bn(torch.cat((y1, y2), 1)))) @@ -166,7 +187,12 @@ def forward(self, x): class CrossConv(nn.Module): # Cross Convolution Downsample def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False): - # ch_in, ch_out, kernel, stride, groups, expansion, shortcut + """ + Initializes CrossConv with downsampling, expanding, and optionally shortcutting; `c1` input, `c2` output + channels. + + Inputs are ch_in, ch_out, kernel, stride, groups, expansion, shortcut. + """ super().__init__() c_ = int(c2 * e) # hidden channels self.cv1 = Conv(c1, c_, (1, k), (1, s)) @@ -174,6 +200,7 @@ def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False): self.add = shortcut and c1 == c2 def forward(self, x): + """Performs feature sampling, expanding, and applies shortcut if channels match; expects `x` input tensor.""" return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) @@ -188,12 +215,16 @@ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, nu self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n))) def forward(self, x): + """Performs forward propagation using concatenated outputs from two convolutions and a Bottleneck sequence.""" return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1)) class C3x(C3): # C3 module with cross-convolutions def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): + """Initializes C3x module with cross-convolutions, extending C3 with customizable channel dimensions, groups, + and expansion. + """ super().__init__(c1, c2, n, shortcut, g, e) c_ = int(c2 * e) self.m = nn.Sequential(*(CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n))) @@ -202,6 +233,9 @@ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): class C3TR(C3): # C3 module with TransformerBlock() def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): + """Initializes C3 module with TransformerBlock for enhanced feature extraction, accepts channel sizes, shortcut + config, group, and expansion. + """ super().__init__(c1, c2, n, shortcut, g, e) c_ = int(c2 * e) self.m = TransformerBlock(c_, c_, 4, n) @@ -210,6 +244,9 @@ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): class C3SPP(C3): # C3 module with SPP() def __init__(self, c1, c2, k=(5, 9, 13), n=1, shortcut=True, g=1, e=0.5): + """Initializes a C3 module with SPP layer for advanced spatial feature extraction, given channel sizes, kernel + sizes, shortcut, group, and expansion ratio. + """ super().__init__(c1, c2, n, shortcut, g, e) c_ = int(c2 * e) self.m = SPP(c_, c_, k) @@ -218,6 +255,7 @@ def __init__(self, c1, c2, k=(5, 9, 13), n=1, shortcut=True, g=1, e=0.5): class C3Ghost(C3): # C3 module with GhostBottleneck() def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): + """Initializes YOLOv5's C3 module with Ghost Bottlenecks for efficient feature extraction.""" super().__init__(c1, c2, n, shortcut, g, e) c_ = int(c2 * e) # hidden channels self.m = nn.Sequential(*(GhostBottleneck(c_, c_) for _ in range(n))) @@ -226,6 +264,7 @@ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): class SPP(nn.Module): # Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729 def __init__(self, c1, c2, k=(5, 9, 13)): + """Initializes SPP layer with Spatial Pyramid Pooling, ref: https://arxiv.org/abs/1406.4729, args: c1 (input channels), c2 (output channels), k (kernel sizes).""" super().__init__() c_ = c1 // 2 # hidden channels self.cv1 = Conv(c1, c_, 1, 1) @@ -233,6 +272,9 @@ def __init__(self, c1, c2, k=(5, 9, 13)): self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k]) def forward(self, x): + """Applies convolution and max pooling layers to the input tensor `x`, concatenates results, and returns output + tensor. + """ x = self.cv1(x) with warnings.catch_warnings(): warnings.simplefilter("ignore") # suppress torch 1.9.0 max_pool2d() warning @@ -249,6 +291,7 @@ def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13)) self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2) def forward(self, x): + """Processes input through a series of convolutions and max pooling operations for feature extraction.""" x = self.cv1(x) with warnings.catch_warnings(): warnings.simplefilter("ignore") # suppress torch 1.9.0 max_pool2d() warning @@ -278,6 +321,7 @@ def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, s self.cv2 = Conv(c_, c_, 5, 1, None, c_, act=act) def forward(self, x): + """Performs forward pass, concatenating outputs of two convolutions on input `x`: shape (B,C,H,W).""" y = self.cv1(x) return torch.cat((y, self.cv2(y)), 1) @@ -297,16 +341,23 @@ def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride ) def forward(self, x): + """Processes input through conv and shortcut layers, returning their summed output.""" return self.conv(x) + self.shortcut(x) class Contract(nn.Module): # Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40) def __init__(self, gain=2): + """Initializes a layer to contract spatial dimensions (width-height) into channels, e.g., input shape + (1,64,80,80) to (1,256,40,40). + """ super().__init__() self.gain = gain def forward(self, x): + """Processes input tensor to expand channel dimensions by contracting spatial dimensions, yielding output shape + `(b, c*s*s, h//s, w//s)`. + """ b, c, h, w = x.size() # assert (h / s == 0) and (W / s == 0), 'Indivisible gain' s = self.gain x = x.view(b, c, h // s, s, w // s, s) # x(1,64,40,2,40,2) @@ -317,10 +368,19 @@ def forward(self, x): class Expand(nn.Module): # Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160) def __init__(self, gain=2): + """ + Initializes the Expand module to increase spatial dimensions by redistributing channels, with an optional gain + factor. + + Example: x(1,64,80,80) to x(1,16,160,160). + """ super().__init__() self.gain = gain def forward(self, x): + """Processes input tensor x to expand spatial dimensions by redistributing channels, requiring C / gain^2 == + 0. + """ b, c, h, w = x.size() # assert C / s ** 2 == 0, 'Indivisible gain' s = self.gain x = x.view(b, s, s, c // s**2, h, w) # x(1,2,2,16,80,80) @@ -331,17 +391,21 @@ def forward(self, x): class Concat(nn.Module): # Concatenate a list of tensors along dimension def __init__(self, dimension=1): + """Initializes a Concat module to concatenate tensors along a specified dimension.""" super().__init__() self.d = dimension def forward(self, x): + """Concatenates a list of tensors along a specified dimension; `x` is a list of tensors, `dimension` is an + int. + """ return torch.cat(x, self.d) class DetectMultiBackend(nn.Module): # YOLOv5 MultiBackend class for python inference on various backends def __init__(self, weights="yolov5s.pt", device=torch.device("cpu"), dnn=False, data=None, fp16=False, fuse=True): - # Usage: + """Initializes DetectMultiBackend with support for various inference backends, including PyTorch and ONNX.""" # PyTorch: weights = *.pt # TorchScript: *.torchscript # ONNX Runtime: *.onnx @@ -462,11 +526,13 @@ def __init__(self, weights="yolov5s.pt", device=torch.device("cpu"), dnn=False, import tensorflow as tf def wrap_frozen_graph(gd, inputs, outputs): + """Wraps a TensorFlow GraphDef for inference, returning a pruned function.""" x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped ge = x.graph.as_graph_element return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs)) def gd_outputs(gd): + """Generates a sorted list of graph outputs excluding NoOp nodes and inputs, formatted as ':0'.""" name_list, input_list = [], [] for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef name_list.append(node.name) @@ -540,7 +606,7 @@ def gd_outputs(gd): self.__dict__.update(locals()) # assign all variables to self def forward(self, im, augment=False, visualize=False): - # YOLOv5 MultiBackend inference + """Performs YOLOv5 inference on input images with options for augmentation and visualization.""" b, ch, h, w = im.shape # batch, channel, height, width if self.fp16 and im.dtype != torch.float16: im = im.half() # to FP16 @@ -622,10 +688,11 @@ def forward(self, im, augment=False, visualize=False): return self.from_numpy(y) def from_numpy(self, x): + """Converts a NumPy array to a torch tensor, maintaining device compatibility.""" return torch.from_numpy(x).to(self.device) if isinstance(x, np.ndarray) else x def warmup(self, imgsz=(1, 3, 640, 640)): - # Warmup model by running inference once + """Performs a single inference warmup to initialize model weights, accepting an `imgsz` tuple for image size.""" warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton if any(warmup_types) and (self.device.type != "cpu" or self.triton): im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input @@ -634,7 +701,11 @@ def warmup(self, imgsz=(1, 3, 640, 640)): @staticmethod def _model_type(p="path/to/model.pt"): - # Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx + """ + Determines model type from file path or URL, supporting various export formats. + + Example: path='path/to/model.onnx' -> type=onnx + """ # types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle] from export import export_formats from utils.downloads import is_url @@ -650,7 +721,7 @@ def _model_type(p="path/to/model.pt"): @staticmethod def _load_metadata(f=Path("path/to/meta.yaml")): - # Load metadata from meta.yaml if it exists + """Loads metadata from a YAML file, returning strides and names if the file exists, otherwise `None`.""" if f.exists(): d = yaml_load(f) return d["stride"], d["names"] # assign stride, names @@ -668,6 +739,7 @@ class AutoShape(nn.Module): amp = False # Automatic Mixed Precision (AMP) inference def __init__(self, model, verbose=True): + """Initializes YOLOv5 model for inference, setting up attributes and preparing model for evaluation.""" super().__init__() if verbose: LOGGER.info("Adding AutoShape... ") @@ -681,7 +753,11 @@ def __init__(self, model, verbose=True): m.export = True # do not output loss values def _apply(self, fn): - # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers + """ + Applies to(), cpu(), cuda(), half() etc. + + to model tensors excluding parameters or registered buffers. + """ self = super()._apply(fn) if self.pt: m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect() @@ -693,7 +769,12 @@ def _apply(self, fn): @smart_inference_mode() def forward(self, ims, size=640, augment=False, profile=False): - # Inference from various sources. For size(height=640, width=1280), RGB images example inputs are: + """ + Performs inference on inputs with optional augment & profiling. + + Supports various formats including file, URI, OpenCV, PIL, numpy, torch. + """ + # For size(height=640, width=1280), RGB images example inputs are: # file: ims = 'data/images/zidane.jpg' # str or PosixPath # URI: = 'https://ultralytics.com/images/zidane.jpg' # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3) @@ -761,6 +842,7 @@ def forward(self, ims, size=640, augment=False, profile=False): class Detections: # YOLOv5 detections class for inference results def __init__(self, ims, pred, files, times=(0, 0, 0), names=None, shape=None): + """Initializes the YOLOv5 Detections class with image info, predictions, filenames, timing and normalization.""" super().__init__() d = pred[0].device # device gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in ims] # normalizations @@ -778,6 +860,7 @@ def __init__(self, ims, pred, files, times=(0, 0, 0), names=None, shape=None): self.s = tuple(shape) # inference BCHW shape def _run(self, pprint=False, show=False, save=False, crop=False, render=False, labels=True, save_dir=Path("")): + """Executes model predictions, displaying and/or saving outputs with optional crops and labels.""" s, crops = "", [] for i, (im, pred) in enumerate(zip(self.ims, self.pred)): s += f"\nimage {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} " # string @@ -832,22 +915,42 @@ def _run(self, pprint=False, show=False, save=False, crop=False, render=False, l @TryExcept("Showing images is not supported in this environment") def show(self, labels=True): + """ + Displays detection results with optional labels. + + Usage: show(labels=True) + """ self._run(show=True, labels=labels) # show results def save(self, labels=True, save_dir="runs/detect/exp", exist_ok=False): + """ + Saves detection results with optional labels to a specified directory. + + Usage: save(labels=True, save_dir='runs/detect/exp', exist_ok=False) + """ save_dir = increment_path(save_dir, exist_ok, mkdir=True) # increment save_dir self._run(save=True, labels=labels, save_dir=save_dir) # save results def crop(self, save=True, save_dir="runs/detect/exp", exist_ok=False): + """ + Crops detection results, optionally saves them to a directory. + + Args: save (bool), save_dir (str), exist_ok (bool). + """ save_dir = increment_path(save_dir, exist_ok, mkdir=True) if save else None return self._run(crop=True, save=save, save_dir=save_dir) # crop results def render(self, labels=True): + """Renders detection results with optional labels on images; args: labels (bool) indicating label inclusion.""" self._run(render=True, labels=labels) # render results return self.ims def pandas(self): - # return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0]) + """ + Returns detections as pandas DataFrames for various box formats (xyxy, xyxyn, xywh, xywhn). + + Example: print(results.pandas().xyxy[0]). + """ new = copy(self) # return copy ca = "xmin", "ymin", "xmax", "ymax", "confidence", "class", "name" # xyxy columns cb = "xcenter", "ycenter", "width", "height", "confidence", "class", "name" # xywh columns @@ -857,7 +960,11 @@ def pandas(self): return new def tolist(self): - # return a list of Detections objects, i.e. 'for result in results.tolist():' + """ + Converts a Detections object into a list of individual detection results for iteration. + + Example: for result in results.tolist(): + """ r = range(self.n) # iterable return [ Detections( @@ -872,6 +979,7 @@ def tolist(self): ] def print(self): + """Logs the string representation of the current object's state via the LOGGER.""" LOGGER.info(self.__str__()) def __len__(self): # override len(results) @@ -881,6 +989,7 @@ def __str__(self): # override print(results) return self._run(pprint=True) # print results def __repr__(self): + """Returns a string representation of the YOLOv5 object, including its class and formatted results.""" return f"YOLOv5 {self.__class__} instance\n" + self.__str__() @@ -894,6 +1003,7 @@ def __init__(self, c1, c_=256, c2=32): # ch_in, number of protos, number of mas self.cv3 = Conv(c_, c2) def forward(self, x): + """Performs a forward pass using convolutional layers and upsampling on input tensor `x`.""" return self.cv3(self.cv2(self.upsample(self.cv1(x)))) @@ -910,6 +1020,7 @@ def __init__( self.linear = nn.Linear(c_, c2) # to x(b,c2) def forward(self, x): + """Processes input through conv, pool, drop, and linear layers; supports list concatenation input.""" if isinstance(x, list): x = torch.cat(x, 1) return self.linear(self.drop(self.pool(self.conv(x)).flatten(1))) diff --git a/models/experimental.py b/models/experimental.py index c242364bdec5..ab229d50e30f 100644 --- a/models/experimental.py +++ b/models/experimental.py @@ -19,6 +19,7 @@ def __init__(self, n, weight=False): # n: number of inputs self.w = nn.Parameter(-torch.arange(1.0, n) / 2, requires_grad=True) # layer weights def forward(self, x): + """Processes input through a customizable weighted sum of `n` inputs, optionally applying learned weights.""" y = x[0] # no weight if self.weight: w = torch.sigmoid(self.w) * 2 @@ -53,15 +54,21 @@ def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): # ch_in, ch_out, kern self.act = nn.SiLU() def forward(self, x): + """Performs forward pass by applying SiLU activation on batch-normalized concatenated convolutional layer + outputs. + """ return self.act(self.bn(torch.cat([m(x) for m in self.m], 1))) class Ensemble(nn.ModuleList): - # Ensemble of models + """Ensemble of models.""" + def __init__(self): + """Initializes an ensemble of models to be used for aggregated predictions.""" super().__init__() def forward(self, x, augment=False, profile=False, visualize=False): + """Performs forward pass aggregating outputs from an ensemble of models..""" y = [module(x, augment, profile, visualize)[0] for module in self] # y = torch.stack(y).max(0)[0] # max ensemble # y = torch.stack(y).mean(0) # mean ensemble @@ -70,7 +77,11 @@ def forward(self, x, augment=False, profile=False, visualize=False): def attempt_load(weights, device=None, inplace=True, fuse=True): - # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a + """ + Loads and fuses an ensemble or single YOLOv5 model from weights, handling device placement and model adjustments. + + Example inputs: weights=[a,b,c] or a single model weights=[a] or weights=a. + """ from models.yolo import Detect, Model model = Ensemble() diff --git a/models/tf.py b/models/tf.py index 53520b52c086..006a66d2b0f6 100644 --- a/models/tf.py +++ b/models/tf.py @@ -51,6 +51,7 @@ class TFBN(keras.layers.Layer): # TensorFlow BatchNormalization wrapper def __init__(self, w=None): + """Initializes a TensorFlow BatchNormalization layer with optional pretrained weights.""" super().__init__() self.bn = keras.layers.BatchNormalization( beta_initializer=keras.initializers.Constant(w.bias.numpy()), @@ -61,12 +62,19 @@ def __init__(self, w=None): ) def call(self, inputs): + """Applies batch normalization to the inputs.""" return self.bn(inputs) class TFPad(keras.layers.Layer): # Pad inputs in spatial dimensions 1 and 2 def __init__(self, pad): + """ + Initializes a padding layer for spatial dimensions 1 and 2 with specified padding, supporting both int and tuple + inputs. + + Inputs are + """ super().__init__() if isinstance(pad, int): self.pad = tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]]) @@ -74,13 +82,19 @@ def __init__(self, pad): self.pad = tf.constant([[0, 0], [pad[0], pad[0]], [pad[1], pad[1]], [0, 0]]) def call(self, inputs): + """Pads input tensor with zeros using specified padding, suitable for int and tuple pad dimensions.""" return tf.pad(inputs, self.pad, mode="constant", constant_values=0) class TFConv(keras.layers.Layer): # Standard convolution def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None): - # ch_in, ch_out, weights, kernel, stride, padding, groups + """ + Initializes a standard convolution layer with optional batch normalization and activation; supports only + group=1. + + Inputs are ch_in, ch_out, weights, kernel, stride, padding, groups. + """ super().__init__() assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument" # TensorFlow convolution padding is inconsistent with PyTorch (e.g. k=3 s=2 'SAME' padding) @@ -99,13 +113,19 @@ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None): self.act = activations(w.act) if act else tf.identity def call(self, inputs): + """Applies convolution, batch normalization, and activation function to input tensors.""" return self.act(self.bn(self.conv(inputs))) class TFDWConv(keras.layers.Layer): # Depthwise convolution def __init__(self, c1, c2, k=1, s=1, p=None, act=True, w=None): - # ch_in, ch_out, weights, kernel, stride, padding, groups + """ + Initializes a depthwise convolution layer with optional batch normalization and activation for TensorFlow + models. + + Input are ch_in, ch_out, weights, kernel, stride, padding, groups. + """ super().__init__() assert c2 % c1 == 0, f"TFDWConv() output={c2} must be a multiple of input={c1} channels" conv = keras.layers.DepthwiseConv2D( @@ -122,13 +142,18 @@ def __init__(self, c1, c2, k=1, s=1, p=None, act=True, w=None): self.act = activations(w.act) if act else tf.identity def call(self, inputs): + """Applies convolution, batch normalization, and activation function to input tensors.""" return self.act(self.bn(self.conv(inputs))) class TFDWConvTranspose2d(keras.layers.Layer): # Depthwise ConvTranspose2d def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0, w=None): - # ch_in, ch_out, weights, kernel, stride, padding, groups + """ + Initializes depthwise ConvTranspose2D layer with specific channel, kernel, stride, and padding settings. + + Inputs are ch_in, ch_out, weights, kernel, stride, padding, groups. + """ super().__init__() assert c1 == c2, f"TFDWConv() output={c2} must be equal to input={c1} channels" assert k == 4 and p1 == 1, "TFDWConv() only valid for k=4 and p1=1" @@ -149,13 +174,19 @@ def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0, w=None): ] def call(self, inputs): + """Processes input through parallel convolutions and concatenates results, trimming border pixels.""" return tf.concat([m(x) for m, x in zip(self.conv, tf.split(inputs, self.c1, 3))], 3)[:, 1:-1, 1:-1] class TFFocus(keras.layers.Layer): # Focus wh information into c-space def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None): - # ch_in, ch_out, kernel, stride, padding, groups + """ + Initializes TFFocus layer to focus width and height information into channel space with custom convolution + parameters. + + Inputs are ch_in, ch_out, kernel, stride, padding, groups. + """ super().__init__() self.conv = TFConv(c1 * 4, c2, k, s, p, g, act, w.conv) @@ -175,12 +206,16 @@ def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, w=None): # ch_in, ch_out, self.add = shortcut and c1 == c2 def call(self, inputs): + """Performs forward pass; if shortcut is True & input/output channels match, adds input to the convolution + result. + """ return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs)) class TFCrossConv(keras.layers.Layer): # Cross Convolution def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False, w=None): + """Initializes cross convolution layer with optional expansion, grouping, and shortcut addition capabilities.""" super().__init__() c_ = int(c2 * e) # hidden channels self.cv1 = TFConv(c1, c_, (1, k), (1, s), w=w.cv1) @@ -188,12 +223,16 @@ def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False, w=None): self.add = shortcut and c1 == c2 def call(self, inputs): + """Passes input through two convolutions optionally adding the input if channel dimensions match.""" return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs)) class TFConv2d(keras.layers.Layer): # Substitution for PyTorch nn.Conv2D def __init__(self, c1, c2, k, s=1, g=1, bias=True, w=None): + """Initializes a TensorFlow 2D convolution layer, mimicking PyTorch's nn.Conv2D functionality for given filter + sizes and stride. + """ super().__init__() assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument" self.conv = keras.layers.Conv2D( @@ -207,13 +246,19 @@ def __init__(self, c1, c2, k, s=1, g=1, bias=True, w=None): ) def call(self, inputs): + """Applies a convolution operation to the inputs and returns the result.""" return self.conv(inputs) class TFBottleneckCSP(keras.layers.Layer): # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None): - # ch_in, ch_out, number, shortcut, groups, expansion + """ + Initializes CSP bottleneck layer with specified channel sizes, count, shortcut option, groups, and expansion + ratio. + + Inputs are ch_in, ch_out, number, shortcut, groups, expansion. + """ super().__init__() c_ = int(c2 * e) # hidden channels self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1) @@ -225,6 +270,9 @@ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None): self.m = keras.Sequential([TFBottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)]) def call(self, inputs): + """Processes input through the model layers, concatenates, normalizes, activates, and reduces the output + dimensions. + """ y1 = self.cv3(self.m(self.cv1(inputs))) y2 = self.cv2(inputs) return self.cv4(self.act(self.bn(tf.concat((y1, y2), axis=3)))) @@ -233,7 +281,11 @@ def call(self, inputs): class TFC3(keras.layers.Layer): # CSP Bottleneck with 3 convolutions def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None): - # ch_in, ch_out, number, shortcut, groups, expansion + """ + Initializes CSP Bottleneck with 3 convolutions, supporting optional shortcuts and group convolutions. + + Inputs are ch_in, ch_out, number, shortcut, groups, expansion. + """ super().__init__() c_ = int(c2 * e) # hidden channels self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1) @@ -242,13 +294,22 @@ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None): self.m = keras.Sequential([TFBottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)]) def call(self, inputs): + """ + Processes input through a sequence of transformations for object detection (YOLOv5). + + See https://github.com/ultralytics/yolov5. + """ return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3)) class TFC3x(keras.layers.Layer): # 3 module with cross-convolutions def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None): - # ch_in, ch_out, number, shortcut, groups, expansion + """ + Initializes layer with cross-convolutions for enhanced feature extraction in object detection models. + + Inputs are ch_in, ch_out, number, shortcut, groups, expansion. + """ super().__init__() c_ = int(c2 * e) # hidden channels self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1) @@ -259,12 +320,14 @@ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None): ) def call(self, inputs): + """Processes input through cascaded convolutions and merges features, returning the final tensor output.""" return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3)) class TFSPP(keras.layers.Layer): # Spatial pyramid pooling layer used in YOLOv3-SPP def __init__(self, c1, c2, k=(5, 9, 13), w=None): + """Initializes a YOLOv3-SPP layer with specific input/output channels and kernel sizes for pooling.""" super().__init__() c_ = c1 // 2 # hidden channels self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1) @@ -272,6 +335,7 @@ def __init__(self, c1, c2, k=(5, 9, 13), w=None): self.m = [keras.layers.MaxPool2D(pool_size=x, strides=1, padding="SAME") for x in k] def call(self, inputs): + """Processes input through two TFConv layers and concatenates with max-pooled outputs at intermediate stage.""" x = self.cv1(inputs) return self.cv2(tf.concat([x] + [m(x) for m in self.m], 3)) @@ -279,6 +343,9 @@ def call(self, inputs): class TFSPPF(keras.layers.Layer): # Spatial pyramid pooling-Fast layer def __init__(self, c1, c2, k=5, w=None): + """Initializes a fast spatial pyramid pooling layer with customizable in/out channels, kernel size, and + weights. + """ super().__init__() c_ = c1 // 2 # hidden channels self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1) @@ -286,6 +353,9 @@ def __init__(self, c1, c2, k=5, w=None): self.m = keras.layers.MaxPool2D(pool_size=k, strides=1, padding="SAME") def call(self, inputs): + """Executes the model's forward pass, concatenating input features with three max-pooled versions before final + convolution. + """ x = self.cv1(inputs) y1 = self.m(x) y2 = self.m(y1) @@ -312,6 +382,7 @@ def __init__(self, nc=80, anchors=(), ch=(), imgsz=(640, 640), w=None): # detec self.grid[i] = self._make_grid(nx, ny) def call(self, inputs): + """Performs forward pass through the model layers to predict object bounding boxes and classifications.""" z = [] # inference output x = [] for i in range(self.nl): @@ -336,7 +407,7 @@ def call(self, inputs): @staticmethod def _make_grid(nx=20, ny=20): - # yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)]) + """Generates a 2D grid of coordinates in (x, y) format with shape [1, 1, ny*nx, 2].""" # return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float() xv, yv = tf.meshgrid(tf.range(nx), tf.range(ny)) return tf.cast(tf.reshape(tf.stack([xv, yv], 2), [1, 1, ny * nx, 2]), dtype=tf.float32) @@ -345,6 +416,9 @@ def _make_grid(nx=20, ny=20): class TFSegment(TFDetect): # YOLOv5 Segment head for segmentation models def __init__(self, nc=80, anchors=(), nm=32, npr=256, ch=(), imgsz=(640, 640), w=None): + """Initializes YOLOv5 Segment head with specified channel depths, anchors, and input size for segmentation + models. + """ super().__init__(nc, anchors, ch, imgsz, w) self.nm = nm # number of masks self.npr = npr # number of protos @@ -354,6 +428,7 @@ def __init__(self, nc=80, anchors=(), nm=32, npr=256, ch=(), imgsz=(640, 640), w self.detect = TFDetect.call def call(self, x): + """Applies detection and proto layers on input, returning detections and optionally protos if training.""" p = self.proto(x[0]) # p = TFUpsample(None, scale_factor=4, mode='nearest')(self.proto(x[0])) # (optional) full-size protos p = tf.transpose(p, [0, 3, 1, 2]) # from shape(1,160,160,32) to shape(1,32,160,160) @@ -363,6 +438,9 @@ def call(self, x): class TFProto(keras.layers.Layer): def __init__(self, c1, c_=256, c2=32, w=None): + """Initializes TFProto layer with convolutional and upsampling layers for feature extraction and + transformation. + """ super().__init__() self.cv1 = TFConv(c1, c_, k=3, w=w.cv1) self.upsample = TFUpsample(None, scale_factor=2, mode="nearest") @@ -370,6 +448,7 @@ def __init__(self, c1, c_=256, c2=32, w=None): self.cv3 = TFConv(c_, c2, w=w.cv3) def call(self, inputs): + """Performs forward pass through the model, applying convolutions and upscaling on input tensor.""" return self.cv3(self.cv2(self.upsample(self.cv1(inputs)))) @@ -385,17 +464,20 @@ def __init__(self, size, scale_factor, mode, w=None): # warning: all arguments # size=(x.shape[1] * 2, x.shape[2] * 2)) def call(self, inputs): + """Applies upsample operation to inputs using nearest neighbor interpolation.""" return self.upsample(inputs) class TFConcat(keras.layers.Layer): # TF version of torch.concat() def __init__(self, dimension=1, w=None): + """Initializes a TensorFlow layer for NCHW to NHWC concatenation, requiring dimension=1.""" super().__init__() assert dimension == 1, "convert only NCHW to NHWC concat" self.d = 3 def call(self, inputs): + """Concatenates a list of tensors along the last dimension, used for NCHW to NHWC conversion.""" return tf.concat(inputs, self.d) @@ -539,7 +621,9 @@ def predict( @staticmethod def _xywh2xyxy(xywh): - # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right + """Converts bounding box format from [x, y, w, h] to [x1, y1, x2, y2], where xy1=top-left and xy2=bottom- + right. + """ x, y, w, h = tf.split(xywh, num_or_size_splits=4, axis=-1) return tf.concat([x - w / 2, y - h / 2, x + w / 2, y + h / 2], axis=-1) @@ -547,7 +631,7 @@ def _xywh2xyxy(xywh): class AgnosticNMS(keras.layers.Layer): # TF Agnostic NMS def call(self, input, topk_all, iou_thres, conf_thres): - # wrap map_fn to avoid TypeSpec related error https://stackoverflow.com/a/65809989/3036450 + """Performs agnostic NMS on input tensors using given thresholds and top-K selection.""" return tf.map_fn( lambda x: self._nms(x, topk_all, iou_thres, conf_thres), input, @@ -589,7 +673,7 @@ def _nms(x, topk_all=100, iou_thres=0.45, conf_thres=0.25): # agnostic NMS def activations(act=nn.SiLU): - # Returns TF activation from input PyTorch activation + """Converts PyTorch activations to TensorFlow equivalents, supporting LeakyReLU, Hardswish, and SiLU/Swish.""" if isinstance(act, nn.LeakyReLU): return lambda x: keras.activations.relu(x, alpha=0.1) elif isinstance(act, nn.Hardswish): @@ -601,7 +685,9 @@ def activations(act=nn.SiLU): def representative_dataset_gen(dataset, ncalib=100): - # Representative dataset generator for use with converter.representative_dataset, returns a generator of np arrays + """Generates a representative dataset for calibration by yielding transformed numpy arrays from the input + dataset. + """ for n, (path, img, im0s, vid_cap, string) in enumerate(dataset): im = np.transpose(img, [1, 2, 0]) im = np.expand_dims(im, axis=0).astype(np.float32) @@ -637,6 +723,9 @@ def run( def parse_opt(): + """Parses and returns command-line options for model inference, including weights path, image size, batch size, and + dynamic batching. + """ parser = argparse.ArgumentParser() parser.add_argument("--weights", type=str, default=ROOT / "yolov5s.pt", help="weights path") parser.add_argument("--imgsz", "--img", "--img-size", nargs="+", type=int, default=[640], help="inference size h,w") @@ -649,6 +738,7 @@ def parse_opt(): def main(opt): + """Executes the YOLOv5 model run function with parsed command line options.""" run(**vars(opt)) diff --git a/models/yolo.py b/models/yolo.py index e98351b98691..ef6c1015f41e 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -88,6 +88,7 @@ def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer self.inplace = inplace # use inplace ops (e.g. slice assignment) def forward(self, x): + """Processes input through YOLOv5 layers, altering shape for detection: `x(bs, 3, ny, nx, 85)`.""" z = [] # inference output for i in range(self.nl): x[i] = self.m[i](x[i]) # conv @@ -113,6 +114,7 @@ def forward(self, x): return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x) def _make_grid(self, nx=20, ny=20, i=0, torch_1_10=check_version(torch.__version__, "1.10.0")): + """Generates a mesh grid for anchor boxes with optional compatibility for torch versions < 1.10.""" d = self.anchors[i].device t = self.anchors[i].dtype shape = 1, self.na, ny, nx, 2 # grid shape @@ -126,6 +128,7 @@ def _make_grid(self, nx=20, ny=20, i=0, torch_1_10=check_version(torch.__version class Segment(Detect): # YOLOv5 Segment head for segmentation models def __init__(self, nc=80, anchors=(), nm=32, npr=256, ch=(), inplace=True): + """Initializes YOLOv5 Segment head with options for mask count, protos, and channel adjustments.""" super().__init__(nc, anchors, ch, inplace) self.nm = nm # number of masks self.npr = npr # number of protos @@ -135,17 +138,25 @@ def __init__(self, nc=80, anchors=(), nm=32, npr=256, ch=(), inplace=True): self.detect = Detect.forward def forward(self, x): + """Processes input through the network, returning detections and prototypes; adjusts output based on + training/export mode. + """ p = self.proto(x[0]) x = self.detect(self, x) return (x, p) if self.training else (x[0], p) if self.export else (x[0], p, x[1]) class BaseModel(nn.Module): - # YOLOv5 base model + """YOLOv5 base model.""" + def forward(self, x, profile=False, visualize=False): + """Executes a single-scale inference or training pass on the YOLOv5 base model, with options for profiling and + visualization. + """ return self._forward_once(x, profile, visualize) # single-scale inference, train def _forward_once(self, x, profile=False, visualize=False): + """Performs a forward pass on the YOLOv5 model, enabling profiling and feature visualization options.""" y, dt = [], [] # outputs for m in self.model: if m.f != -1: # if not from previous layer @@ -159,6 +170,7 @@ def _forward_once(self, x, profile=False, visualize=False): return x def _profile_one_layer(self, m, x, dt): + """Profiles a single layer's performance by computing GFLOPs, execution time, and parameters.""" c = m == self.model[-1] # is final layer, copy input as inplace fix o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1e9 * 2 if thop else 0 # FLOPs t = time_sync() @@ -185,7 +197,9 @@ def info(self, verbose=False, img_size=640): # print model information model_info(self, verbose, img_size) def _apply(self, fn): - # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers + """Applies transformations like to(), cpu(), cuda(), half() to model tensors excluding parameters or registered + buffers. + """ self = super()._apply(fn) m = self.model[-1] # Detect() if isinstance(m, (Detect, Segment)): @@ -239,11 +253,13 @@ def __init__(self, cfg="yolov5s.yaml", ch=3, nc=None, anchors=None): # model, i LOGGER.info("") def forward(self, x, augment=False, profile=False, visualize=False): + """Performs single-scale or augmented inference and may include profiling or visualization.""" if augment: return self._forward_augment(x) # augmented inference, None return self._forward_once(x, profile, visualize) # single-scale inference, train def _forward_augment(self, x): + """Performs augmented inference across different scales and flips, returning combined detections.""" img_size = x.shape[-2:] # height, width s = [1, 0.83, 0.67] # scales f = [None, 3, None] # flips (2-ud, 3-lr) @@ -258,7 +274,7 @@ def _forward_augment(self, x): return torch.cat(y, 1), None # augmented inference, train def _descale_pred(self, p, flips, scale, img_size): - # de-scale predictions following augmented inference (inverse operation) + """De-scales predictions from augmented inference, adjusting for flips and image size.""" if self.inplace: p[..., :4] /= scale # de-scale if flips == 2: @@ -275,7 +291,9 @@ def _descale_pred(self, p, flips, scale, img_size): return p def _clip_augmented(self, y): - # Clip YOLOv5 augmented inference tails + """Clips augmented inference tails for YOLOv5 models, affecting first and last tensors based on grid points and + layer counts. + """ nl = self.model[-1].nl # number of detection layers (P3-P5) g = sum(4**x for x in range(nl)) # grid points e = 1 # exclude layer count @@ -304,6 +322,7 @@ def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class SegmentationModel(DetectionModel): # YOLOv5 segmentation model def __init__(self, cfg="yolov5s-seg.yaml", ch=3, nc=None, anchors=None): + """Initializes a YOLOv5 segmentation model with configurable params: cfg (str) for configuration, ch (int) for channels, nc (int) for num classes, anchors (list).""" super().__init__(cfg, ch, nc, anchors) @@ -314,7 +333,9 @@ def __init__(self, cfg=None, model=None, nc=1000, cutoff=10): # yaml, model, nu self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg) def _from_detection_model(self, model, nc=1000, cutoff=10): - # Create a YOLOv5 classification model from a YOLOv5 detection model + """Creates a classification model from a YOLOv5 detection model, slicing at `cutoff` and adding a classification + layer. + """ if isinstance(model, DetectMultiBackend): model = model.model # unwrap DetectMultiBackend model.model = model.model[:cutoff] # backbone @@ -329,7 +350,7 @@ def _from_detection_model(self, model, nc=1000, cutoff=10): self.nc = nc def _from_yaml(self, cfg): - # Create a YOLOv5 classification model from a *.yaml file + """Creates a YOLOv5 classification model from a specified *.yaml configuration file.""" self.model = None diff --git a/segment/predict.py b/segment/predict.py index 23a4e3538509..bea9bfe2f21c 100644 --- a/segment/predict.py +++ b/segment/predict.py @@ -257,6 +257,9 @@ def run( def parse_opt(): + """Parses command-line options for YOLOv5 inference including model paths, data sources, inference settings, and + output preferences. + """ parser = argparse.ArgumentParser() parser.add_argument("--weights", nargs="+", type=str, default=ROOT / "yolov5s-seg.pt", help="model path(s)") parser.add_argument("--source", type=str, default=ROOT / "data/images", help="file/dir/URL/glob/screen/0(webcam)") @@ -293,6 +296,7 @@ def parse_opt(): def main(opt): + """Executes YOLOv5 model inference with given options, checking for requirements before launching.""" check_requirements(ROOT / "requirements.txt", exclude=("tensorboard", "thop")) run(**vars(opt)) diff --git a/segment/train.py b/segment/train.py index fe262348fae4..ce59df9c635b 100644 --- a/segment/train.py +++ b/segment/train.py @@ -532,6 +532,11 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio def parse_opt(known=False): + """ + Parses command line arguments for training configurations, returning parsed arguments. + + Supports both known and unknown args. + """ parser = argparse.ArgumentParser() parser.add_argument("--weights", type=str, default=ROOT / "yolov5s-seg.pt", help="initial weights path") parser.add_argument("--cfg", type=str, default="", help="model.yaml path") @@ -576,7 +581,7 @@ def parse_opt(known=False): def main(opt, callbacks=Callbacks()): - # Checks + """Initializes training or evolution of YOLOv5 models based on provided configuration and options.""" if RANK in {-1, 0}: print_args(vars(opt)) check_git_status() @@ -733,7 +738,11 @@ def main(opt, callbacks=Callbacks()): def run(**kwargs): - # Usage: import train; train.run(data='coco128.yaml', imgsz=320, weights='yolov5m.pt') + """ + Executes YOLOv5 training with given parameters, altering options programmatically; returns updated options. + + Example: mport train; train.run(data='coco128.yaml', imgsz=320, weights='yolov5m.pt') + """ opt = parse_opt(True) for k, v in kwargs.items(): setattr(opt, k, v) diff --git a/segment/val.py b/segment/val.py index 1e5159c710ed..bafdb5dcec07 100644 --- a/segment/val.py +++ b/segment/val.py @@ -71,7 +71,9 @@ def save_one_txt(predn, save_conf, shape, file): - # Save one txt result + """Saves detection results in txt format; includes class, xywh (normalized), optionally confidence if `save_conf` is + True. + """ gn = torch.tensor(shape)[[1, 0, 1, 0]] # normalization gain whwh for *xyxy, conf, cls in predn.tolist(): xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh @@ -81,7 +83,11 @@ def save_one_txt(predn, save_conf, shape, file): def save_one_json(predn, jdict, path, class_map, pred_masks): - # Save one JSON result {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236} + """ + Saves a JSON file with detection results including bounding boxes, category IDs, scores, and segmentation masks. + + Example JSON result: {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}. + """ from pycocotools.mask import encode def single_encode(x): @@ -437,6 +443,9 @@ def run( def parse_opt(): + """Parses command line arguments for configuring YOLOv5 options like dataset path, weights, batch size, and + inference settings. + """ parser = argparse.ArgumentParser() parser.add_argument("--data", type=str, default=ROOT / "data/coco128-seg.yaml", help="dataset.yaml path") parser.add_argument("--weights", nargs="+", type=str, default=ROOT / "yolov5s-seg.pt", help="model path(s)") @@ -469,6 +478,7 @@ def parse_opt(): def main(opt): + """Executes YOLOv5 tasks including training, validation, testing, speed, and study with configurable options.""" check_requirements(ROOT / "requirements.txt", exclude=("tensorboard", "thop")) if opt.task in ("train", "val", "test"): # run normally diff --git a/train.py b/train.py index ca284d06df25..3f2f64385c90 100644 --- a/train.py +++ b/train.py @@ -505,6 +505,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio def parse_opt(known=False): + """Parses command-line arguments for YOLOv5 training, validation, and testing.""" parser = argparse.ArgumentParser() parser.add_argument("--weights", type=str, default=ROOT / "yolov5s.pt", help="initial weights path") parser.add_argument("--cfg", type=str, default="", help="model.yaml path") @@ -559,7 +560,7 @@ def parse_opt(known=False): def main(opt, callbacks=Callbacks()): - # Checks + """Runs training or hyperparameter evolution with specified options and optional callbacks.""" if RANK in {-1, 0}: print_args(vars(opt)) check_git_status() @@ -815,6 +816,7 @@ def main(opt, callbacks=Callbacks()): def generate_individual(input_ranges, individual_length): + """Generates a list of random values within specified input ranges for each gene in the individual.""" individual = [] for i in range(individual_length): lower_bound, upper_bound = input_ranges[i] @@ -823,7 +825,11 @@ def generate_individual(input_ranges, individual_length): def run(**kwargs): - # Usage: import train; train.run(data='coco128.yaml', imgsz=320, weights='yolov5m.pt') + """ + Executes YOLOv5 training with given options, overriding with any kwargs provided. + + Example: import train; train.run(data='coco128.yaml', imgsz=320, weights='yolov5m.pt') + """ opt = parse_opt(True) for k, v in kwargs.items(): setattr(opt, k, v) diff --git a/utils/__init__.py b/utils/__init__.py index eff756e2b90e..0b7e1fdfc31a 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,85 +1,95 @@ -# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license -"""utils/initialization.""" - -import contextlib -import platform -import threading - - -def emojis(str=""): - # Return platform-dependent emoji-safe version of string - return str.encode().decode("ascii", "ignore") if platform.system() == "Windows" else str - - -class TryExcept(contextlib.ContextDecorator): - # YOLOv5 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager - def __init__(self, msg=""): - self.msg = msg - - def __enter__(self): - pass - - def __exit__(self, exc_type, value, traceback): - if value: - print(emojis(f"{self.msg}{': ' if self.msg else ''}{value}")) - return True - - -def threaded(func): - # Multi-threads a target function and returns thread. Usage: @threaded decorator - def wrapper(*args, **kwargs): - thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True) - thread.start() - return thread - - return wrapper - - -def join_threads(verbose=False): - # Join all daemon threads, i.e. atexit.register(lambda: join_threads()) - main_thread = threading.current_thread() - for t in threading.enumerate(): - if t is not main_thread: - if verbose: - print(f"Joining thread {t.name}") - t.join() - - -def notebook_init(verbose=True): - # Check system software and hardware - print("Checking setup...") - - import os - import shutil - - from ultralytics.utils.checks import check_requirements - - from utils.general import check_font, is_colab - from utils.torch_utils import select_device # imports - - check_font() - - import psutil - - if check_requirements("wandb", install=False): - os.system("pip uninstall -y wandb") # eliminate unexpected account creation prompt with infinite hang - if is_colab(): - shutil.rmtree("/content/sample_data", ignore_errors=True) # remove colab /sample_data directory - - # System info - display = None - if verbose: - gb = 1 << 30 # bytes to GiB (1024 ** 3) - ram = psutil.virtual_memory().total - total, used, free = shutil.disk_usage("/") - with contextlib.suppress(Exception): # clear display if ipython is installed - from IPython import display - - display.clear_output() - s = f"({os.cpu_count()} CPUs, {ram / gb:.1f} GB RAM, {(total - free) / gb:.1f}/{total / gb:.1f} GB disk)" - else: - s = "" - - select_device(newline=False) - print(emojis(f"Setup complete ✅ {s}")) - return display +# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license +"""utils/initialization.""" + +import contextlib +import platform +import threading + + +def emojis(str=""): + """Returns an emoji-safe version of a string, stripped of emojis on Windows platforms.""" + return str.encode().decode("ascii", "ignore") if platform.system() == "Windows" else str + + +class TryExcept(contextlib.ContextDecorator): + # YOLOv5 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager + def __init__(self, msg=""): + """Initializes TryExcept with an optional message, used as a decorator or context manager for error handling.""" + self.msg = msg + + def __enter__(self): + """Enter the runtime context related to this object for error handling with an optional message.""" + pass + + def __exit__(self, exc_type, value, traceback): + """Context manager exit method that prints an error message with emojis if an exception occurred, always returns + True. + """ + if value: + print(emojis(f"{self.msg}{': ' if self.msg else ''}{value}")) + return True + + +def threaded(func): + """Decorator @threaded to run a function in a separate thread, returning the thread instance.""" + + def wrapper(*args, **kwargs): + thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True) + thread.start() + return thread + + return wrapper + + +def join_threads(verbose=False): + """ + Joins all daemon threads, optionally printing their names if verbose is True. + + Example: atexit.register(lambda: join_threads()) + """ + main_thread = threading.current_thread() + for t in threading.enumerate(): + if t is not main_thread: + if verbose: + print(f"Joining thread {t.name}") + t.join() + + +def notebook_init(verbose=True): + """Initializes notebook environment by checking requirements, cleaning up, and displaying system info.""" + print("Checking setup...") + + import os + import shutil + + from ultralytics.utils.checks import check_requirements + + from utils.general import check_font, is_colab + from utils.torch_utils import select_device # imports + + check_font() + + import psutil + + if check_requirements("wandb", install=False): + os.system("pip uninstall -y wandb") # eliminate unexpected account creation prompt with infinite hang + if is_colab(): + shutil.rmtree("/content/sample_data", ignore_errors=True) # remove colab /sample_data directory + + # System info + display = None + if verbose: + gb = 1 << 30 # bytes to GiB (1024 ** 3) + ram = psutil.virtual_memory().total + total, used, free = shutil.disk_usage("/") + with contextlib.suppress(Exception): # clear display if ipython is installed + from IPython import display + + display.clear_output() + s = f"({os.cpu_count()} CPUs, {ram / gb:.1f} GB RAM, {(total - free) / gb:.1f}/{total / gb:.1f} GB disk)" + else: + s = "" + + select_device(newline=False) + print(emojis(f"Setup complete ✅ {s}")) + return display diff --git a/utils/activations.py b/utils/activations.py index 616002f06a73..6218eb58440a 100644 --- a/utils/activations.py +++ b/utils/activations.py @@ -7,43 +7,54 @@ class SiLU(nn.Module): - # SiLU activation https://arxiv.org/pdf/1606.08415.pdf @staticmethod def forward(x): + """ + Applies the Sigmoid-weighted Linear Unit (SiLU) activation function. + + https://arxiv.org/pdf/1606.08415.pdf. + """ return x * torch.sigmoid(x) class Hardswish(nn.Module): - # Hard-SiLU activation @staticmethod def forward(x): - # return x * F.hardsigmoid(x) # for TorchScript and CoreML + """ + Applies the Hardswish activation function, compatible with TorchScript, CoreML, and ONNX. + + Equivalent to x * F.hardsigmoid(x) + """ return x * F.hardtanh(x + 3, 0.0, 6.0) / 6.0 # for TorchScript, CoreML and ONNX class Mish(nn.Module): - # Mish activation https://github.com/digantamisra98/Mish + """Mish activation https://github.com/digantamisra98/Mish.""" + @staticmethod def forward(x): + """Applies the Mish activation function, a smooth alternative to ReLU.""" return x * F.softplus(x).tanh() class MemoryEfficientMish(nn.Module): - # Mish activation memory-efficient class F(torch.autograd.Function): @staticmethod def forward(ctx, x): + """Applies the Mish activation function, a smooth ReLU alternative, to the input tensor `x`.""" ctx.save_for_backward(x) return x.mul(torch.tanh(F.softplus(x))) # x * tanh(ln(1 + exp(x))) @staticmethod def backward(ctx, grad_output): + """Computes the gradient of the Mish activation function with respect to input `x`.""" x = ctx.saved_tensors[0] sx = torch.sigmoid(x) fx = F.softplus(x).tanh() return grad_output * (fx + x * sx * (1 - fx * fx)) def forward(self, x): + """Applies the Mish activation function to the input tensor `x`.""" return self.F.apply(x) @@ -55,30 +66,41 @@ def __init__(self, c1, k=3): # ch_in, kernel self.bn = nn.BatchNorm2d(c1) def forward(self, x): + """ + Applies FReLU activation with max operation between input and BN-convolved input. + + https://arxiv.org/abs/2007.11824 + """ return torch.max(x, self.bn(self.conv(x))) class AconC(nn.Module): - r"""ACON activation (activate or not) + """ + ACON activation (activate or not) function. + AconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is a learnable parameter - according to "Activate or Not: Learning Customized Activation" . + See "Activate or Not: Learning Customized Activation" https://arxiv.org/pdf/2009.04759.pdf. """ def __init__(self, c1): + """Initializes AconC with learnable parameters p1, p2, and beta for channel-wise activation control.""" super().__init__() self.p1 = nn.Parameter(torch.randn(1, c1, 1, 1)) self.p2 = nn.Parameter(torch.randn(1, c1, 1, 1)) self.beta = nn.Parameter(torch.ones(1, c1, 1, 1)) def forward(self, x): + """Applies AconC activation function with learnable parameters for channel-wise control on input tensor x.""" dpx = (self.p1 - self.p2) * x return dpx * torch.sigmoid(self.beta * dpx) + self.p2 * x class MetaAconC(nn.Module): - r"""ACON activation (activate or not) - MetaAconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is generated by a small network - according to "Activate or Not: Learning Customized Activation" . + """ + ACON activation (activate or not) function. + + AconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is a learnable parameter + See "Activate or Not: Learning Customized Activation" https://arxiv.org/pdf/2009.04759.pdf. """ def __init__(self, c1, k=1, s=1, r=16): # ch_in, kernel, stride, r @@ -92,6 +114,7 @@ def __init__(self, c1, k=1, s=1, r=16): # ch_in, kernel, stride, r # self.bn2 = nn.BatchNorm2d(c1) def forward(self, x): + """Applies a forward pass transforming input `x` using learnable parameters and sigmoid activation.""" y = x.mean(dim=2, keepdims=True).mean(dim=3, keepdims=True) # batch-size 1 bug/instabilities https://github.com/ultralytics/yolov5/issues/2891 # beta = torch.sigmoid(self.bn2(self.fc2(self.bn1(self.fc1(y))))) # bug/unstable diff --git a/utils/augmentations.py b/utils/augmentations.py index b3b9524320d0..500e13248a06 100644 --- a/utils/augmentations.py +++ b/utils/augmentations.py @@ -20,6 +20,7 @@ class Albumentations: # YOLOv5 Albumentations class (optional, only used if package is installed) def __init__(self, size=640): + """Initializes Albumentations class for optional data augmentation in YOLOv5 with specified input size.""" self.transform = None prefix = colorstr("albumentations: ") try: @@ -46,6 +47,7 @@ def __init__(self, size=640): LOGGER.info(f"{prefix}{e}") def __call__(self, im, labels, p=1.0): + """Applies transformations to an image and labels with probability `p`, returning updated image and labels.""" if self.transform and random.random() < p: new = self.transform(image=im, bboxes=labels[:, 1:], class_labels=labels[:, 0]) # transformed im, labels = new["image"], np.array([[c, *b] for c, b in zip(new["class_labels"], new["bboxes"])]) @@ -53,19 +55,23 @@ def __call__(self, im, labels, p=1.0): def normalize(x, mean=IMAGENET_MEAN, std=IMAGENET_STD, inplace=False): - # Denormalize RGB images x per ImageNet stats in BCHW format, i.e. = (x - mean) / std + """ + Applies ImageNet normalization to RGB images in BCHW format, modifying them in-place if specified. + + Example: y = (x - mean) / std + """ return TF.normalize(x, mean, std, inplace=inplace) def denormalize(x, mean=IMAGENET_MEAN, std=IMAGENET_STD): - # Denormalize RGB images x per ImageNet stats in BCHW format, i.e. = x * std + mean + """Reverses ImageNet normalization for BCHW format RGB images by applying `x = x * std + mean`.""" for i in range(3): x[:, i] = x[:, i] * std[i] + mean[i] return x def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5): - # HSV color-space augmentation + """Applies HSV color-space augmentation to an image with random gains for hue, saturation, and value.""" if hgain or sgain or vgain: r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains hue, sat, val = cv2.split(cv2.cvtColor(im, cv2.COLOR_BGR2HSV)) @@ -81,7 +87,7 @@ def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5): def hist_equalize(im, clahe=True, bgr=False): - # Equalize histogram on BGR image 'im' with im.shape(n,m,3) and range 0-255 + """Equalizes image histogram, with optional CLAHE, for BGR or RGB image with shape (n,m,3) and range 0-255.""" yuv = cv2.cvtColor(im, cv2.COLOR_BGR2YUV if bgr else cv2.COLOR_RGB2YUV) if clahe: c = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) @@ -92,7 +98,11 @@ def hist_equalize(im, clahe=True, bgr=False): def replicate(im, labels): - # Replicate labels + """ + Replicates half of the smallest object labels in an image for data augmentation. + + Returns augmented image and labels. + """ h, w = im.shape[:2] boxes = labels[:, 1:].astype(int) x1, y1, x2, y2 = boxes.T @@ -109,7 +119,7 @@ def replicate(im, labels): def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32): - # Resize and pad image while meeting stride-multiple constraints + """Resizes and pads image to new_shape with stride-multiple constraints, returns resized image, ratio, padding.""" shape = im.shape[:2] # current shape [height, width] if isinstance(new_shape, int): new_shape = (new_shape, new_shape) @@ -232,7 +242,11 @@ def random_perspective( def copy_paste(im, labels, segments, p=0.5): - # Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy) + """ + Applies Copy-Paste augmentation by flipping and merging segments and labels on an image. + + Details at https://arxiv.org/abs/2012.07177. + """ n = len(segments) if p and n: h, w, c = im.shape # height, width, channels @@ -254,7 +268,11 @@ def copy_paste(im, labels, segments, p=0.5): def cutout(im, labels, p=0.5): - # Applies image cutout augmentation https://arxiv.org/abs/1708.04552 + """ + Applies cutout augmentation to an image with optional label adjustment, using random masks of varying sizes. + + Details at https://arxiv.org/abs/1708.04552. + """ if random.random() < p: h, w = im.shape[:2] scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction @@ -281,7 +299,11 @@ def cutout(im, labels, p=0.5): def mixup(im, labels, im2, labels2): - # Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf + """ + Applies MixUp augmentation by blending images and labels. + + See https://arxiv.org/pdf/1710.09412.pdf for details. + """ r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0 im = (im * r + im2 * (1 - r)).astype(np.uint8) labels = np.concatenate((labels, labels2), 0) @@ -341,7 +363,7 @@ def classify_albumentations( def classify_transforms(size=224): - # Transforms to apply if albumentations not installed + """Applies a series of transformations including center crop, ToTensor, and normalization for classification.""" assert isinstance(size, int), f"ERROR: classify_transforms size {size} must be integer, not (list, tuple)" # T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)]) return T.Compose([CenterCrop(size), ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)]) @@ -350,6 +372,9 @@ def classify_transforms(size=224): class LetterBox: # YOLOv5 LetterBox class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()]) def __init__(self, size=(640, 640), auto=False, stride=32): + """Initializes a LetterBox object for YOLOv5 image preprocessing with optional auto sizing and stride + adjustment. + """ super().__init__() self.h, self.w = (size, size) if isinstance(size, int) else size self.auto = auto # pass max size integer, automatically solve for short side using stride @@ -369,6 +394,7 @@ def __call__(self, im): # im = np.array HWC class CenterCrop: # YOLOv5 CenterCrop class for image preprocessing, i.e. T.Compose([CenterCrop(size), ToTensor()]) def __init__(self, size=640): + """Initializes CenterCrop for image preprocessing, accepting single int or tuple for size, defaults to 640.""" super().__init__() self.h, self.w = (size, size) if isinstance(size, int) else size @@ -382,6 +408,7 @@ def __call__(self, im): # im = np.array HWC class ToTensor: # YOLOv5 ToTensor class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()]) def __init__(self, half=False): + """Initializes ToTensor for YOLOv5 image preprocessing, with optional half precision (half=True for FP16).""" super().__init__() self.half = half diff --git a/utils/autoanchor.py b/utils/autoanchor.py index 89e4d97fdcd5..62c39811657b 100644 --- a/utils/autoanchor.py +++ b/utils/autoanchor.py @@ -15,7 +15,7 @@ def check_anchor_order(m): - # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary + """Checks and corrects anchor order against stride in YOLOv5 Detect() module if necessary.""" a = m.anchors.prod(-1).mean(-1).view(-1) # mean anchor area per output layer da = a[-1] - a[0] # delta a ds = m.stride[-1] - m.stride[0] # delta s @@ -26,7 +26,7 @@ def check_anchor_order(m): @TryExcept(f"{PREFIX}ERROR") def check_anchors(dataset, model, thr=4.0, imgsz=640): - # Check anchor fit to data, recompute if necessary + """Evaluates anchor fit to dataset and adjusts if necessary, supporting customizable threshold and image size.""" m = model.module.model[-1] if hasattr(model, "module") else model.model[-1] # Detect() shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True) scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale diff --git a/utils/autobatch.py b/utils/autobatch.py index 396dbed1dda4..52a71f62c47c 100644 --- a/utils/autobatch.py +++ b/utils/autobatch.py @@ -11,13 +11,13 @@ def check_train_batch_size(model, imgsz=640, amp=True): - # Check YOLOv5 training batch size + """Checks and computes optimal training batch size for YOLOv5 model, given image size and AMP setting.""" with torch.cuda.amp.autocast(amp): return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size def autobatch(model, imgsz=640, fraction=0.8, batch_size=16): - # Automatically estimate best YOLOv5 batch size to use `fraction` of available CUDA memory + """Estimates optimal YOLOv5 batch size using `fraction` of CUDA memory.""" # Usage: # import torch # from utils.autobatch import autobatch diff --git a/utils/callbacks.py b/utils/callbacks.py index f658a72cce7c..3275789fa12e 100644 --- a/utils/callbacks.py +++ b/utils/callbacks.py @@ -8,7 +8,7 @@ class Callbacks: """Handles all registered callbacks for YOLOv5 Hooks.""" def __init__(self): - # Define the available callbacks + """Initializes a Callbacks object to manage registered YOLOv5 training event hooks.""" self._callbacks = { "on_pretrain_routine_start": [], "on_pretrain_routine_end": [], diff --git a/utils/dataloaders.py b/utils/dataloaders.py index c821e917ed38..3e636717fb84 100644 --- a/utils/dataloaders.py +++ b/utils/dataloaders.py @@ -73,7 +73,7 @@ def get_hash(paths): - # Returns a single hash value of a list of paths (files or dirs) + """Generates a single SHA256 hash for a list of file or directory paths by combining their sizes and paths.""" size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes h = hashlib.sha256(str(size).encode()) # hash sizes h.update("".join(paths).encode()) # hash paths @@ -81,7 +81,7 @@ def get_hash(paths): def exif_size(img): - # Returns exif-corrected PIL size + """Returns corrected PIL image size (width, height) considering EXIF orientation.""" s = img.size # (width, height) with contextlib.suppress(Exception): rotation = dict(img._getexif().items())[orientation] @@ -118,7 +118,11 @@ def exif_transpose(image): def seed_worker(worker_id): - # Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader + """ + Sets the seed for a dataloader worker to ensure reproducibility, based on PyTorch's randomness notes. + + See https://pytorch.org/docs/stable/notes/randomness.html#dataloader. + """ worker_seed = torch.initial_seed() % 2**32 np.random.seed(worker_seed) random.seed(worker_seed) @@ -128,7 +132,7 @@ def seed_worker(worker_id): # https://github.com/pytorch/pytorch/blob/master/torch/utils/data/distributed.py class SmartDistributedSampler(distributed.DistributedSampler): def __iter__(self): - # deterministically shuffle based on epoch and seed + """Yields indices for distributed data sampling, shuffled deterministically based on epoch and seed.""" g = torch.Generator() g.manual_seed(self.seed + self.epoch) @@ -218,14 +222,19 @@ class InfiniteDataLoader(dataloader.DataLoader): """ def __init__(self, *args, **kwargs): + """Initializes an InfiniteDataLoader that reuses workers with standard DataLoader syntax, augmenting with a + repeating sampler. + """ super().__init__(*args, **kwargs) object.__setattr__(self, "batch_sampler", _RepeatSampler(self.batch_sampler)) self.iterator = super().__iter__() def __len__(self): + """Returns the length of the batch sampler's sampler in the InfiniteDataLoader.""" return len(self.batch_sampler.sampler) def __iter__(self): + """Yields batches of data indefinitely in a loop by resetting the sampler when exhausted.""" for _ in range(len(self)): yield next(self.iterator) @@ -239,9 +248,11 @@ class _RepeatSampler: """ def __init__(self, sampler): + """Initializes a perpetual sampler wrapping a provided `Sampler` instance for endless data iteration.""" self.sampler = sampler def __iter__(self): + """Returns an infinite iterator over the dataset by repeatedly yielding from the given sampler.""" while True: yield from iter(self.sampler) @@ -249,7 +260,12 @@ def __iter__(self): class LoadScreenshots: # YOLOv5 screenshot dataloader, i.e. `python detect.py --source "screen 0 100 100 512 256"` def __init__(self, source, img_size=640, stride=32, auto=True, transforms=None): - # source = [screen_number left top width height] (pixels) + """ + Initializes a screenshot dataloader for YOLOv5 with specified source region, image size, stride, auto, and + transforms. + + Source = [screen_number left top width height] (pixels) + """ check_requirements("mss") import mss @@ -278,10 +294,13 @@ def __init__(self, source, img_size=640, stride=32, auto=True, transforms=None): self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height} def __iter__(self): + """Iterates over itself, enabling use in loops and iterable contexts.""" return self def __next__(self): - # mss screen capture: get raw pixels from the screen as np array + """Captures and returns the next screen frame as a BGR numpy array, cropping to only the first three channels + from BGRA. + """ im0 = np.array(self.sct.grab(self.monitor))[:, :, :3] # [:, :, :3] BGRA to BGR s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: " @@ -296,8 +315,10 @@ def __next__(self): class LoadImages: - # YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4` + """YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`""" + def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1): + """Initializes YOLOv5 loader for images/videos, supporting glob patterns, directories, and lists of paths.""" if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line path = Path(path).read_text().rsplit() files = [] @@ -335,10 +356,12 @@ def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vi ) def __iter__(self): + """Initializes iterator by resetting count and returns the iterator object itself.""" self.count = 0 return self def __next__(self): + """Advances to the next file in the dataset, raising StopIteration if at the end.""" if self.count == self.nf: raise StopIteration path = self.files[self.count] @@ -379,7 +402,9 @@ def __next__(self): return path, im, im0, self.cap, s def _new_video(self, path): - # Create a new video capture object + """Initializes a new video capture object with path, frame count adjusted by stride, and orientation + metadata. + """ self.frame = 0 self.cap = cv2.VideoCapture(path) self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride) @@ -387,7 +412,7 @@ def _new_video(self, path): # self.cap.set(cv2.CAP_PROP_ORIENTATION_AUTO, 0) # disable https://github.com/ultralytics/yolov5/issues/8493 def _cv2_rotate(self, im): - # Rotate a cv2 video manually + """Rotates a cv2 image based on its orientation; supports 0, 90, and 180 degrees rotations.""" if self.orientation == 0: return cv2.rotate(im, cv2.ROTATE_90_CLOCKWISE) elif self.orientation == 180: @@ -397,12 +422,16 @@ def _cv2_rotate(self, im): return im def __len__(self): + """Returns the number of files in the dataset.""" return self.nf # number of files class LoadStreams: # YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams` def __init__(self, sources="file.streams", img_size=640, stride=32, auto=True, transforms=None, vid_stride=1): + """Initializes a stream loader for processing video streams with YOLOv5, supporting various sources including + YouTube. + """ torch.backends.cudnn.benchmark = True # faster for fixed-size inference self.mode = "stream" self.img_size = img_size @@ -448,7 +477,7 @@ def __init__(self, sources="file.streams", img_size=640, stride=32, auto=True, t LOGGER.warning("WARNING ⚠️ Stream shapes differ. For optimal performance supply similarly-shaped streams.") def update(self, i, cap, stream): - # Read stream `i` frames in daemon thread + """Reads frames from stream `i`, updating imgs array; handles stream reopening on signal loss.""" n, f = 0, self.frames[i] # frame number, frame array while cap.isOpened() and n < f: n += 1 @@ -464,10 +493,14 @@ def update(self, i, cap, stream): time.sleep(0.0) # wait time def __iter__(self): + """Resets and returns the iterator for iterating over video frames or images in a dataset.""" self.count = -1 return self def __next__(self): + """Iterates over video frames or images, halting on thread stop or 'q' key press, raising `StopIteration` when + done. + """ self.count += 1 if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord("q"): # q to quit cv2.destroyAllWindows() @@ -484,11 +517,14 @@ def __next__(self): return self.sources, im, im0, None, "" def __len__(self): + """Returns the number of sources in the dataset, supporting up to 32 streams at 30 FPS over 30 years.""" return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years def img2label_paths(img_paths): - # Define label paths as a function of image paths + """Generates label file paths from corresponding image file paths by replacing `/images/` with `/labels/` and + extension with `.txt`. + """ sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}" # /images/, /labels/ substrings return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths] @@ -657,7 +693,7 @@ def __init__( pbar.close() def check_cache_ram(self, safety_margin=0.1, prefix=""): - # Check image caching requirements vs available memory + """Checks if available RAM is sufficient for caching images, adjusting for a safety margin.""" b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes n = min(self.n, 30) # extrapolate from 30 random images for _ in range(n): @@ -676,7 +712,7 @@ def check_cache_ram(self, safety_margin=0.1, prefix=""): return cache def cache_labels(self, path=Path("./labels.cache"), prefix=""): - # Cache dataset labels, check images and read shapes + """Caches dataset labels, verifies images, reads shapes, and tracks dataset integrity.""" x = {} # dict nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages desc = f"{prefix}Scanning {path.parent / path.stem}..." @@ -716,6 +752,7 @@ def cache_labels(self, path=Path("./labels.cache"), prefix=""): return x def __len__(self): + """Returns the number of images in the dataset.""" return len(self.im_files) # def __iter__(self): @@ -725,6 +762,7 @@ def __len__(self): # return self def __getitem__(self, index): + """Fetches the dataset item at the given index, considering linear, shuffled, or weighted sampling.""" index = self.indices[index] # linear, shuffled, or image_weights hyp = self.hyp @@ -801,7 +839,11 @@ def __getitem__(self, index): return torch.from_numpy(img), labels_out, self.im_files[index], shapes def load_image(self, i): - # Loads 1 image from dataset index 'i', returns (im, original hw, resized hw) + """ + Loads an image by index, returning the image, its original dimensions, and resized dimensions. + + Returns (im, original hw, resized hw) + """ im, f, fn = ( self.ims[i], self.im_files[i], @@ -822,13 +864,13 @@ def load_image(self, i): return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized def cache_images_to_disk(self, i): - # Saves an image as an *.npy file for faster loading + """Saves an image to disk as an *.npy file for quicker loading, identified by index `i`.""" f = self.npy_files[i] if not f.exists(): np.save(f.as_posix(), cv2.imread(self.im_files[i])) def load_mosaic(self, index): - # YOLOv5 4-mosaic loader. Loads 1 image + 3 random images into a 4-image mosaic + """Loads a 4-image mosaic for YOLOv5, combining 1 selected and 3 random images, with labels and segments.""" labels4, segments4 = [], [] s = self.img_size yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border) # mosaic center x, y @@ -888,7 +930,9 @@ def load_mosaic(self, index): return img4, labels4 def load_mosaic9(self, index): - # YOLOv5 9-mosaic loader. Loads 1 image + 8 random images into a 9-image mosaic + """Loads 1 image + 8 random images into a 9-image mosaic for augmented YOLOv5 training, returning labels and + segments. + """ labels9, segments9 = [], [] s = self.img_size indices = [index] + random.choices(self.indices, k=8) # 8 additional image indices @@ -968,6 +1012,7 @@ def load_mosaic9(self, index): @staticmethod def collate_fn(batch): + """Batches images, labels, paths, and shapes, assigning unique indices to targets in merged label tensor.""" im, label, path, shapes = zip(*batch) # transposed for i, lb in enumerate(label): lb[:, 0] = i # add target image index for build_targets() @@ -975,6 +1020,7 @@ def collate_fn(batch): @staticmethod def collate_fn4(batch): + """Bundles a batch's data by quartering the number of shapes and paths, preparing it for model input.""" im, label, path, shapes = zip(*batch) # transposed n = len(shapes) // 4 im4, label4, path4, shapes4 = [], [], path[:n], shapes[:n] @@ -1003,7 +1049,9 @@ def collate_fn4(batch): # Ancillary functions -------------------------------------------------------------------------------------------------- def flatten_recursive(path=DATASETS_DIR / "coco128"): - # Flatten a recursive directory by bringing all files to top level + """Flattens a directory by copying all files from subdirectories to a new top-level directory, preserving + filenames. + """ new_path = Path(f"{str(path)}_flat") if os.path.exists(new_path): shutil.rmtree(new_path) # delete output folder @@ -1073,7 +1121,7 @@ def autosplit(path=DATASETS_DIR / "coco128/images", weights=(0.9, 0.1, 0.0), ann def verify_image_label(args): - # Verify one image-label pair + """Verifies a single image-label pair, ensuring image format, size, and legal label values.""" im_file, lb_file, prefix = args nm, nf, ne, nc, msg, segments = 0, 0, 0, 0, "", [] # number (missing, found, empty, corrupt), message, segments try: @@ -1141,7 +1189,9 @@ class HUBDatasetStats: """ def __init__(self, path="coco128.yaml", autodownload=False): - # Initialize class + """Initializes HUBDatasetStats with optional auto-download for datasets, given a path to dataset YAML or ZIP + file. + """ zipped, data_dir, yaml_path = self._unzip(Path(path)) try: with open(check_yaml(yaml_path), errors="ignore") as f: @@ -1160,7 +1210,9 @@ def __init__(self, path="coco128.yaml", autodownload=False): @staticmethod def _find_yaml(dir): - # Return data.yaml file + """Finds and returns the path to a single '.yaml' file in the specified directory, preferring files that match + the directory name. + """ files = list(dir.glob("*.yaml")) or list(dir.rglob("*.yaml")) # try root level first and then recursive assert files, f"No *.yaml file found in {dir}" if len(files) > 1: @@ -1170,7 +1222,7 @@ def _find_yaml(dir): return files[0] def _unzip(self, path): - # Unzip data.zip + """Unzips a .zip file at 'path', returning success status, unzipped directory, and path to YAML file within.""" if not str(path).endswith(".zip"): # path is data.yaml return False, None, path assert Path(path).is_file(), f"Error unzipping {path}, file not found" @@ -1180,7 +1232,7 @@ def _unzip(self, path): return True, str(dir), self._find_yaml(dir) # zipped, data_dir, yaml_path def _hub_ops(self, f, max_dim=1920): - # HUB ops for 1 image 'f': resize and save at reduced quality in /dataset-hub for web/app viewing + """Resizes and saves an image at reduced quality for web/app viewing, supporting both PIL and OpenCV.""" f_new = self.im_dir / Path(f).name # dataset-hub image filename try: # use PIL im = Image.open(f) @@ -1198,7 +1250,8 @@ def _hub_ops(self, f, max_dim=1920): cv2.imwrite(str(f_new), im) def get_json(self, save=False, verbose=False): - # Return dataset JSON for Ultralytics HUB + """Generates dataset JSON for Ultralytics HUB, optionally saves or prints it; save=bool, verbose=bool.""" + def _round(labels): # Update labels to integer class and 6 decimal place floats return [[int(c), *(round(x, 4) for x in points)] for c, *points in labels] @@ -1235,7 +1288,9 @@ def _round(labels): return self.stats def process_images(self): - # Compress images for Ultralytics HUB + """Compresses images for Ultralytics HUB across 'train', 'val', 'test' splits and saves to specified + directory. + """ for split in "train", "val", "test": if self.data.get(split) is None: continue @@ -1259,6 +1314,9 @@ class ClassificationDataset(torchvision.datasets.ImageFolder): """ def __init__(self, root, augment, imgsz, cache=False): + """Initializes YOLOv5 Classification Dataset with optional caching, augmentations, and transforms for image + classification. + """ super().__init__(root=root) self.torch_transforms = classify_transforms(imgsz) self.album_transforms = classify_albumentations(augment, imgsz) if augment else None @@ -1267,6 +1325,7 @@ def __init__(self, root, augment, imgsz, cache=False): self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im def __getitem__(self, i): + """Fetches and transforms an image sample by index, supporting RAM/disk caching and Augmentations.""" f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image if self.cache_ram and im is None: im = self.samples[i][3] = cv2.imread(f) diff --git a/utils/downloads.py b/utils/downloads.py index ccab278a1d00..071e1b077bf6 100644 --- a/utils/downloads.py +++ b/utils/downloads.py @@ -11,7 +11,7 @@ def is_url(url, check=True): - # Check if string is URL and check if URL exists + """Determines if a string is a URL and optionally checks its existence online, returning a boolean.""" try: url = str(url) result = urllib.parse.urlparse(url) @@ -22,13 +22,17 @@ def is_url(url, check=True): def gsutil_getsize(url=""): - # gs://bucket/file size https://cloud.google.com/storage/docs/gsutil/commands/du + """ + Returns the size in bytes of a file at a Google Cloud Storage URL using `gsutil du`. + + Returns 0 if the command fails or output is empty. + """ output = subprocess.check_output(["gsutil", "du", url], shell=True, encoding="utf-8") return int(output.split()[0]) if output else 0 def url_getsize(url="https://ultralytics.com/images/bus.jpg"): - # Return downloadable file size in bytes + """Returns the size in bytes of a downloadable file at a given URL; defaults to -1 if not found.""" response = requests.head(url, allow_redirects=True) return int(response.headers.get("content-length", -1)) @@ -54,7 +58,11 @@ def curl_download(url, filename, *, silent: bool = False) -> bool: def safe_download(file, url, url2=None, min_bytes=1e0, error_msg=""): - # Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes + """ + Downloads a file from a URL (or alternate URL) to a specified path if file is above a minimum size. + + Removes incomplete downloads. + """ from utils.general import LOGGER file = Path(file) @@ -78,7 +86,9 @@ def safe_download(file, url, url2=None, min_bytes=1e0, error_msg=""): def attempt_download(file, repo="ultralytics/yolov5", release="v7.0"): - # Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v7.0', etc. + """Downloads a file from GitHub release assets or via direct URL if not found locally, supporting backup + versions. + """ from utils.general import LOGGER def github_assets(repository, version="latest"): diff --git a/utils/flask_rest_api/restapi.py b/utils/flask_rest_api/restapi.py index e62c7ebd709f..b9bd16f1a63e 100644 --- a/utils/flask_rest_api/restapi.py +++ b/utils/flask_rest_api/restapi.py @@ -16,6 +16,9 @@ @app.route(DETECTION_URL, methods=["POST"]) def predict(model): + """Predict and return object detections in JSON format given an image and model name via a Flask REST API POST + request. + """ if request.method != "POST": return diff --git a/utils/general.py b/utils/general.py index aa2ed6eb947d..661475354adc 100644 --- a/utils/general.py +++ b/utils/general.py @@ -71,18 +71,18 @@ def is_ascii(s=""): - # Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7) + """Checks if input string `s` contains only ASCII characters; returns `True` if so, otherwise `False`.""" s = str(s) # convert list, tuple, None, etc. to str return len(s.encode().decode("ascii", "ignore")) == len(s) def is_chinese(s="人工智能"): - # Is string composed of any Chinese characters? + """Determines if a string `s` contains any Chinese characters; returns `True` if so, otherwise `False`.""" return bool(re.search("[\u4e00-\u9fff]", str(s))) def is_colab(): - # Is environment a Google Colab instance? + """Checks if the current environment is a Google Colab instance; returns `True` for Colab, otherwise `False`.""" return "google.colab" in sys.modules @@ -101,7 +101,7 @@ def is_jupyter(): def is_kaggle(): - # Is environment a Kaggle Notebook? + """Checks if the current environment is a Kaggle Notebook by validating environment variables.""" return os.environ.get("PWD") == "/kaggle/working" and os.environ.get("KAGGLE_URL_BASE") == "https://www.kaggle.com" @@ -117,7 +117,7 @@ def is_docker() -> bool: def is_writeable(dir, test=False): - # Return True if directory has write permissions, test opening a file with write permissions if test=True + """Checks if a directory is writable, optionally testing by creating a temporary file if `test=True`.""" if not test: return os.access(dir, os.W_OK) # possible issues on Windows file = Path(dir) / "tmp.txt" @@ -134,7 +134,7 @@ def is_writeable(dir, test=False): def set_logging(name=LOGGING_NAME, verbose=True): - # sets up logging for the given name + """Configures logging with specified verbosity; `name` sets the logger's name, `verbose` controls logging level.""" rank = int(os.getenv("RANK", -1)) # rank in world for Multi-GPU trainings level = logging.INFO if verbose and rank in {-1, 0} else logging.ERROR logging.config.dictConfig( @@ -168,7 +168,9 @@ def set_logging(name=LOGGING_NAME, verbose=True): def user_config_dir(dir="Ultralytics", env_var="YOLOV5_CONFIG_DIR"): - # Return path of user configuration directory. Prefer environment variable if exists. Make dir if required. + """Returns user configuration directory path, preferring environment variable `YOLOV5_CONFIG_DIR` if set, else OS- + specific. + """ env = os.getenv(env_var) if env: path = Path(env) # use environment variable @@ -186,19 +188,23 @@ def user_config_dir(dir="Ultralytics", env_var="YOLOV5_CONFIG_DIR"): class Profile(contextlib.ContextDecorator): # YOLOv5 Profile class. Usage: @Profile() decorator or 'with Profile():' context manager def __init__(self, t=0.0, device: torch.device = None): + """Initializes a profiling context for YOLOv5 with optional timing threshold and device specification.""" self.t = t self.device = device self.cuda = bool(device and str(device).startswith("cuda")) def __enter__(self): + """Initializes timing at the start of a profiling context block for performance measurement.""" self.start = self.time() return self def __exit__(self, type, value, traceback): + """Concludes timing, updating duration for profiling upon exiting a context block.""" self.dt = self.time() - self.start # delta-time self.t += self.dt # accumulate dt def time(self): + """Measures and returns the current time, synchronizing CUDA operations if `cuda` is True.""" if self.cuda: torch.cuda.synchronize(self.device) return time.time() @@ -207,19 +213,23 @@ def time(self): class Timeout(contextlib.ContextDecorator): # YOLOv5 Timeout class. Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager def __init__(self, seconds, *, timeout_msg="", suppress_timeout_errors=True): + """Initializes a timeout context/decorator with defined seconds, optional message, and error suppression.""" self.seconds = int(seconds) self.timeout_message = timeout_msg self.suppress = bool(suppress_timeout_errors) def _timeout_handler(self, signum, frame): + """Raises a TimeoutError with a custom message when a timeout event occurs.""" raise TimeoutError(self.timeout_message) def __enter__(self): + """Initializes timeout mechanism on non-Windows platforms, starting a countdown to raise TimeoutError.""" if platform.system() != "Windows": # not supported on Windows signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM signal.alarm(self.seconds) # start countdown for SIGALRM to be raised def __exit__(self, exc_type, exc_val, exc_tb): + """Disables active alarm on non-Windows systems and optionally suppresses TimeoutError if set.""" if platform.system() != "Windows": signal.alarm(0) # Cancel SIGALRM if it's scheduled if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError @@ -229,23 +239,26 @@ def __exit__(self, exc_type, exc_val, exc_tb): class WorkingDirectory(contextlib.ContextDecorator): # Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager def __init__(self, new_dir): + """Initializes a context manager/decorator to temporarily change the working directory.""" self.dir = new_dir # new dir self.cwd = Path.cwd().resolve() # current dir def __enter__(self): + """Temporarily changes the working directory within a 'with' statement context.""" os.chdir(self.dir) def __exit__(self, exc_type, exc_val, exc_tb): + """Restores the original working directory upon exiting a 'with' statement context.""" os.chdir(self.cwd) def methods(instance): - # Get class/instance methods + """Returns list of method names for a class/instance excluding dunder methods.""" return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")] def print_args(args: Optional[dict] = None, show_file=True, show_func=False): - # Print function arguments (optional args dict) + """Logs the arguments of the calling function, with options to include the filename and function name.""" x = inspect.currentframe().f_back # previous frame file, _, func, _, _ = inspect.getframeinfo(x) if args is None: # get args automatically @@ -260,7 +273,11 @@ def print_args(args: Optional[dict] = None, show_file=True, show_func=False): def init_seeds(seed=0, deterministic=False): - # Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html + """ + Initializes RNG seeds and sets deterministic options if specified. + + See https://pytorch.org/docs/stable/notes/randomness.html + """ random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) @@ -275,36 +292,38 @@ def init_seeds(seed=0, deterministic=False): def intersect_dicts(da, db, exclude=()): - # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values + """Returns intersection of `da` and `db` dicts with matching keys and shapes, excluding `exclude` keys; uses `da` + values. + """ return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape} def get_default_args(func): - # Get func() default arguments + """Returns a dict of `func` default arguments by inspecting its signature.""" signature = inspect.signature(func) return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty} def get_latest_run(search_dir="."): - # Return path to most recent 'last.pt' in /runs (i.e. to --resume from) + """Returns the path to the most recent 'last.pt' file in /runs to resume from, searches in `search_dir`.""" last_list = glob.glob(f"{search_dir}/**/last*.pt", recursive=True) return max(last_list, key=os.path.getctime) if last_list else "" def file_age(path=__file__): - # Return days since last file update + """Calculates and returns the age of a file in days based on its last modification time.""" dt = datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime) # delta return dt.days # + dt.seconds / 86400 # fractional days def file_date(path=__file__): - # Return human-readable file modification date, i.e. '2021-3-26' + """Returns a human-readable file modification date in 'YYYY-M-D' format, given a file path.""" t = datetime.fromtimestamp(Path(path).stat().st_mtime) return f"{t.year}-{t.month}-{t.day}" def file_size(path): - # Return file/dir size (MB) + """Returns file or directory size in megabytes (MB) for a given path, where directories are recursively summed.""" mb = 1 << 20 # bytes to MiB (1024 ** 2) path = Path(path) if path.is_file(): @@ -316,7 +335,9 @@ def file_size(path): def check_online(): - # Check internet connectivity + """Checks internet connectivity by attempting to create a connection to "1.1.1.1" on port 443, retries once if the + first attempt fails. + """ import socket def run_once(): @@ -342,7 +363,9 @@ def git_describe(path=ROOT): # path must be a directory @TryExcept() @WorkingDirectory(ROOT) def check_git_status(repo="ultralytics/yolov5", branch="master"): - # YOLOv5 status check, recommend 'git pull' if code is out of date + """Checks if YOLOv5 code is up-to-date with the repository, advising 'git pull' if behind; errors return informative + messages. + """ url = f"https://github.com/{repo}" msg = f", for updates see {url}" s = colorstr("github: ") # string @@ -369,7 +392,7 @@ def check_git_status(repo="ultralytics/yolov5", branch="master"): @WorkingDirectory(ROOT) def check_git_info(path="."): - # YOLOv5 git info check, return {remote, branch, commit} + """Checks YOLOv5 git info, returning a dict with remote URL, branch name, and commit hash.""" check_requirements("gitpython") import git @@ -387,12 +410,12 @@ def check_git_info(path="."): def check_python(minimum="3.8.0"): - # Check current python version vs. required python version + """Checks if current Python version meets the minimum required version, exits if not.""" check_version(platform.python_version(), minimum, name="Python ", hard=True) def check_version(current="0.0.0", minimum="0.0.0", name="version ", pinned=False, hard=False, verbose=False): - # Check version vs. required version + """Checks if the current version meets the minimum required version, exits or warns based on parameters.""" current, minimum = (pkg.parse_version(x) for x in (current, minimum)) result = (current == minimum) if pinned else (current >= minimum) # bool s = f"WARNING ⚠️ {name}{minimum} is required by YOLOv5, but {name}{current} is currently installed" # string @@ -404,7 +427,7 @@ def check_version(current="0.0.0", minimum="0.0.0", name="version ", pinned=Fals def check_img_size(imgsz, s=32, floor=0): - # Verify image size is a multiple of stride s in each dimension + """Adjusts image size to be divisible by stride `s`, supports int or list/tuple input, returns adjusted size.""" if isinstance(imgsz, int): # integer i.e. img_size=640 new_size = max(make_divisible(imgsz, int(s)), floor) else: # list i.e. img_size=[640, 480] @@ -416,7 +439,7 @@ def check_img_size(imgsz, s=32, floor=0): def check_imshow(warn=False): - # Check if environment supports image displays + """Checks environment support for image display; warns on failure if `warn=True`.""" try: assert not is_jupyter() assert not is_docker() @@ -432,7 +455,7 @@ def check_imshow(warn=False): def check_suffix(file="yolov5s.pt", suffix=(".pt",), msg=""): - # Check file(s) for acceptable suffix + """Validates if a file or files have an acceptable suffix, raising an error if not.""" if file and suffix: if isinstance(suffix, str): suffix = [suffix] @@ -443,12 +466,12 @@ def check_suffix(file="yolov5s.pt", suffix=(".pt",), msg=""): def check_yaml(file, suffix=(".yaml", ".yml")): - # Search/download YAML file (if necessary) and return path, checking suffix + """Searches/downloads a YAML file, verifies its suffix (.yaml or .yml), and returns the file path.""" return check_file(file, suffix) def check_file(file, suffix=""): - # Search/download file (if necessary) and return path + """Searches/downloads a file, checks its suffix (if provided), and returns the file path.""" check_suffix(file, suffix) # optional file = str(file) # convert to str() if os.path.isfile(file) or not file: # exists @@ -478,7 +501,7 @@ def check_file(file, suffix=""): def check_font(font=FONT, progress=False): - # Download font to CONFIG_DIR if necessary + """Ensures specified font exists or downloads it from Ultralytics assets, optionally displaying progress.""" font = Path(font) file = CONFIG_DIR / font.name if not font.exists() and not file.exists(): @@ -488,7 +511,7 @@ def check_font(font=FONT, progress=False): def check_dataset(data, autodownload=True): - # Download, check and/or unzip dataset if not found locally + """Validates and/or auto-downloads a dataset, returning its configuration as a dictionary.""" # Download (optional) extract_dir = "" @@ -554,7 +577,7 @@ def check_dataset(data, autodownload=True): def check_amp(model): - # Check PyTorch Automatic Mixed Precision (AMP) functionality. Return True on correct operation + """Checks PyTorch AMP functionality for a model, returns True if AMP operates correctly, otherwise False.""" from models.common import AutoShape, DetectMultiBackend def amp_allclose(model, im): @@ -582,19 +605,23 @@ def amp_allclose(model, im): def yaml_load(file="data.yaml"): - # Single-line safe yaml loading + """Safely loads and returns the contents of a YAML file specified by `file` argument.""" with open(file, errors="ignore") as f: return yaml.safe_load(f) def yaml_save(file="data.yaml", data={}): - # Single-line safe yaml saving + """Safely saves `data` to a YAML file specified by `file`, converting `Path` objects to strings; `data` is a + dictionary. + """ with open(file, "w") as f: yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False) def unzip_file(file, path=None, exclude=(".DS_Store", "__MACOSX")): - # Unzip a *.zip file to path/, excluding files containing strings in exclude list + """Unzips `file` to `path` (default: file's parent), excluding filenames containing any in `exclude` (`.DS_Store`, + `__MACOSX`). + """ if path is None: path = Path(file).parent # default path with ZipFile(file) as zipObj: @@ -604,13 +631,18 @@ def unzip_file(file, path=None, exclude=(".DS_Store", "__MACOSX")): def url2file(url): - # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt + """ + Converts a URL string to a valid filename by stripping protocol, domain, and any query parameters. + + Example https://url.com/file.txt?auth -> file.txt + """ url = str(Path(url)).replace(":/", "://") # Pathlib turns :// -> :/ return Path(urllib.parse.unquote(url)).name.split("?")[0] # '%2F' to '/', split https://url.com/file.txt?auth def download(url, dir=".", unzip=True, delete=True, curl=False, threads=1, retry=3): - # Multithreaded file download and unzip function, used in data.yaml for autodownload + """Downloads and optionally unzips files concurrently, supporting retries and curl fallback.""" + def download_one(url, dir): # Download 1 file success = True @@ -656,24 +688,34 @@ def download_one(url, dir): def make_divisible(x, divisor): - # Returns nearest x divisible by divisor + """Adjusts `x` to be divisible by `divisor`, returning the nearest greater or equal value.""" if isinstance(divisor, torch.Tensor): divisor = int(divisor.max()) # to int return math.ceil(x / divisor) * divisor def clean_str(s): - # Cleans a string by replacing special characters with underscore _ + """Cleans a string by replacing special characters with underscore, e.g., `clean_str('#example!')` returns + '_example_'. + """ return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s) def one_cycle(y1=0.0, y2=1.0, steps=100): - # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf + """ + Generates a lambda for a sinusoidal ramp from y1 to y2 over 'steps'. + + See https://arxiv.org/pdf/1812.01187.pdf for details. + """ return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1 def colorstr(*input): - # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world') + """ + Colors a string using ANSI escape codes, e.g., colorstr('blue', 'hello world'). + + See https://en.wikipedia.org/wiki/ANSI_escape_code. + """ *args, string = input if len(input) > 1 else ("blue", "bold", input[0]) # color arguments, string colors = { "black": "\033[30m", # basic colors @@ -700,7 +742,7 @@ def colorstr(*input): def labels_to_class_weights(labels, nc=80): - # Get class weights (inverse frequency) from training labels + """Calculates class weights from labels to handle class imbalance in training; input shape: (n, 5).""" if labels[0] is None: # no labels loaded return torch.Tensor() @@ -719,7 +761,7 @@ def labels_to_class_weights(labels, nc=80): def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)): - # Produces image weights based on class_weights and image contents + """Calculates image weights from labels using class weights for weighted sampling.""" # Usage: index = random.choices(range(n), weights=image_weights, k=1) # weighted image sample class_counts = np.array([np.bincount(x[:, 0].astype(int), minlength=nc) for x in labels]) return (class_weights.reshape(1, nc) * class_counts).sum(1) @@ -816,7 +858,7 @@ def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper) def xyxy2xywh(x): - # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right + """Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right.""" y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center @@ -826,7 +868,7 @@ def xyxy2xywh(x): def xywh2xyxy(x): - # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right + """Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right.""" y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y @@ -836,7 +878,7 @@ def xywh2xyxy(x): def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0): - # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right + """Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right.""" y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y @@ -846,7 +888,7 @@ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0): def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0): - # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right + """Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right.""" if clip: clip_boxes(x, (h - eps, w - eps)) # warning: inplace clip y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) @@ -858,7 +900,7 @@ def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0): def xyn2xy(x, w=640, h=640, padw=0, padh=0): - # Convert normalized segments into pixel segments, shape (n,2) + """Convert normalized segments into pixel segments, shape (n,2).""" y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) y[..., 0] = w * x[..., 0] + padw # top left x y[..., 1] = h * x[..., 1] + padh # top left y @@ -866,7 +908,7 @@ def xyn2xy(x, w=640, h=640, padw=0, padh=0): def segment2box(segment, width=640, height=640): - # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy) + """Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy).""" x, y = segment.T # segment xy inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height) ( @@ -877,7 +919,7 @@ def segment2box(segment, width=640, height=640): def segments2boxes(segments): - # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh) + """Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh).""" boxes = [] for s in segments: x, y = s.T # segment xy @@ -886,7 +928,7 @@ def segments2boxes(segments): def resample_segments(segments, n=1000): - # Up-sample an (n,2) segment + """Resamples an (n,2) segment to a fixed number of points for consistent representation.""" for i, s in enumerate(segments): s = np.concatenate((s, s[0:1, :]), axis=0) x = np.linspace(0, len(s) - 1, n) @@ -896,7 +938,7 @@ def resample_segments(segments, n=1000): def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None): - # Rescale boxes (xyxy) from img1_shape to img0_shape + """Rescales (xyxy) bounding boxes from img1_shape to img0_shape, optionally using provided `ratio_pad`.""" if ratio_pad is None: # calculate from img0_shape gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding @@ -912,7 +954,7 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None): def scale_segments(img1_shape, segments, img0_shape, ratio_pad=None, normalize=False): - # Rescale coords (xyxy) from img1_shape to img0_shape + """Rescales segment coordinates from img1_shape to img0_shape, optionally normalizing them with custom padding.""" if ratio_pad is None: # calculate from img0_shape gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding @@ -931,7 +973,7 @@ def scale_segments(img1_shape, segments, img0_shape, ratio_pad=None, normalize=F def clip_boxes(boxes, shape): - # Clip boxes (xyxy) to image shape (height, width) + """Clips bounding box coordinates (xyxy) to fit within the specified image shape (height, width).""" if isinstance(boxes, torch.Tensor): # faster individually boxes[..., 0].clamp_(0, shape[1]) # x1 boxes[..., 1].clamp_(0, shape[0]) # y1 @@ -943,7 +985,7 @@ def clip_boxes(boxes, shape): def clip_segments(segments, shape): - # Clip segments (xy1,xy2,...) to image shape (height, width) + """Clips segment coordinates (xy1, xy2, ...) to an image's boundaries given its shape (height, width).""" if isinstance(segments, torch.Tensor): # faster individually segments[:, 0].clamp_(0, shape[1]) # x segments[:, 1].clamp_(0, shape[0]) # y @@ -1083,6 +1125,7 @@ def strip_optimizer(f="best.pt", s=""): # from utils.general import *; strip_op def print_mutation(keys, results, hyp, save_dir, bucket, prefix=colorstr("evolve: ")): + """Logs evolution results and saves to CSV and YAML in `save_dir`, optionally syncs with `bucket`.""" evolve_csv = save_dir / "evolve.csv" evolve_yaml = save_dir / "hyp_evolve.yaml" keys = tuple(keys) + tuple(hyp.keys()) # [results + hyps] @@ -1137,7 +1180,7 @@ def print_mutation(keys, results, hyp, save_dir, bucket, prefix=colorstr("evolve def apply_classifier(x, model, img, im0): - # Apply a second stage classifier to YOLO outputs + """Applies second-stage classifier to YOLO outputs, filtering detections by class match.""" # Example model = torchvision.models.__dict__['efficientnet_b0'](pretrained=True).to(device).eval() im0 = [im0] if isinstance(im0, np.ndarray) else im0 for i, d in enumerate(x): # per image @@ -1172,7 +1215,12 @@ def apply_classifier(x, model, img, im0): def increment_path(path, exist_ok=False, sep="", mkdir=False): - # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc. + """ + Generates an incremented file or directory path if it exists, with optional mkdir; args: path, exist_ok=False, + sep="", mkdir=False. + + Example: runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc + """ path = Path(path) # os-agnostic if path.exists() and not exist_ok: path, suffix = (path.with_suffix(""), path.suffix) if path.is_file() else (path, "") @@ -1202,10 +1250,14 @@ def increment_path(path, exist_ok=False, sep="", mkdir=False): def imread(filename, flags=cv2.IMREAD_COLOR): + """Reads an image from a file and returns it as a numpy array, using OpenCV's imdecode to support multilanguage + paths. + """ return cv2.imdecode(np.fromfile(filename, np.uint8), flags) def imwrite(filename, img): + """Writes an image to a file, returns True on success and False on failure, supports multilanguage paths.""" try: cv2.imencode(Path(filename).suffix, img)[1].tofile(filename) return True @@ -1214,6 +1266,7 @@ def imwrite(filename, img): def imshow(path, im): + """Displays an image using Unicode path, requires encoded path and image matrix as input.""" imshow_(path.encode("unicode_escape").decode(), im) diff --git a/utils/loggers/__init__.py b/utils/loggers/__init__.py index 36792979913a..c3fbded50a3c 100644 --- a/utils/loggers/__init__.py +++ b/utils/loggers/__init__.py @@ -73,6 +73,7 @@ def _json_default(value): class Loggers: # YOLOv5 Loggers class def __init__(self, save_dir=None, weights=None, opt=None, hyp=None, logger=None, include=LOGGERS): + """Initializes loggers for YOLOv5 training and validation metrics, paths, and options.""" self.save_dir = save_dir self.weights = weights self.opt = opt @@ -150,7 +151,7 @@ def __init__(self, save_dir=None, weights=None, opt=None, hyp=None, logger=None, @property def remote_dataset(self): - # Get data_dict if custom dataset artifact link is provided + """Fetches dataset dictionary from remote logging services like ClearML, Weights & Biases, or Comet ML.""" data_dict = None if self.clearml: data_dict = self.clearml.data_dict @@ -162,15 +163,17 @@ def remote_dataset(self): return data_dict def on_train_start(self): + """Initializes the training process for Comet ML logger if it's configured.""" if self.comet_logger: self.comet_logger.on_train_start() def on_pretrain_routine_start(self): + """Invokes pre-training routine start hook for Comet ML logger if available.""" if self.comet_logger: self.comet_logger.on_pretrain_routine_start() def on_pretrain_routine_end(self, labels, names): - # Callback runs on pre-train routine end + """Callback that runs at the end of pre-training routine, logging label plots if enabled.""" if self.plots: plot_labels(labels, names, self.save_dir) paths = self.save_dir.glob("*labels*.jpg") # training labels @@ -183,6 +186,7 @@ def on_pretrain_routine_end(self, labels, names): self.clearml.log_plot(title=path.stem, plot_path=path) def on_train_batch_end(self, model, ni, imgs, targets, paths, vals): + """Logs training batch end events, plots images, and updates external loggers with batch-end data.""" log_dict = dict(zip(self.keys[:3], vals)) # Callback runs on train batch end # ni: number integrated batches (since train start) @@ -203,7 +207,7 @@ def on_train_batch_end(self, model, ni, imgs, targets, paths, vals): self.comet_logger.on_train_batch_end(log_dict, step=ni) def on_train_epoch_end(self, epoch): - # Callback runs on train epoch end + """Callback that updates the current epoch in Weights & Biases at the end of a training epoch.""" if self.wandb: self.wandb.current_epoch = epoch + 1 @@ -211,22 +215,24 @@ def on_train_epoch_end(self, epoch): self.comet_logger.on_train_epoch_end(epoch) def on_val_start(self): + """Callback that signals the start of a validation phase to the Comet logger.""" if self.comet_logger: self.comet_logger.on_val_start() def on_val_image_end(self, pred, predn, path, names, im): - # Callback runs on val image end + """Callback that logs a validation image and its predictions to WandB or ClearML.""" if self.wandb: self.wandb.val_one_image(pred, predn, path, names, im) if self.clearml: self.clearml.log_image_with_boxes(path, pred, names, im) def on_val_batch_end(self, batch_i, im, targets, paths, shapes, out): + """Logs validation batch results to Comet ML during training at the end of each validation batch.""" if self.comet_logger: self.comet_logger.on_val_batch_end(batch_i, im, targets, paths, shapes, out) def on_val_end(self, nt, tp, fp, p, r, f1, ap, ap50, ap_class, confusion_matrix): - # Callback runs on val end + """Logs validation results to WandB or ClearML at the end of the validation process.""" if self.wandb or self.clearml: files = sorted(self.save_dir.glob("val*.jpg")) if self.wandb: @@ -238,7 +244,7 @@ def on_val_end(self, nt, tp, fp, p, r, f1, ap, ap50, ap_class, confusion_matrix) self.comet_logger.on_val_end(nt, tp, fp, p, r, f1, ap, ap50, ap_class, confusion_matrix) def on_fit_epoch_end(self, vals, epoch, best_fitness, fi): - # Callback runs at the end of each fit (train+val) epoch + """Callback that logs metrics and saves them to CSV or NDJSON at the end of each fit (train+val) epoch.""" x = dict(zip(self.keys, vals)) if self.csv: file = self.save_dir / "results.csv" @@ -277,7 +283,7 @@ def on_fit_epoch_end(self, vals, epoch, best_fitness, fi): self.comet_logger.on_fit_epoch_end(x, epoch=epoch) def on_model_save(self, last, epoch, final_epoch, best_fitness, fi): - # Callback runs on model save event + """Callback that handles model saving events, logging to Weights & Biases or ClearML if enabled.""" if (epoch + 1) % self.opt.save_period == 0 and not final_epoch and self.opt.save_period != -1: if self.wandb: self.wandb.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi) @@ -290,7 +296,7 @@ def on_model_save(self, last, epoch, final_epoch, best_fitness, fi): self.comet_logger.on_model_save(last, epoch, final_epoch, best_fitness, fi) def on_train_end(self, last, best, epoch, results): - # Callback runs on training end, i.e. saving best model + """Callback that runs at the end of training to save plots and log results.""" if self.plots: plot_results(file=self.save_dir / "results.csv") # save results.png files = ["results.png", "confusion_matrix.png", *(f"{x}_curve.png" for x in ("F1", "PR", "P", "R"))] @@ -326,7 +332,7 @@ def on_train_end(self, last, best, epoch, results): self.comet_logger.on_train_end(files, self.save_dir, last, best, epoch, final_results) def on_params_update(self, params: dict): - # Update hyperparams or configs of the experiment + """Updates experiment hyperparameters or configurations in WandB, Comet, or ClearML.""" if self.wandb: self.wandb.wandb_run.config.update(params, allow_val_change=True) if self.comet_logger: @@ -346,7 +352,7 @@ class GenericLogger: """ def __init__(self, opt, console_logger, include=("tb", "wandb", "clearml")): - # init default loggers + """Initializes a generic logger with optional TensorBoard, W&B, and ClearML support.""" self.save_dir = Path(opt.save_dir) self.include = include self.console_logger = console_logger @@ -381,7 +387,7 @@ def __init__(self, opt, console_logger, include=("tb", "wandb", "clearml")): self.clearml = None def log_metrics(self, metrics, epoch): - # Log metrics dictionary to all loggers + """Logs metrics to CSV, TensorBoard, W&B, and ClearML; `metrics` is a dict, `epoch` is an int.""" if self.csv: keys, vals = list(metrics.keys()), list(metrics.values()) n = len(metrics) + 1 # number of cols @@ -400,7 +406,7 @@ def log_metrics(self, metrics, epoch): self.clearml.log_scalars(metrics, epoch) def log_images(self, files, name="Images", epoch=0): - # Log images to all loggers + """Logs images to all loggers with optional naming and epoch specification.""" files = [Path(f) for f in (files if isinstance(files, (tuple, list)) else [files])] # to Path files = [f for f in files if f.exists()] # filter by exists @@ -418,11 +424,12 @@ def log_images(self, files, name="Images", epoch=0): self.clearml.log_debug_samples(files, title=name) def log_graph(self, model, imgsz=(640, 640)): - # Log model graph to all loggers + """Logs model graph to all configured loggers with specified input image size.""" if self.tb: log_tensorboard_graph(self.tb, model, imgsz) def log_model(self, model_path, epoch=0, metadata=None): + """Logs the model to all configured loggers with optional epoch and metadata.""" if metadata is None: metadata = {} # Log model to all loggers @@ -434,7 +441,7 @@ def log_model(self, model_path, epoch=0, metadata=None): self.clearml.log_model(model_path=model_path, model_name=model_path.stem) def update_params(self, params): - # Update the parameters logged + """Updates logged parameters in WandB and/or ClearML if enabled.""" if self.wandb: wandb.run.config.update(params, allow_val_change=True) if self.clearml: @@ -442,7 +449,7 @@ def update_params(self, params): def log_tensorboard_graph(tb, model, imgsz=(640, 640)): - # Log model graph to TensorBoard + """Logs the model graph to TensorBoard with specified image size and model.""" try: p = next(model.parameters()) # for device, type imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz # expand @@ -455,7 +462,7 @@ def log_tensorboard_graph(tb, model, imgsz=(640, 640)): def web_project_name(project): - # Convert local project name to web project name + """Converts a local project name to a standardized web project name with optional suffixes.""" if not project.startswith("runs/train"): return project suffix = "-Classify" if project.endswith("-cls") else "-Segment" if project.endswith("-seg") else "" diff --git a/utils/loggers/comet/__init__.py b/utils/loggers/comet/__init__.py index cec46f5af1fb..076eb3ccecab 100644 --- a/utils/loggers/comet/__init__.py +++ b/utils/loggers/comet/__init__.py @@ -165,6 +165,7 @@ def __init__(self, opt, hyp, run_id=None, job_type="Training", **experiment_kwar self.experiment.log_other("optimizer_parameters", json.dumps(self.hyp)) def _get_experiment(self, mode, experiment_id=None): + """Returns a new or existing Comet.ml experiment based on mode and optional experiment_id.""" if mode == "offline": return ( comet_ml.ExistingOfflineExperiment( @@ -197,21 +198,27 @@ def _get_experiment(self, mode, experiment_id=None): return def log_metrics(self, log_dict, **kwargs): + """Logs metrics to the current experiment, accepting a dictionary of metric names and values.""" self.experiment.log_metrics(log_dict, **kwargs) def log_parameters(self, log_dict, **kwargs): + """Logs parameters to the current experiment, accepting a dictionary of parameter names and values.""" self.experiment.log_parameters(log_dict, **kwargs) def log_asset(self, asset_path, **kwargs): + """Logs a file or directory as an asset to the current experiment.""" self.experiment.log_asset(asset_path, **kwargs) def log_asset_data(self, asset, **kwargs): + """Logs in-memory data as an asset to the current experiment, with optional kwargs.""" self.experiment.log_asset_data(asset, **kwargs) def log_image(self, img, **kwargs): + """Logs an image to the current experiment with optional kwargs.""" self.experiment.log_image(img, **kwargs) def log_model(self, path, opt, epoch, fitness_score, best_model=False): + """Logs model checkpoint to experiment with path, options, epoch, fitness, and best model flag.""" if not self.save_model: return @@ -235,6 +242,7 @@ def log_model(self, path, opt, epoch, fitness_score, best_model=False): ) def check_dataset(self, data_file): + """Validates the dataset configuration by loading the YAML file specified in `data_file`.""" with open(data_file) as f: data_config = yaml.safe_load(f) @@ -247,6 +255,7 @@ def check_dataset(self, data_file): return check_dataset(data_file) def log_predictions(self, image, labelsn, path, shape, predn): + """Logs predictions with IOU filtering, given image, labels, path, shape, and predictions.""" if self.logged_images_count >= self.max_images: return detections = predn[predn[:, 4] > self.conf_thres] @@ -287,6 +296,7 @@ def log_predictions(self, image, labelsn, path, shape, predn): return def preprocess_prediction(self, image, labels, shape, pred): + """Processes prediction data, resizing labels and adding dataset metadata.""" nl, _ = labels.shape[0], pred.shape[0] # Predictions @@ -306,6 +316,7 @@ def preprocess_prediction(self, image, labels, shape, pred): return predn, labelsn def add_assets_to_artifact(self, artifact, path, asset_path, split): + """Adds image and label assets to a wandb artifact given dataset split and paths.""" img_paths = sorted(glob.glob(f"{asset_path}/*")) label_paths = img2label_paths(img_paths) @@ -331,6 +342,7 @@ def add_assets_to_artifact(self, artifact, path, asset_path, split): return artifact def upload_dataset_artifact(self): + """Uploads a YOLOv5 dataset as an artifact to the Comet.ml platform.""" dataset_name = self.data_dict.get("dataset_name", "yolov5-dataset") path = str((ROOT / Path(self.data_dict["path"])).resolve()) @@ -355,6 +367,7 @@ def upload_dataset_artifact(self): return def download_dataset_artifact(self, artifact_path): + """Downloads a dataset artifact to a specified directory using the experiment's logged artifact.""" logged_artifact = self.experiment.get_artifact(artifact_path) artifact_save_dir = str(Path(self.opt.save_dir) / logged_artifact.name) logged_artifact.download(artifact_save_dir) @@ -374,6 +387,7 @@ def download_dataset_artifact(self, artifact_path): return self.update_data_paths(data_dict) def update_data_paths(self, data_dict): + """Updates data paths in the dataset dictionary, defaulting 'path' to an empty string if not present.""" path = data_dict.get("path", "") for split in ["train", "val", "test"]: @@ -386,6 +400,7 @@ def update_data_paths(self, data_dict): return data_dict def on_pretrain_routine_end(self, paths): + """Called at the end of pretraining routine to handle paths if training is not being resumed.""" if self.opt.resume: return @@ -398,20 +413,25 @@ def on_pretrain_routine_end(self, paths): return def on_train_start(self): + """Logs hyperparameters at the start of training.""" self.log_parameters(self.hyp) def on_train_epoch_start(self): + """Called at the start of each training epoch.""" return def on_train_epoch_end(self, epoch): + """Updates the current epoch in the experiment tracking at the end of each epoch.""" self.experiment.curr_epoch = epoch return def on_train_batch_start(self): + """Called at the start of each training batch.""" return def on_train_batch_end(self, log_dict, step): + """Callback function that updates and logs metrics at the end of each training batch if conditions are met.""" self.experiment.curr_step = step if self.log_batch_metrics and (step % self.comet_log_batch_interval == 0): self.log_metrics(log_dict, step=step) @@ -419,6 +439,7 @@ def on_train_batch_end(self, log_dict, step): return def on_train_end(self, files, save_dir, last, best, epoch, results): + """Logs metadata and optionally saves model files at the end of training.""" if self.comet_log_predictions: curr_epoch = self.experiment.curr_epoch self.experiment.log_asset_data(self.metadata_dict, "image-metadata.json", epoch=curr_epoch) @@ -446,12 +467,15 @@ def on_train_end(self, files, save_dir, last, best, epoch, results): self.finish_run() def on_val_start(self): + """Called at the start of validation, currently a placeholder with no functionality.""" return def on_val_batch_start(self): + """Placeholder called at the start of a validation batch with no current functionality.""" return def on_val_batch_end(self, batch_i, images, targets, paths, shapes, outputs): + """Callback executed at the end of a validation batch, conditionally logs predictions to Comet ML.""" if not (self.comet_log_predictions and ((batch_i + 1) % self.comet_log_prediction_interval == 0)): return @@ -470,6 +494,7 @@ def on_val_batch_end(self, batch_i, images, targets, paths, shapes, outputs): return def on_val_end(self, nt, tp, fp, p, r, f1, ap, ap50, ap_class, confusion_matrix): + """Logs per-class metrics to Comet.ml after validation if enabled and more than one class exists.""" if self.comet_log_per_class_metrics and self.num_classes > 1: for i, c in enumerate(ap_class): class_name = self.class_names[c] @@ -504,14 +529,18 @@ def on_val_end(self, nt, tp, fp, p, r, f1, ap, ap50, ap_class, confusion_matrix) ) def on_fit_epoch_end(self, result, epoch): + """Logs metrics at the end of each training epoch.""" self.log_metrics(result, epoch=epoch) def on_model_save(self, last, epoch, final_epoch, best_fitness, fi): + """Callback to save model checkpoints periodically if conditions are met.""" if ((epoch + 1) % self.opt.save_period == 0 and not final_epoch) and self.opt.save_period != -1: self.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi) def on_params_update(self, params): + """Logs updated parameters during training.""" self.log_parameters(params) def finish_run(self): + """Ends the current experiment and logs its completion.""" self.experiment.end() diff --git a/utils/loggers/comet/comet_utils.py b/utils/loggers/comet/comet_utils.py index 6e8fad68c6cc..7eca1f504d69 100644 --- a/utils/loggers/comet/comet_utils.py +++ b/utils/loggers/comet/comet_utils.py @@ -17,6 +17,7 @@ def download_model_checkpoint(opt, experiment): + """Downloads YOLOv5 model checkpoint from Comet ML experiment, updating `opt.weights` with download path.""" model_dir = f"{opt.project}/{experiment.name}" os.makedirs(model_dir, exist_ok=True) diff --git a/utils/loggers/comet/hpo.py b/utils/loggers/comet/hpo.py index a9e6fabec1cd..8ca08ddc858a 100644 --- a/utils/loggers/comet/hpo.py +++ b/utils/loggers/comet/hpo.py @@ -25,6 +25,9 @@ def get_args(known=False): + """Parses command-line arguments for YOLOv5 training, supporting configuration of weights, data paths, + hyperparameters, and more. + """ parser = argparse.ArgumentParser() parser.add_argument("--weights", type=str, default=ROOT / "yolov5s.pt", help="initial weights path") parser.add_argument("--cfg", type=str, default="", help="model.yaml path") @@ -83,6 +86,7 @@ def get_args(known=False): def run(parameters, opt): + """Executes YOLOv5 training with given hyperparameters and options, setting up device and training directories.""" hyp_dict = {k: v for k, v in parameters.items() if k not in ["epochs", "batch_size"]} opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok or opt.evolve)) diff --git a/utils/loggers/wandb/wandb_utils.py b/utils/loggers/wandb/wandb_utils.py index 0af8bda12d85..4083312e6a59 100644 --- a/utils/loggers/wandb/wandb_utils.py +++ b/utils/loggers/wandb/wandb_utils.py @@ -152,6 +152,7 @@ def log_model(self, path, opt, epoch, fitness_score, best_model=False): LOGGER.info(f"Saving model artifact on epoch {epoch + 1}") def val_one_image(self, pred, predn, path, names, im): + """Evaluates model prediction for a single image, returning metrics and visualizations.""" pass def log(self, log_dict): diff --git a/utils/loss.py b/utils/loss.py index 26b8c06bf333..8a910e12ad6f 100644 --- a/utils/loss.py +++ b/utils/loss.py @@ -16,11 +16,17 @@ def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#iss class BCEBlurWithLogitsLoss(nn.Module): # BCEwithLogitLoss() with reduced missing label effects. def __init__(self, alpha=0.05): + """Initializes a modified BCEWithLogitsLoss with reduced missing label effects, taking optional alpha smoothing + parameter. + """ super().__init__() self.loss_fcn = nn.BCEWithLogitsLoss(reduction="none") # must be nn.BCEWithLogitsLoss() self.alpha = alpha def forward(self, pred, true): + """Computes modified BCE loss for YOLOv5 with reduced missing label effects, taking pred and true tensors, + returns mean loss. + """ loss = self.loss_fcn(pred, true) pred = torch.sigmoid(pred) # prob from logits dx = pred - true # reduce only missing label effects @@ -33,6 +39,9 @@ def forward(self, pred, true): class FocalLoss(nn.Module): # Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5) def __init__(self, loss_fcn, gamma=1.5, alpha=0.25): + """Initializes FocalLoss with specified loss function, gamma, and alpha values; modifies loss reduction to + 'none'. + """ super().__init__() self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss() self.gamma = gamma @@ -41,6 +50,7 @@ def __init__(self, loss_fcn, gamma=1.5, alpha=0.25): self.loss_fcn.reduction = "none" # required to apply FL to each element def forward(self, pred, true): + """Calculates the focal loss between predicted and true labels using a modified BCEWithLogitsLoss.""" loss = self.loss_fcn(pred, true) # p_t = torch.exp(-loss) # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability @@ -63,6 +73,7 @@ def forward(self, pred, true): class QFocalLoss(nn.Module): # Wraps Quality focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5) def __init__(self, loss_fcn, gamma=1.5, alpha=0.25): + """Initializes Quality Focal Loss with given loss function, gamma, alpha; modifies reduction to 'none'.""" super().__init__() self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss() self.gamma = gamma @@ -71,6 +82,9 @@ def __init__(self, loss_fcn, gamma=1.5, alpha=0.25): self.loss_fcn.reduction = "none" # required to apply FL to each element def forward(self, pred, true): + """Computes the focal loss between `pred` and `true` using BCEWithLogitsLoss, adjusting for imbalance with + `gamma` and `alpha`. + """ loss = self.loss_fcn(pred, true) pred_prob = torch.sigmoid(pred) # prob from logits @@ -91,6 +105,7 @@ class ComputeLoss: # Compute losses def __init__(self, model, autobalance=False): + """Initializes ComputeLoss with model and autobalance option, autobalances losses if True.""" device = next(model.parameters()).device # get model device h = model.hyp # hyperparameters @@ -173,7 +188,9 @@ def __call__(self, p, targets): # predictions, targets return (lbox + lobj + lcls) * bs, torch.cat((lbox, lobj, lcls)).detach() def build_targets(self, p, targets): - # Build targets for compute_loss(), input targets(image,class,x,y,w,h) + """Prepares model targets from input targets (image,class,x,y,w,h) for loss computation, returning class, box, + indices, and anchors. + """ na, nt = self.na, targets.shape[0] # number of anchors, targets tcls, tbox, indices, anch = [], [], [], [] gain = torch.ones(7, device=self.device) # normalized to gridspace gain diff --git a/utils/metrics.py b/utils/metrics.py index 5f45621dc372..e572355fec1e 100644 --- a/utils/metrics.py +++ b/utils/metrics.py @@ -13,13 +13,13 @@ def fitness(x): - # Model fitness as a weighted combination of metrics + """Calculates fitness of a model using weighted sum of metrics P, R, mAP@0.5, mAP@0.5:0.95.""" w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95] return (x[:, :4] * w).sum(1) def smooth(y, f=0.05): - # Box filter of fraction f + """Applies box filter smoothing to array `y` with fraction `f`, yielding a smoothed array.""" nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd) p = np.ones(nf // 2) # ones padding yp = np.concatenate((p * y[0], y, p * y[-1]), 0) # y padded @@ -126,6 +126,7 @@ def compute_ap(recall, precision): class ConfusionMatrix: # Updated version of https://github.com/kaanakan/object_detection_confusion_matrix def __init__(self, nc, conf=0.25, iou_thres=0.45): + """Initializes ConfusionMatrix with given number of classes, confidence, and IoU threshold.""" self.matrix = np.zeros((nc + 1, nc + 1)) self.nc = nc # number of classes self.conf = conf @@ -179,6 +180,9 @@ def process_batch(self, detections, labels): self.matrix[dc, self.nc] += 1 # predicted background def tp_fp(self): + """Calculates true positives (tp) and false positives (fp) excluding the background class from the confusion + matrix. + """ tp = self.matrix.diagonal() # true positives fp = self.matrix.sum(1) - tp # false positives # fn = self.matrix.sum(0) - tp # false negatives (missed detections) @@ -186,6 +190,7 @@ def tp_fp(self): @TryExcept("WARNING ⚠️ ConfusionMatrix plot failure") def plot(self, normalize=True, save_dir="", names=()): + """Plots confusion matrix using seaborn, optional normalization; can save plot to specified directory.""" import seaborn as sn array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1) # normalize columns @@ -217,12 +222,17 @@ def plot(self, normalize=True, save_dir="", names=()): plt.close(fig) def print(self): + """Prints the confusion matrix row-wise, with each class and its predictions separated by spaces.""" for i in range(self.nc + 1): print(" ".join(map(str, self.matrix[i]))) def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7): - # Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4) + """ + Calculates IoU, GIoU, DIoU, or CIoU between two boxes, supporting xywh/xyxy formats. + + Input shapes are box1(1,4) to box2(n,4). + """ # Get the coordinates of bounding boxes if xywh: # transform from xywh to xyxy @@ -312,7 +322,9 @@ def bbox_ioa(box1, box2, eps=1e-7): def wh_iou(wh1, wh2, eps=1e-7): - # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2 + """Calculates the Intersection over Union (IoU) for two sets of widths and heights; `wh1` and `wh2` should be nx2 + and mx2 tensors. + """ wh1 = wh1[:, None] # [N,1,2] wh2 = wh2[None] # [1,M,2] inter = torch.min(wh1, wh2).prod(2) # [N,M] @@ -324,7 +336,9 @@ def wh_iou(wh1, wh2, eps=1e-7): @threaded def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names=()): - # Precision-recall curve + """Plots precision-recall curve, optionally per class, saving to `save_dir`; `px`, `py` are lists, `ap` is Nx2 + array, `names` optional. + """ fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) py = np.stack(py, axis=1) @@ -347,7 +361,7 @@ def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names=()): @threaded def plot_mc_curve(px, py, save_dir=Path("mc_curve.png"), names=(), xlabel="Confidence", ylabel="Metric"): - # Metric-confidence curve + """Plots a metric-confidence curve for model predictions, supporting per-class visualization and smoothing.""" fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) if 0 < len(names) < 21: # display per-class legend if < 21 classes diff --git a/utils/plots.py b/utils/plots.py index 11c96a6372c3..e1b073dfb1ad 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -31,7 +31,11 @@ class Colors: # Ultralytics color palette https://ultralytics.com/ def __init__(self): - # hex = matplotlib.colors.TABLEAU_COLORS.values() + """ + Initializes the Colors class with a palette derived from Ultralytics color scheme, converting hex codes to RGB. + + Colors derived from `hex = matplotlib.colors.TABLEAU_COLORS.values()`. + """ hexs = ( "FF3838", "FF9D97", @@ -58,6 +62,7 @@ def __init__(self): self.n = len(self.palette) def __call__(self, i, bgr=False): + """Returns color from palette by index `i`, in BGR format if `bgr=True`, else RGB; `i` is an integer index.""" c = self.palette[int(i) % self.n] return (c[2], c[1], c[0]) if bgr else c @@ -100,7 +105,11 @@ def feature_visualization(x, module_type, stage, n=32, save_dir=Path("runs/detec def hist2d(x, y, n=100): - # 2d histogram used in labels.png and evolve.png + """ + Generates a logarithmic 2D histogram, useful for visualizing label or evolution distributions. + + Used in used in labels.png and evolve.png. + """ xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n) hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges)) xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1) @@ -109,6 +118,7 @@ def hist2d(x, y, n=100): def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5): + """Applies a low-pass Butterworth filter to `data` with specified `cutoff`, `fs`, and `order`.""" from scipy.signal import butter, filtfilt # https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy @@ -122,7 +132,9 @@ def butter_lowpass(cutoff, fs, order): def output_to_target(output, max_det=300): - # Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting + """Converts YOLOv5 model output to [batch_id, class_id, x, y, w, h, conf] format for plotting, limiting detections + to `max_det`. + """ targets = [] for i, o in enumerate(output): box, conf, cls = o[:max_det, :6].cpu().split((4, 1, 1), 1) @@ -133,7 +145,7 @@ def output_to_target(output, max_det=300): @threaded def plot_images(images, targets, paths=None, fname="images.jpg", names=None): - # Plot image grid with labels + """Plots an image grid with labels from YOLOv5 predictions or targets, saving to `fname`.""" if isinstance(images, torch.Tensor): images = images.cpu().float().numpy() if isinstance(targets, torch.Tensor): @@ -197,7 +209,7 @@ def plot_images(images, targets, paths=None, fname="images.jpg", names=None): def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=""): - # Plot LR simulating training for full epochs + """Plots learning rate schedule for given optimizer and scheduler, saving plot to `save_dir`.""" optimizer, scheduler = copy(optimizer), copy(scheduler) # do not modify originals y = [] for _ in range(epochs): @@ -295,7 +307,7 @@ def plot_val_study(file="", dir="", x=None): # from utils.plots import *; plot_ @TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395 def plot_labels(labels, names=(), save_dir=Path("")): - # plot dataset labels + """Plots dataset labels, saving correlogram and label images, handles classes, and visualizes bounding boxes.""" LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ") c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes nc = int(c.max() + 1) # number of classes @@ -340,7 +352,7 @@ def plot_labels(labels, names=(), save_dir=Path("")): def imshow_cls(im, labels=None, pred=None, names=None, nmax=25, verbose=False, f=Path("images.jpg")): - # Show classification image grid with labels (optional) and predictions (optional) + """Displays a grid of images with optional labels and predictions, saving to a file.""" from utils.augmentations import denormalize names = names or [f"class{i}" for i in range(1000)] @@ -397,7 +409,11 @@ def plot_evolve(evolve_csv="path/to/evolve.csv"): # from utils.plots import *; def plot_results(file="path/to/results.csv", dir=""): - # Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv') + """ + Plots training results from a 'results.csv' file; accepts file path and directory as arguments. + + Example: from utils.plots import *; plot_results('path/to/results.csv') + """ save_dir = Path(file).parent if file else Path(dir) fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True) ax = ax.ravel() @@ -424,7 +440,11 @@ def plot_results(file="path/to/results.csv", dir=""): def profile_idetection(start=0, stop=0, labels=(), save_dir=""): - # Plot iDetection '*.txt' per-image logs. from utils.plots import *; profile_idetection() + """ + Plots per-image iDetection logs, comparing metrics like storage and performance over time. + + Example: from utils.plots import *; profile_idetection() + """ ax = plt.subplots(2, 4, figsize=(12, 6), tight_layout=True)[1].ravel() s = ["Images", "Free Storage (GB)", "RAM Usage (GB)", "Battery", "dt_raw (ms)", "dt_smooth (ms)", "real-world FPS"] files = list(Path(save_dir).glob("frames*.txt")) @@ -455,7 +475,9 @@ def profile_idetection(start=0, stop=0, labels=(), save_dir=""): def save_one_box(xyxy, im, file=Path("im.jpg"), gain=1.02, pad=10, square=False, BGR=False, save=True): - # Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop + """Crops and saves an image from bounding box `xyxy`, applied with `gain` and `pad`, optionally squares and adjusts + for BGR. + """ xyxy = torch.tensor(xyxy).view(-1, 4) b = xyxy2xywh(xyxy) # boxes if square: diff --git a/utils/segment/augmentations.py b/utils/segment/augmentations.py index 56636b65d93a..e13a53d34821 100644 --- a/utils/segment/augmentations.py +++ b/utils/segment/augmentations.py @@ -12,7 +12,11 @@ def mixup(im, labels, segments, im2, labels2, segments2): - # Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf + """ + Applies MixUp augmentation blending two images, labels, and segments with a random ratio. + + See https://arxiv.org/pdf/1710.09412.pdf + """ r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0 im = (im * r + im2 * (1 - r)).astype(np.uint8) labels = np.concatenate((labels, labels2), 0) diff --git a/utils/segment/dataloaders.py b/utils/segment/dataloaders.py index b0b3a7424216..9d2e9bef0b09 100644 --- a/utils/segment/dataloaders.py +++ b/utils/segment/dataloaders.py @@ -123,6 +123,7 @@ def __init__( self.overlap = overlap def __getitem__(self, index): + """Returns a transformed item from the dataset at the specified index, handling indexing and image weighting.""" index = self.indices[index] # linear, shuffled, or image_weights hyp = self.hyp @@ -230,7 +231,7 @@ def __getitem__(self, index): return (torch.from_numpy(img), labels_out, self.im_files[index], shapes, masks) def load_mosaic(self, index): - # YOLOv5 4-mosaic loader. Loads 1 image + 3 random images into a 4-image mosaic + """Loads 1 image + 3 random images into a 4-image YOLOv5 mosaic, adjusting labels and segments accordingly.""" labels4, segments4 = [], [] s = self.img_size yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border) # mosaic center x, y @@ -291,6 +292,7 @@ def load_mosaic(self, index): @staticmethod def collate_fn(batch): + """Custom collation function for DataLoader, batches images, labels, paths, shapes, and segmentation masks.""" img, label, path, shapes, masks = zip(*batch) # transposed batched_masks = torch.cat(masks, 0) for i, l in enumerate(label): diff --git a/utils/segment/general.py b/utils/segment/general.py index 8cbc745b4a90..f292496c0da9 100644 --- a/utils/segment/general.py +++ b/utils/segment/general.py @@ -144,7 +144,9 @@ def masks_iou(mask1, mask2, eps=1e-7): def masks2segments(masks, strategy="largest"): - # Convert masks(n,160,160) into segments(n,xy) + """Converts binary (n,160,160) masks to polygon segments with options for concatenation or selecting the largest + segment. + """ segments = [] for x in masks.int().cpu().numpy().astype("uint8"): c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0] diff --git a/utils/segment/loss.py b/utils/segment/loss.py index 1e007271fa9c..29f1bcbb7e77 100644 --- a/utils/segment/loss.py +++ b/utils/segment/loss.py @@ -12,6 +12,9 @@ class ComputeLoss: # Compute losses def __init__(self, model, autobalance=False, overlap=False): + """Initializes the compute loss function for YOLOv5 models with options for autobalancing and overlap + handling. + """ self.sort_obj_iou = False self.overlap = overlap device = next(model.parameters()).device # get model device @@ -109,13 +112,15 @@ def __call__(self, preds, targets, masks): # predictions, targets, model return loss * bs, torch.cat((lbox, lseg, lobj, lcls)).detach() def single_mask_loss(self, gt_mask, pred, proto, xyxy, area): - # Mask loss for one image + """Calculates and normalizes single mask loss for YOLOv5 between predicted and ground truth masks.""" pred_mask = (pred @ proto.view(self.nm, -1)).view(-1, *proto.shape[1:]) # (n,32) @ (32,80,80) -> (n,80,80) loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none") return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean() def build_targets(self, p, targets): - # Build targets for compute_loss(), input targets(image,class,x,y,w,h) + """Prepares YOLOv5 targets for loss computation; inputs targets (image, class, x, y, w, h), output target + classes/boxes. + """ na, nt = self.na, targets.shape[0] # number of anchors, targets tcls, tbox, indices, anch, tidxs, xywhn = [], [], [], [], [], [] gain = torch.ones(8, device=self.device) # normalized to gridspace gain diff --git a/utils/segment/metrics.py b/utils/segment/metrics.py index 7811e7eb364a..973b398eb6b9 100644 --- a/utils/segment/metrics.py +++ b/utils/segment/metrics.py @@ -7,7 +7,7 @@ def fitness(x): - # Model fitness as a weighted combination of metrics + """Evaluates model fitness by a weighted sum of 8 metrics, `x`: [N,8] array, weights: [0.1, 0.9] for mAP and F1.""" w = [0.0, 0.0, 0.1, 0.9, 0.0, 0.0, 0.1, 0.9] return (x[:, :8] * w).sum(1) @@ -128,6 +128,7 @@ def class_result(self, i): return (self.p[i], self.r[i], self.ap50[i], self.ap[i]) def get_maps(self, nc): + """Calculates and returns mean Average Precision (mAP) for each class given number of classes `nc`.""" maps = np.zeros(nc) + self.map for i, c in enumerate(self.ap_class_index): maps[c] = self.ap[i] @@ -162,17 +163,22 @@ def update(self, results): self.metric_mask.update(list(results["masks"].values())) def mean_results(self): + """Computes and returns the mean results for both box and mask metrics by summing their individual means.""" return self.metric_box.mean_results() + self.metric_mask.mean_results() def class_result(self, i): + """Returns the sum of box and mask metric results for a specified class index `i`.""" return self.metric_box.class_result(i) + self.metric_mask.class_result(i) def get_maps(self, nc): + """Calculates and returns the sum of mean average precisions (mAPs) for both box and mask metrics for `nc` + classes. + """ return self.metric_box.get_maps(nc) + self.metric_mask.get_maps(nc) @property def ap_class_index(self): - # boxes and masks have the same ap_class_index + """Returns the class index for average precision, shared by both box and mask metrics.""" return self.metric_box.ap_class_index diff --git a/utils/segment/plots.py b/utils/segment/plots.py index 0e30c61be66f..ce01988be937 100644 --- a/utils/segment/plots.py +++ b/utils/segment/plots.py @@ -15,7 +15,7 @@ @threaded def plot_images_and_masks(images, targets, masks, paths=None, fname="images.jpg", names=None): - # Plot image grid with labels + """Plots a grid of images, their labels, and masks with optional resizing and annotations, saving to fname.""" if isinstance(images, torch.Tensor): images = images.cpu().float().numpy() if isinstance(targets, torch.Tensor): @@ -111,7 +111,11 @@ def plot_images_and_masks(images, targets, masks, paths=None, fname="images.jpg" def plot_results_with_masks(file="path/to/results.csv", dir="", best=True): - # Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv') + """ + Plots training results from CSV files, plotting best or last result highlights based on `best` parameter. + + Example: from utils.plots import *; plot_results('path/to/results.csv') + """ save_dir = Path(file).parent if file else Path(dir) fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True) ax = ax.ravel() diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 6bc4b4c7fd04..c2c760efa404 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -34,7 +34,8 @@ def smart_inference_mode(torch_1_9=check_version(torch.__version__, "1.9.0")): - # Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator + """Applies torch.inference_mode() if torch>=1.9.0, else torch.no_grad() as a decorator for functions.""" + def decorate(fn): return (torch.inference_mode if torch_1_9 else torch.no_grad)()(fn) @@ -42,7 +43,9 @@ def decorate(fn): def smartCrossEntropyLoss(label_smoothing=0.0): - # Returns nn.CrossEntropyLoss with label smoothing enabled for torch>=1.10.0 + """Returns a CrossEntropyLoss with optional label smoothing for torch>=1.10.0; warns if smoothing on lower + versions. + """ if check_version(torch.__version__, "1.10.0"): return nn.CrossEntropyLoss(label_smoothing=label_smoothing) if label_smoothing > 0: @@ -51,7 +54,7 @@ def smartCrossEntropyLoss(label_smoothing=0.0): def smart_DDP(model): - # Model DDP creation with checks + """Initializes DistributedDataParallel (DDP) for model training, respecting torch version constraints.""" assert not check_version(torch.__version__, "1.12.0", pinned=True), ( "torch==1.12.0 torchvision==0.13.0 DDP training is not supported due to a known issue. " "Please upgrade or downgrade torch to use DDP. See https://github.com/ultralytics/yolov5/issues/8395" @@ -63,7 +66,7 @@ def smart_DDP(model): def reshape_classifier_output(model, n=1000): - # Update a TorchVision classification model to class count 'n' if required + """Reshapes last layer of model to match class count 'n', supporting Classify, Linear, Sequential types.""" from models.common import Classify name, m = list((model.model if hasattr(model, "model") else model).named_children())[-1] # last module @@ -87,7 +90,9 @@ def reshape_classifier_output(model, n=1000): @contextmanager def torch_distributed_zero_first(local_rank: int): - # Decorator to make all processes in distributed training wait for each local_master to do something + """Context manager ensuring ordered operations in distributed training by making all processes wait for the leading + process. + """ if local_rank not in [-1, 0]: dist.barrier(device_ids=[local_rank]) yield @@ -96,7 +101,7 @@ def torch_distributed_zero_first(local_rank: int): def device_count(): - # Returns number of CUDA devices available. Safe version of torch.cuda.device_count(). Supports Linux and Windows + """Returns the number of available CUDA devices; works on Linux and Windows by invoking `nvidia-smi`.""" assert platform.system() in ("Linux", "Windows"), "device_count() only supported on Linux or Windows" try: cmd = "nvidia-smi -L | wc -l" if platform.system() == "Linux" else 'nvidia-smi -L | find /c /v ""' # Windows @@ -106,7 +111,7 @@ def device_count(): def select_device(device="", batch_size=0, newline=True): - # device = None or 'cpu' or 0 or '0' or '0,1,2,3' + """Selects computing device (CPU, CUDA GPU, MPS) for YOLOv5 model deployment, logging device info.""" s = f"YOLOv5 🚀 {git_describe() or file_date()} Python-{platform.python_version()} torch-{torch.__version__} " device = str(device).strip().lower().replace("cuda:", "").replace("none", "") # to string, 'cuda:0' to '0' cpu = device == "cpu" @@ -143,7 +148,7 @@ def select_device(device="", batch_size=0, newline=True): def time_sync(): - # PyTorch-accurate time + """Synchronizes PyTorch for accurate timing, leveraging CUDA if available, and returns the current time.""" if torch.cuda.is_available(): torch.cuda.synchronize() return time.time() @@ -203,16 +208,19 @@ def profile(input, ops, n=10, device=None): def is_parallel(model): - # Returns True if model is of type DP or DDP + """Checks if the model is using Data Parallelism (DP) or Distributed Data Parallelism (DDP).""" return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) def de_parallel(model): - # De-parallelize a model: returns single-GPU model if model is of type DP or DDP + """Returns a single-GPU model by removing Data Parallelism (DP) or Distributed Data Parallelism (DDP) if applied.""" return model.module if is_parallel(model) else model def initialize_weights(model): + """Initializes weights of Conv2d, BatchNorm2d, and activations (Hardswish, LeakyReLU, ReLU, ReLU6, SiLU) in the + model. + """ for m in model.modules(): t = type(m) if t is nn.Conv2d: @@ -225,12 +233,14 @@ def initialize_weights(model): def find_modules(model, mclass=nn.Conv2d): - # Finds layer indices matching module class 'mclass' + """Finds and returns list of layer indices in `model.module_list` matching the specified `mclass`.""" return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)] def sparsity(model): - # Return global model sparsity + """Calculates and returns the global sparsity of a model as the ratio of zero-valued parameters to total + parameters. + """ a, b = 0, 0 for p in model.parameters(): a += p.numel() @@ -239,7 +249,7 @@ def sparsity(model): def prune(model, amount=0.3): - # Prune model to requested global sparsity + """Prunes Conv2d layers in a model to a specified sparsity using L1 unstructured pruning.""" import torch.nn.utils.prune as prune for name, m in model.named_modules(): @@ -250,7 +260,11 @@ def prune(model, amount=0.3): def fuse_conv_and_bn(conv, bn): - # Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/ + """ + Fuses Conv2d and BatchNorm2d layers into a single Conv2d layer. + + See https://tehnokv.com/posts/fusing-batchnorm-and-conv/. + """ fusedconv = ( nn.Conv2d( conv.in_channels, @@ -280,7 +294,11 @@ def fuse_conv_and_bn(conv, bn): def model_info(model, verbose=False, imgsz=640): - # Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320] + """ + Prints model summary including layers, parameters, gradients, and FLOPs; imgsz may be int or list. + + Example: img_size=640 or img_size=[640, 320] + """ n_p = sum(x.numel() for x in model.parameters()) # number parameters n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients if verbose: @@ -319,7 +337,7 @@ def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416) def copy_attr(a, b, include=(), exclude=()): - # Copy attributes from b to a, options to only include [...] and to exclude [...] + """Copies attributes from object b to a, optionally filtering with include and exclude lists.""" for k, v in b.__dict__.items(): if (len(include) and k not in include) or k.startswith("_") or k in exclude: continue @@ -328,7 +346,11 @@ def copy_attr(a, b, include=(), exclude=()): def smart_optimizer(model, name="Adam", lr=0.001, momentum=0.9, decay=1e-5): - # YOLOv5 3-param group optimizer: 0) weights with decay, 1) weights no decay, 2) biases no decay + """ + Initializes YOLOv5 smart optimizer with 3 parameter groups for different decay configurations. + + Groups are 0) weights with decay, 1) weights no decay, 2) biases no decay. + """ g = [], [], [] # optimizer parameter groups bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d() for v in model.modules(): @@ -361,7 +383,7 @@ def smart_optimizer(model, name="Adam", lr=0.001, momentum=0.9, decay=1e-5): def smart_hub_load(repo="ultralytics/yolov5", model="yolov5s", **kwargs): - # YOLOv5 torch.hub.load() wrapper with smart error/issue handling + """YOLOv5 torch.hub.load() wrapper with smart error handling, adjusting torch arguments for compatibility.""" if check_version(torch.__version__, "1.9.1"): kwargs["skip_validation"] = True # validation causes GitHub API rate limit errors if check_version(torch.__version__, "1.12.0"): @@ -373,7 +395,7 @@ def smart_hub_load(repo="ultralytics/yolov5", model="yolov5s", **kwargs): def smart_resume(ckpt, optimizer, ema=None, weights="yolov5s.pt", epochs=300, resume=True): - # Resume training from a partially trained checkpoint + """Resumes training from a checkpoint, updating optimizer, ema, and epochs, with optional resume verification.""" best_fitness = 0.0 start_epoch = ckpt["epoch"] + 1 if ckpt["optimizer"] is not None: @@ -397,12 +419,14 @@ def smart_resume(ckpt, optimizer, ema=None, weights="yolov5s.pt", epochs=300, re class EarlyStopping: # YOLOv5 simple early stopper def __init__(self, patience=30): + """Initializes simple early stopping mechanism for YOLOv5, with adjustable patience for non-improving epochs.""" self.best_fitness = 0.0 # i.e. mAP self.best_epoch = 0 self.patience = patience or float("inf") # epochs to wait after fitness stops improving to stop self.possible_stop = False # possible stop may occur next epoch def __call__(self, epoch, fitness): + """Evaluates if training should stop based on fitness improvement and patience, returning a boolean.""" if fitness >= self.best_fitness: # >= 0 to allow for early zero-fitness stage of training self.best_epoch = epoch self.best_fitness = fitness @@ -426,7 +450,9 @@ class ModelEMA: """ def __init__(self, model, decay=0.9999, tau=2000, updates=0): - # Create EMA + """Initializes EMA with model parameters, decay rate, tau for decay adjustment, and update count; sets model to + evaluation mode. + """ self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA self.updates = updates # number of EMA updates self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs) @@ -434,7 +460,7 @@ def __init__(self, model, decay=0.9999, tau=2000, updates=0): p.requires_grad_(False) def update(self, model): - # Update EMA parameters + """Updates the Exponential Moving Average (EMA) parameters based on the current model's parameters.""" self.updates += 1 d = self.decay(self.updates) @@ -446,5 +472,7 @@ def update(self, model): # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32' def update_attr(self, model, include=(), exclude=("process_group", "reducer")): - # Update EMA attributes + """Updates EMA attributes by copying specified attributes from model to EMA, excluding certain attributes by + default. + """ copy_attr(self.ema, model, include, exclude) diff --git a/utils/triton.py b/utils/triton.py index 9584d07fbcf0..87524c9c7801 100644 --- a/utils/triton.py +++ b/utils/triton.py @@ -71,6 +71,7 @@ def __call__(self, *args, **kwargs) -> typing.Union[torch.Tensor, typing.Tuple[t return result[0] if len(result) == 1 else result def _create_inputs(self, *args, **kwargs): + """Creates input tensors from args or kwargs, not both; raises error if none or both are provided.""" args_len, kwargs_len = len(args), len(kwargs) if not args_len and not kwargs_len: raise RuntimeError("No inputs provided.") diff --git a/val.py b/val.py index 6cc1d37a0a26..1c8c65ba89aa 100644 --- a/val.py +++ b/val.py @@ -62,7 +62,7 @@ def save_one_txt(predn, save_conf, shape, file): - # Save one txt result + """Saves one detection result to a txt file in normalized xywh format, optionally including confidence.""" gn = torch.tensor(shape)[[1, 0, 1, 0]] # normalization gain whwh for *xyxy, conf, cls in predn.tolist(): xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh @@ -72,7 +72,11 @@ def save_one_txt(predn, save_conf, shape, file): def save_one_json(predn, jdict, path, class_map): - # Save one JSON result {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236} + """ + Saves one JSON detection result with image ID, category ID, bounding box, and score. + + Example: {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236} + """ image_id = int(path.stem) if path.stem.isnumeric() else path.stem box = xyxy2xywh(predn[:, :4]) # xywh box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner @@ -359,6 +363,7 @@ def run( def parse_opt(): + """Parses command-line options for YOLOv5 model inference configuration.""" parser = argparse.ArgumentParser() parser.add_argument("--data", type=str, default=ROOT / "data/coco128.yaml", help="dataset.yaml path") parser.add_argument("--weights", nargs="+", type=str, default=ROOT / "yolov5s.pt", help="model path(s)") @@ -391,6 +396,9 @@ def parse_opt(): def main(opt): + """Executes YOLOv5 tasks like training, validation, testing, speed, and study benchmarks based on provided + options. + """ check_requirements(ROOT / "requirements.txt", exclude=("tensorboard", "thop")) if opt.task in ("train", "val", "test"): # run normally