You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello,when I train and load a model with Detectron2,I try to use torch-pruning to do model pruning.However,there always seems to be an error.
NotImplementedError Traceback (most recent call last)
d:\Phishpedia\prune\test.ipynb 单元格 4 line 6
54 torch.save(model, f'./rcnn_bet365_prune30_{i}.pth') # without .state_dict
55 model = torch.load(f'./rcnn_bet365_prune30_{i}.pth') # load the pruned model
---> 63 prune_pipeline(model,iterative_steps=2)
d:\Phishpedia\prune\test.ipynb 单元格 4 line 2
23 ignored_layers = []
24 # for m in model.modules():
25 # if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
26 # ignored_layers.append(m) # DO NOT prune the final classifier!
---> 28 pruner = tp.pruner.MetaPruner( # We can always choose MetaPruner if sparse training is not required.
29 model,
30 example_inputs,
31 importance=imp,
32 pruning_ratio=0.1, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
33 ignored_layers=ignored_layers,
34 isomorphic=True, # enable isomorphic pruning to improve global ranking
35 global_pruning=True, # global pruning
36 forward_fn=custom_forward_fn
37 )
39 ### iterative pruning
40 for i in range(iterative_steps):
...
File d:\detectron2-windows\detectron2-windows\detectron2\structures\instances.py:149, in Instances.iter(self)
148 def iter(self):
--> 149 raise NotImplementedError("Instances object is not iterable!")
NotImplementedError: Instances object is not iterable!
And my code is like:
加载训练好的faster-rcnn模型
import torch
from detectron2.engine import DefaultPredictor, DefaultTrainer
from detectron2.config import get_cfg
from detectron2.checkpoint import DetectionCheckpointer
import argparse
import torch_pruning as tp
cfg = get_cfg()
cfg.merge_from_file(cfg_path)
cfg.MODEL.WEIGHTS = weights_path
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = conf_threshold
# uncomment if you installed detectron2 cpu version
if not torch.cuda.is_available():
cfg.MODEL.DEVICE = 'cpu'
model = DefaultTrainer.build_model(cfg)
DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS)
return model
example_inputs = [{"image": torch.randn(3, height, width)}]
# 1. Importance criterion
imp = tp.importance.GroupNormImportance(p=2) # or GroupTaylorImportance(), GroupHessianImportance(), etc.
# 2. Initialize a pruner with the model and the importance criterion
ignored_layers = []
for m in model.modules():
if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
ignored_layers.append(m) # DO NOT prune the final classifier!
pruner = tp.pruner.MetaPruner( # We can always choose MetaPruner if sparse training is not required.
model,
example_inputs,
importance=imp,
pruning_ratio=0.3, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
ignored_layers=ignored_layers,
isomorphic=True, # enable isomorphic pruning to improve global ranking
global_pruning=True, # global pruning
)
### iterative pruning
for i in range(iterative_steps):
for group in pruner.step(interactive=True): # Warning: groups must be handled sequentially. Do not keep them as a list.
print(group)
# do whatever you like with the group
dep, idxs = group[0] # get the idxs
target_module = dep.target.module # get the root module
pruning_fn = dep.handler # get the pruning function
group.prune()
macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
# finetune your model here
finetune(model)
### save
model.zero_grad() # Remove gradients
torch.save(model, './models/rcnn_bet365_prune30.pth') # without .state_dict
model = torch.load('model.pth') # load the pruned model
if name == 'main':
cfg_path='./models/faster_rcnn.yaml'
weights_path='./models/rcnn_bet365.pth'
conf_threshold=0.5
model=config_rcnn(cfg_path, weights_path, conf_threshold)
# for test
prune_pipeline(model,iterative_steps=1)
So,my question is : can we use torch-pruning to prune a Detectron2 model?
The text was updated successfully, but these errors were encountered:
Hello,when I train and load a model with Detectron2,I try to use torch-pruning to do model pruning.However,there always seems to be an error.
NotImplementedError Traceback (most recent call last)
d:\Phishpedia\prune\test.ipynb 单元格 4 line 6
54 torch.save(model, f'./rcnn_bet365_prune30_{i}.pth') # without .state_dict
55 model = torch.load(f'./rcnn_bet365_prune30_{i}.pth') # load the pruned model
---> 63 prune_pipeline(model,iterative_steps=2)
d:\Phishpedia\prune\test.ipynb 单元格 4 line 2
23 ignored_layers = []
24 # for m in model.modules():
25 # if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
26 # ignored_layers.append(m) # DO NOT prune the final classifier!
---> 28 pruner = tp.pruner.MetaPruner( # We can always choose MetaPruner if sparse training is not required.
29 model,
30 example_inputs,
31 importance=imp,
32 pruning_ratio=0.1, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
33 ignored_layers=ignored_layers,
34 isomorphic=True, # enable isomorphic pruning to improve global ranking
35 global_pruning=True, # global pruning
36 forward_fn=custom_forward_fn
37 )
39 ### iterative pruning
40 for i in range(iterative_steps):
...
File d:\detectron2-windows\detectron2-windows\detectron2\structures\instances.py:149, in Instances.iter(self)
148 def iter(self):
--> 149 raise NotImplementedError("
Instances
object is not iterable!")NotImplementedError:
Instances
object is not iterable!And my code is like:
加载训练好的faster-rcnn模型
import torch
from detectron2.engine import DefaultPredictor, DefaultTrainer
from detectron2.config import get_cfg
from detectron2.checkpoint import DetectionCheckpointer
import argparse
import torch_pruning as tp
def config_rcnn(cfg_path, weights_path, conf_threshold):
'''
Configure weights and confidence threshold
:param cfg_path:
:param weights_path:
:param conf_threshold:
:return:
'''
#微调faster-rcnn模型
def finetune(model):
pass # todo
def prune_pipeline(model,iterative_steps):
height, width = 800, 800
if name == 'main':
cfg_path='./models/faster_rcnn.yaml'
weights_path='./models/rcnn_bet365.pth'
conf_threshold=0.5
model=config_rcnn(cfg_path, weights_path, conf_threshold)
# for test
prune_pipeline(model,iterative_steps=1)
So,my question is : can we use torch-pruning to prune a Detectron2 model?
The text was updated successfully, but these errors were encountered: