Skip to content

Commit

Permalink
fix typo
Browse files Browse the repository at this point in the history
  • Loading branch information
songtianhui committed Mar 30, 2023
1 parent ea4a311 commit 5532478
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion lib/models/mixformer_convmae/mixformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def get_mixformer_convmae(config, train):

if config.MODEL.BACKBONE.PRETRAINED and train:
ckpt_path = config.MODEL.BACKBONE.PRETRAINED_PATH
ckpt = torch.load(ckpt_path, map_location='cpu') #['model']
ckpt = torch.load(ckpt_path, map_location='cpu')['model']
new_dict = {}
for k, v in ckpt.items():
if 'pos_embed' not in k and 'mask_token' not in k:
Expand Down
2 changes: 1 addition & 1 deletion lib/models/mixformer_convmae/mixformer_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def get_mixformer_convmae(config, train):

if config.MODEL.BACKBONE.PRETRAINED and train:
ckpt_path = config.MODEL.BACKBONE.PRETRAINED_PATH
ckpt = torch.load(ckpt_path, map_location='cpu') #['model']
ckpt = torch.load(ckpt_path, map_location='cpu')['model']
new_dict = {}
for k, v in ckpt.items():
if 'pos_embed' not in k and 'mask_token' not in k:
Expand Down
2 changes: 1 addition & 1 deletion lib/test/tracker/mixformer_convmae.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class MixFormer(BaseTracker):
def __init__(self, params, dataset_name):
super(MixFormer, self).__init__(params)
network = build_mixformer_convmae(params.cfg)
network = build_mixformer_convmae(params.cfg, train=False)
network.load_state_dict(torch.load(self.params.checkpoint, map_location='cpu')['net'], strict=True)
self.cfg = params.cfg
self.network = network.cuda()
Expand Down
2 changes: 1 addition & 1 deletion lib/test/tracker/mixformer_cvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class MixFormer(BaseTracker):
def __init__(self, params, dataset_name):
super(MixFormer, self).__init__(params)
network = build_mixformer_cvt(params.cfg)
network = build_mixformer_cvt(params.cfg, train=False)
network.load_state_dict(torch.load(self.params.checkpoint, map_location='cpu')['net'], strict=True)
self.cfg = params.cfg
self.network = network.cuda()
Expand Down
2 changes: 1 addition & 1 deletion lib/test/tracker/mixformer_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class MixFormer(BaseTracker):
def __init__(self, params, dataset_name):
super(MixFormer, self).__init__(params)
network = build_mixformer_vit(params.cfg)
network = build_mixformer_vit(params.cfg, train=False)
network.load_state_dict(torch.load(self.params.checkpoint, map_location='cpu')['net'], strict=True)
self.cfg = params.cfg
self.network = network.cuda()
Expand Down

0 comments on commit 5532478

Please sign in to comment.