Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is torch-pruning compatible with Detectron2? #421

Open
randShuffle opened this issue Sep 9, 2024 · 0 comments
Open

Is torch-pruning compatible with Detectron2? #421

randShuffle opened this issue Sep 9, 2024 · 0 comments

Comments

@randShuffle
Copy link

Hello,when I train and load a model with Detectron2,I try to use torch-pruning to do model pruning.However,there always seems to be an error.
NotImplementedError Traceback (most recent call last)
d:\Phishpedia\prune\test.ipynb 单元格 4 line 6
54 torch.save(model, f'./rcnn_bet365_prune30_{i}.pth') # without .state_dict
55 model = torch.load(f'./rcnn_bet365_prune30_{i}.pth') # load the pruned model
---> 63 prune_pipeline(model,iterative_steps=2)

d:\Phishpedia\prune\test.ipynb 单元格 4 line 2
23 ignored_layers = []
24 # for m in model.modules():
25 # if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
26 # ignored_layers.append(m) # DO NOT prune the final classifier!
---> 28 pruner = tp.pruner.MetaPruner( # We can always choose MetaPruner if sparse training is not required.
29 model,
30 example_inputs,
31 importance=imp,
32 pruning_ratio=0.1, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
33 ignored_layers=ignored_layers,
34 isomorphic=True, # enable isomorphic pruning to improve global ranking
35 global_pruning=True, # global pruning
36 forward_fn=custom_forward_fn
37 )
39 ### iterative pruning
40 for i in range(iterative_steps):
...
File d:\detectron2-windows\detectron2-windows\detectron2\structures\instances.py:149, in Instances.iter(self)
148 def iter(self):
--> 149 raise NotImplementedError("Instances object is not iterable!")

NotImplementedError: Instances object is not iterable!

And my code is like:

加载训练好的faster-rcnn模型

import torch
from detectron2.engine import DefaultPredictor, DefaultTrainer
from detectron2.config import get_cfg
from detectron2.checkpoint import DetectionCheckpointer
import argparse
import torch_pruning as tp

def config_rcnn(cfg_path, weights_path, conf_threshold):
'''
Configure weights and confidence threshold
:param cfg_path:
:param weights_path:
:param conf_threshold:
:return:
'''

cfg = get_cfg()

cfg.merge_from_file(cfg_path)
cfg.MODEL.WEIGHTS = weights_path

cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = conf_threshold

# uncomment if you installed detectron2 cpu version
if not torch.cuda.is_available():
    cfg.MODEL.DEVICE = 'cpu'

model = DefaultTrainer.build_model(cfg)
DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS)

return model

#微调faster-rcnn模型
def finetune(model):
pass # todo

def prune_pipeline(model,iterative_steps):
height, width = 800, 800

example_inputs = [{"image": torch.randn(3, height, width)}]


# 1. Importance criterion
imp = tp.importance.GroupNormImportance(p=2) # or GroupTaylorImportance(), GroupHessianImportance(), etc.

# 2. Initialize a pruner with the model and the importance criterion
ignored_layers = []
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m) # DO NOT prune the final classifier!

pruner = tp.pruner.MetaPruner( # We can always choose MetaPruner if sparse training is not required.
    model,
    example_inputs,
    importance=imp,
    pruning_ratio=0.3, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    ignored_layers=ignored_layers,
    isomorphic=True, # enable isomorphic pruning to improve global ranking
    global_pruning=True, # global pruning
)

### iterative pruning
for i in range(iterative_steps):
    for group in pruner.step(interactive=True): # Warning: groups must be handled sequentially. Do not keep them as a list.
        print(group) 
        # do whatever you like with the group 
        dep, idxs = group[0] # get the idxs
        target_module = dep.target.module # get the root module
        pruning_fn = dep.handler # get the pruning function
        group.prune()
        
    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
    # finetune your model here
    finetune(model)
    ### save
    model.zero_grad() # Remove gradients
    torch.save(model, './models/rcnn_bet365_prune30.pth') # without .state_dict
    model = torch.load('model.pth') # load the pruned model

if name == 'main':
cfg_path='./models/faster_rcnn.yaml'
weights_path='./models/rcnn_bet365.pth'
conf_threshold=0.5
model=config_rcnn(cfg_path, weights_path, conf_threshold)
# for test
prune_pipeline(model,iterative_steps=1)

So,my question is : can we use torch-pruning to prune a Detectron2 model?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant