Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error loading custom yolov5-rt-stack model in C++ converted from yolov5_v4.0 - file not found: archive/constants.pkl #142

Closed
mattpopovich opened this issue Jul 15, 2021 · 14 comments
Labels
bug / fix Something isn't working

Comments

@mattpopovich
Copy link
Contributor

🐛 Bug

After training a yolov5-v4.0 model, I then took its best.pt weights and converted them to yolov5-rt-stack weights via python in update_module_state_from_ultralytics(). I then took these yolov5-rt-stack weights and passed them as an argument to the stock ./yolo_inference program.
Is that the correct procedure?

I get the following error:

root@pc:~yolov5-rt-stack/deployment/libtorch/build# ./yolo_inference --input_source path/to/jpg --checkpoint path/to/yolov5-rt-stack-yolov5_v4-model.pt --labelmap path/to/names 
>>> Set CPU mode
>>> Loading model
>>> Error loading the model: [enforce fail at inline_container.cc:222] . file not found: archive/constants.pkl
frame #0: c10::ThrowEnforceNotMet(char const*, int, char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, void const*) + 0x68 (0x7f235b0bda28 in /usr/local/lib/python3.8/dist-packages/torch/lib/libc10.so)
frame #1: caffe2::serialize::PyTorchStreamReader::getRecordID(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xda (0x7f23506bb70a in /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so)
frame #2: caffe2::serialize::PyTorchStreamReader::getRecord(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x38 (0x7f23506bb768 in /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so)
frame #3: torch::jit::readArchiveAndTensors(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, c10::optional<std::function<c10::StrongTypePtr (c10::QualifiedName const&)> >, c10::optional<std::function<c10::intrusive_ptr<c10::ivalue::Object, c10::detail::intrusive_target_default_null_type<c10::ivalue::Object> > (c10::StrongTypePtr, c10::IValue)> >, c10::optional<c10::Device>, caffe2::serialize::PyTorchStreamReader&) + 0xab (0x7f2351d732db in /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so)
frame #4: <unknown function> + 0x3d01835 (0x7f2351d73835 in /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so)
frame #5: <unknown function> + 0x3d04013 (0x7f2351d76013 in /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so)
frame #6: torch::jit::load(std::shared_ptr<caffe2::serialize::ReadAdapterInterface>, c10::optional<c10::Device>, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > > >&) + 0x1ab (0x7f2351d7710b in /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so)
frame #7: torch::jit::load(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, c10::optional<c10::Device>, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > > >&) + 0xc2 (0x7f2351d78dd2 in /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so)
frame #8: torch::jit::load(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, c10::optional<c10::Device>) + 0x6a (0x7f2351d78eba in /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so)
frame #9: <unknown function> + 0x24709 (0x562e7d327709 in ./yolo_inference)
frame #10: __libc_start_main + 0xf3 (0x7f230cc060b3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #11: <unknown function> + 0x2315e (0x562e7d32615e in ./yolo_inference)

I do not have any errors and everything runs correctly when using yolov5-rt-stack/test/tracing/yolov5s.torchscript.pt weights with ./yolo_inference, so I believe I have everything installed, compiling, and running correctly.

Expected behavior

./yolo_inference runs correctly, finishes, and outputs detections.

Environment

Click to display environment

root@pc:~# python3 -m torch.utils.collect_env
Collecting environment information...
PyTorch version: 1.8.0a0+56b43f4
Is debug build: False
CUDA used to build PyTorch: 11.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.1 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: Could not collect
CMake version: version 3.16.3

Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: 
GPU 0: GeForce GTX 1080
GPU 1: GeForce GTX 1080
GPU 2: GeForce GTX 1080

Nvidia driver version: 460.84
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.1.0
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.21.0
[pip3] pytorch-lightning==1.3.8
[pip3] torch==1.8.0a0+56b43f4
[pip3] torchmetrics==0.4.1
[pip3] torchvision==0.9.0a0+01dfa8e
[conda] Could not collect

  • PyTorch / torchvision Version (e.g., 1.0 / 0.4.0): 1.8.0 / 0.9.0
  • OS (e.g., Linux): Ubuntu 20.04
  • How you installed PyTorch / torchvision (conda, pip, source): source / source
  • Python version: 3.8
  • CUDA/cuDNN version: 11.2
  • GPU models and configuration: 3x GeForce GTX 1080

Additional context

I trained my yolov5_v4.0 model with the following command (on 4x Tesla V100's):
train.py --local_rank=0 --img 512 --batch 128 --epochs 500 --data path/to/yaml --weights --cfg models/yolov5s.yaml --device 0,1,2,3 --name experiment-2

@mattpopovich mattpopovich added the bug / fix Something isn't working label Jul 15, 2021
@mattpopovich
Copy link
Contributor Author

mattpopovich commented Jul 15, 2021

Going to try and test on some other models I've trained. I might try and re-train and edit some of the arguments that I pass during training (such as removing --img 512) to see if that helps fix anything. Just wanted to make this issue in the interim.

@mattpopovich
Copy link
Contributor Author

Update, looks like I was training off of yolov5 code as of March 14. yolov5-v4.0 was released on Jan 4. So I might have been in a weird hybrid between v4.0 and v5.0.

I'm going to check out v4.0 of yolov5, retrain, and try again.

@zhiqwang
Copy link
Owner

zhiqwang commented Jul 16, 2021

Hi @mattpopovich

Did you try to generate the torchscript checkpoint from yolov5-rt-stack-yolov5_v4-model.pt? the torchscript checkpoint is not vanilla pytorch pth checkpoint, we can use the following scripts to generate the torchscript checkpoint.

from yolort.models import yolov5s
model_yolort = yolov5s()
model_yolort.model.load_state_dict(torch.load("yolov5-rt-stack-yolov5_v4-model.pt"))  # Edited
# jit scripting the pytorch model
traced_model = torch.jit.script(model_yolort)
traced_model.save("yolov5-rt-stack-yolov5_v4-model.torchscript.pt")

And then use the generated yolov5-rt-stack-yolov5_v4-model.torchscript.pt to inference on the libtorch backend.

If the inference results on the pytorch backend between yolov5 v5.0 (or master branch) and yolort are consistent, that should be ok. Actually the model structure of v4.0 and v5.0 are very similar, in other words, the current version of yolort r4.0 is compatible with yolov5 v5.0 or master branch.

I guess that yolov5 v5.0 fuse their model automatically, that's why we have problem to convert the yolov5 in the v5.0 and master branch.

@mattpopovich
Copy link
Contributor Author

My "pipeline" pseudocode is as follows:

best.pt = ultralytics-yolov5(dataset)

model = update_module_state_from_ultralytics(arch='yolov5s',
                                             version='v4.0',
                                             custom_path_or_model=/path/to/best.pt)
torch.save(model.state_dict(), /path/to/yolort-best.pt)

./yolo_inference --input_source /path/to/img --checkpoint /path/to/yolort-best.pt

I believe you're saying I'm missing a step in my pipeline and it should be like (>>> denotes added lines):

best.pt = ultralytics-yolov5(dataset)

model = update_module_state_from_ultralytics(arch='yolov5s',
                                             version='v4.0',
                                             custom_path_or_model=/path/to/best.pt)
torch.save(model.state_dict(), /path/to/yolort-best.pt)

>>>from yolort.models import yolov5s
>>>updated_model = yolov5s()
>>>updated_model.load_state_dict(torch.load(/path/to/yolort-best.pt))
>>>model_script = torch.jit.script(updated_model)
>>>model_script.save(/path/to/yolort-best.torchscript.pt)

./yolo_inference --input_source /path/to/img --checkpoint /path/to/yolort-best.torchscript.pt

Is that correct?

@zhiqwang
Copy link
Owner

zhiqwang commented Jul 16, 2021

I believe you're saying I'm missing a step in my pipeline and it should be like (>>> denotes added lines):

Yep! And I guess there is one minor fix in your snippets,

updated_model.model.load_state_dict(torch.load(/path/to/yolort-best.pt))

Because we wrapped the vanilla converted yolov5 model in YOLOModule with

https://github.com/zhiqwang/yolov5-rt-stack/blob/53a9d6186380663b9a320f73e408683334da1a0d/yolort/models/yolo_module.py#L48-L49

And we do the pre-processing (with YOLOTransform, another implementation of letterbox in yolov5) in the YOLOModule.

https://github.com/zhiqwang/yolov5-rt-stack/blob/53a9d6186380663b9a320f73e408683334da1a0d/yolort/models/yolo_module.py#L51

@mattpopovich
Copy link
Contributor Author

Are you saying to go from

updated_model.load_state_dict(torch.load(/path/to/yolort-best.pt))

to

updated_model.model.load_state_dict(torch.load(/path/to/yolort-best.pt)) 

???

If so, I am getting the error AttributeError: 'YOLO' object has no attribute 'model':

Adding AutoShape... 
Traceback (most recent call last):
  File "convert_ultralytics_to_rt-stack.py", line 78, in <module>
    model.model.load_state_dict(torch.load(yolort_weight_path))
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 947, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'YOLO' object has no attribute 'model'

@zhiqwang
Copy link
Owner

zhiqwang commented Jul 16, 2021

If so, I am getting the error AttributeError: 'YOLO' object has no attribute 'model':

Sorry @mattpopovich , Try the following instead

updated_model.model.model.load_state_dict(torch.load(/path/to/yolort-best.pt)) 

The current model is wrapped one layer deeper than I thought (

@mattpopovich
Copy link
Contributor Author

'YOLO' object has no attribute 'model', so I cannot access updated_model.model, let alone updated_model.model.model.

@mattpopovich
Copy link
Contributor Author

Regardless, I don't think that is the issue. The conversion works successfully after using updated_model.load_state_dict(torch.load(/path/to/yolort-best.pt)), the error still remains during LibTorch C++ inference... although it is a different error now.
And just to make sure the error wasn't with my custom model, I am now using yolov5s-v4.0.pt to start the conversion and getting the same error as my custom model:

root@pc:yolov5-rt-stack/deployment/libtorch/build# ./yolo_inference --input_source dog.jpg --checkpoint yolov5s-v4.0-yolort.torchscript.pt --labelmap coco.names 
>>> Set CPU mode
>>> Loading model
>>> Model loaded
>>> Run once on empty image
terminate called after throwing an instance of 'c10::Error'
  what():  forward() Expected a value of type 'Tensor' for argument 'samples' but instead found type 'List[Tensor]'.
Position: 1
Declaration: forward(__torch__.yolort.models.yolo.YOLO self, Tensor samples, Tensor? targets=None) -> ((Dict(str, Tensor), Dict(str, Tensor)[]))
Exception raised from checkArg at /resources/pytorch/aten/src/ATen/core/function_schema_inl.h:159 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x6c (0x7f9714b86b5c in /usr/local/lib/python3.8/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xfa (0x7f9714b4dd4c in /usr/local/lib/python3.8/dist-packages/torch/lib/libc10.so)
frame #2: <unknown function> + 0x13d87d2 (0x7f9708f137d2 in /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so)
frame #3: torch::jit::GraphFunction::operator()(std::vector<c10::IValue, std::allocator<c10::IValue> >, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, c10::IValue, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, c10::IValue> > > const&) + 0x31 (0x7f970b5587a1 in /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so)
frame #4: torch::jit::Method::operator()(std::vector<c10::IValue, std::allocator<c10::IValue> >, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, c10::IValue, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, c10::IValue> > > const&) + 0x168 (0x7f970b568c88 in /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so)
frame #5: <unknown function> + 0x3013f (0x56554b8f213f in ./yolo_inference)
frame #6: <unknown function> + 0x2492a (0x56554b8e692a in ./yolo_inference)
frame #7: __libc_start_main + 0xf3 (0x7f96c66cf0b3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #8: <unknown function> + 0x2315e (0x56554b8e515e in ./yolo_inference)

Aborted (core dumped)

@mattpopovich
Copy link
Contributor Author

If you're interested in reproducing the above error, you can run the below convert_ultralytics_to_rt-stack.py, just change the file paths.

You can get yolov5s-v4.0.pt and bus.jpg from those links.

Once you get the file yolov5s-v4.0-RT.torchscript.pt, you can run it via LibTorch in C++ with:

root@pc:yolov5-rt-stack/deployment/libtorch/build# ./yolo_inference --input_source bus.jpg --checkpoint yolov5s-v4.0-RT.torchscript.pt --labelmap coco.names

You should get the Expected a value of type 'Tensor' for argument 'samples' but instead found type 'List[Tensor]' error I mentioned above.

If you think the below script is handy, I'd be happy to clean it up and add it to the repo. It is mostly a duplicate of your Jupyter notebook though!

Click to display convert_ultralytics_to_rt-stack.py

# When given a path to a yolov5-v4.0 weights file, 
#   will convert it to a yolov5-rt-stack weights file, 
#   check it for equivalence, 
#   then convert it to torchscript for inference via LibTorch in C++
# With inspiration from: https://github.com/zhiqwang/yolov5-rt-stack/blob/master/notebooks/inference-pytorch-export-libtorch.ipynb
# Author: Matt Popovich (mattpopovich.com)

import sys 
import os 
import torch 
import cv2

sys.path.insert(0, "/path/to/yolov5-rt-stack")

from yolort.utils.image_utils import (
    letterbox,
    non_max_suppression,
    plot_one_box,
    scale_coords,
    color_list,
)
from yolort.utils import (
    cv2_imshow,
    get_image_from_url,
    read_image_to_tensor,
    update_module_state_from_ultralytics,
)

# Define static variables
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"
device = torch.device('cuda')
is_half = False
conf = 0.4
iou = 0.45
img_size = 512

# Define file paths 
img_path = 'bus.jpg'
ultralytics_weights_path =  'yolov5s-v4.0.pt'
yolort_weight_path =        'yolov5s-v4.0-RT.pt' 
yolort_script_path =        'yolov5s-v4.0-RT.torchscript.pt' 
label_path = "coco.names"

# Check paths given 
if not os.path.isfile(ultralytics_weights_path):
    sys.exit("ERROR: filename specified does not exist: " + ultralytics_weights_path)

if os.path.isfile(img_path):
    img_raw = cv2.imread(img_path)
else:
    sys.exit("ERROR: filename specified for img_path does not exist: " + img_path)

with open(label_path) as f:
    num_classes = 0 
    # Count number of lines that have text in them
    for line in f:
        # If line is empty, it will still contain '\n' = len of 1
        num_classes += 1 if len(line) > 1 else 0  
print("number of lines in {} = {}".format(label_path, num_classes))

# Preprocess
img = letterbox(img_raw, new_shape=(img_size,img_size))[0]
img = read_image_to_tensor(img, is_half)
img = img.to(device)

## Load model as ultralytics and inference
model = torch.hub.load('ultralytics/yolov5', 'custom', path=ultralytics_weights_path, autoshape=False)
model = model.to(device)
model.conf = conf  # confidence threshold (0-1)
model.iou = iou  # NMS IoU threshold (0-1)
model.classes = None  # (optional list) filter by class, i.e. = [0, 15, 16] for persons, cats and dogs
model.eval()

# Get actual anchors from ultralytics model
m = model.model[-1]  # get Detect() layer
anchor_grids = m.anchor_grid.view((3, -1)).cpu().tolist()  # get anchors

with torch.no_grad():
    ultralytics_dets = model(img[None])[0]
    ultralytics_dets = non_max_suppression(ultralytics_dets, conf, iou, agnostic=True)[0]

print(f'Detection results with ultralytics:\n{ultralytics_dets}')

model = update_module_state_from_ultralytics(arch='yolov5s',
                                             version='v4.0',
                                             custom_path_or_model=ultralytics_weights_path,
                                             set_fp16=is_half,
                                             num_classes=num_classes)

# The updated model checkpoint
torch.save(model.state_dict(), yolort_weight_path)

## Load model as yolort and inference
from yolort.models.yolo import yolov5_darknet_pan_s_r40 as yolov5s

model = yolov5s(score_thresh=conf, nms_thresh=iou, num_classes=num_classes, anchor_grids=anchor_grids)
model.load_state_dict(torch.load(yolort_weight_path))
model = model.to(device)

model.eval()

with torch.no_grad():
    yolort_dets = model(img[None])

print(f"Detection boxes with yolort:\n{yolort_dets[0]['boxes']}")

print(f"Detection scores with yolort:\n{yolort_dets[0]['scores']}")

print(f"Detection labels with yolort:\n{yolort_dets[0]['labels']}")

# Testing boxes
torch.testing.assert_allclose(
    yolort_dets[0]['boxes'], ultralytics_dets[:, :4], rtol=1e-05, atol=1e-07)
# Testing scores
torch.testing.assert_allclose(
    yolort_dets[0]['scores'], ultralytics_dets[:, 4], rtol=1e-05, atol=1e-07)
# Testing labels
torch.testing.assert_allclose(
    yolort_dets[0]['labels'], ultralytics_dets[:, 5].to(dtype=torch.int64), rtol=1e-05, atol=1e-07)

print("Exported model has been tested, and the result looks good!")

## Detect output visualization
# Get label names
from yolort.utils.image_utils import load_names

LABELS = load_names(label_path)
COLORS = color_list()

# Hah, that's the trick to rescale the box correctly
boxes = scale_coords(yolort_dets[0]['boxes'], img.shape[1:], img_raw.shape[:-1])

for box, label in zip(boxes.tolist(), yolort_dets[0]['labels'].tolist()):
    img_raw = plot_one_box(box, img_raw, color=COLORS[label % len(COLORS)], label=LABELS[label])

# cv2_imshow(img_raw, imshow_scale=0.5)     # If in Jupyter notebook 
# cv2.imshow("img_raw", img_raw)            # If running script locally
# cv2.waitKey()

# Create script of yolov5-rt-stack model 
print("\nBeginning export of torchscript model...")
model_script = torch.jit.script(model)
model_script.eval()
model_script = model_script.to(device)
model_script.save(yolort_script_path)
print("Exported yolov5-rt-stack torchscript model to: ")
print("\t" + yolort_script_path)

@zhiqwang
Copy link
Owner

zhiqwang commented Jul 17, 2021

Hi @mattpopovich ,

Thank you for your feedback on this phenomenon, I'll try to reproduce this bug and try to fix this problem ASAP. One thing that can be determined with a high probability is that you don't need to check out v4.0 of yolov5, retrain your model .

@zhiqwang
Copy link
Owner

zhiqwang commented Jul 18, 2021

Hi, @mattpopovich ,

Thanks for your details analysis and scripts above, I think I got where our operations are different, there is one thing we need to be careful when calling yolort models. Actually we wrap the YOLOv5 in two places, one is YOLO and the other is YOLOModule, and we just use YOLOModule in the libtorch inference, but your scripts at below are using YOLO

from yolort.models.yolo import yolov5_darknet_pan_s_r40 as yolov5s  # This is `YOLO`

model = yolov5s(score_thresh=conf, nms_thresh=iou, num_classes=num_classes, anchor_grids=anchor_grids)
model.load_state_dict(torch.load(yolort_weight_path))
model = model.to(device)

These lines should be changed to

from yolort.models import yolov5s  # This is `YOLOModule`

# Make sure you are using the master branch here!
model_yolort = yolov5s(
    pretrained=False,
    num_classes=num_classes,
    anchor_grids=anchor_grids,  # same as above
    nms_thresh=0.45,
    score_thresh=0.3,
)
# Load your updated checkpoint here
model_yolort.model.load_state_dict(torch.load("yolov5s-v4.0-RT.pt"))
# jit scripting the pytorch model
model_yolort_script = torch.jit.script(model_yolort)
model_yolort_script.save("yolov5s-v4.0-RT.torchscript.pt")

Let me know if the above modification works for you!

@zhiqwang
Copy link
Owner

zhiqwang commented Jul 18, 2021

FYI, The pre-processing operation is in YOLOModule, we wrapped letterbox in YOLOv5 with YOLOTransform and call it in

https://github.com/zhiqwang/yolov5-rt-stack/blob/947956f3c472bd4395272e261fe6f9d46bfd8b61/yolort/models/yolo_module.py#L51

There are some minor differences in the behavior of letterbox and YOLOTransform, their effects are same, ultralytics use OpenCV (cv2.resize and cv2.copyMakeBorder) which cannot be traced by torch.jit.script to achieve this function, and we reimplement these with torch.nn.functional.

This difference should be able to be eliminated, but we do not have time to resolve this problem now. That's why we only test the behavior between YOLO and ultralytics/yolov5 in https://github.com/zhiqwang/yolov5-rt-stack/blob/master/notebooks/how-to-align-with-ultralytics-yolov5.ipynb .

And you can see #92 (comment) for the difference in mAP between letterbox and YOLOTransform.

BTW, the post-processing (nms) operation is wrapped in YOLO.

All contributions and suggestions are welcome here.

@mattpopovich
Copy link
Contributor Author

Let me know if the above modification works for you!

Great news, I ran a quick test tonight and it appears to work on the stock yolov5-v4.0 pre-trained MS COCO weights file! I will do further validation tomorrow.

Thanks for your details analysis and scripts above, I think I got where our operations are different, there is one thing we need to be careful when calling yolort models. Actually we wrap the YOLOv5 in two places, one is YOLO and the other is YOLOModule, and we just use YOLOModule in the libtorch inference, but your scripts at below are using YOLO

That was definitely the issue. I remember not remembering the path of how to import yolov5s and just grabbed a line from the first Jupyter notebook that I had open. I did not realize that they (YOLO vs YOLOModule) were different!

Really appreciate your extremely speedy responses in getting this resolved! Awesome repo you have here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants