Skip to content

Commit

Permalink
fix: support task map attribute for yolo >= 8.0.44 (#67)
Browse files Browse the repository at this point in the history
* support task map attribute for yolo >= 8.0.44

* update version

---------

Co-authored-by: fatih cagatay akyon <[email protected]>
  • Loading branch information
wadhah101 and fcakyon authored Feb 3, 2024
1 parent 0350758 commit f50987f
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 7 deletions.
6 changes: 6 additions & 0 deletions tests/test_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,16 @@
hub_id = "ultralyticsplus/yolov8s"


# for ultralytics < 8.0.44
def test_load_from_hub():
path = download_from_hub(hub_id)


# for ultralytics >= 8.0.44
def test_load_from_hub_yolo_8_0_44():
model = YOLO("keremberke/yolov8n-table-extraction")


def test_yolo_from_hub():
model = YOLO(hub_id)

Expand Down
2 changes: 1 addition & 1 deletion ultralyticsplus/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .hf_utils import download_from_hub, push_to_hfhub
from .ultralytics_utils import YOLO, postprocess_classify_output, render_result

__version__ = "0.0.29"
__version__ = "0.1.0"
28 changes: 22 additions & 6 deletions ultralyticsplus/ultralytics_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,28 @@ def _load_from_hf_hub(self, weights: str, hf_token=None):
self.task = self.model.args["task"]
self.overrides = self.model.args
self._reset_ckpt_args(self.overrides)
(
self.ModelClass,
self.TrainerClass,
self.ValidatorClass,
self.PredictorClass,
) = self._assign_ops_from_task()

# for loading a model with ultralytics <8.0.44
if hasattr(self, "_assign_ops_from_task"):
(
self.ModelClass,
self.TrainerClass,
self.ValidatorClass,
self.PredictorClass,
) = self._assign_ops_from_task()

# for loading a model with ultralytics >=8.0.44
else:
if self.task not in self.task_map:
raise ValueError(
f"Task '{self.task}' not supported. Supported tasks: {list(self.task_map.keys())}"
)
(
self.ModelClass,
self.TrainerClass,
self.ValidatorClass,
self.PredictorClass,
) = self.task_map[self.task]


def render_result(
Expand Down

0 comments on commit f50987f

Please sign in to comment.