-
Notifications
You must be signed in to change notification settings - Fork 0
/
resnet18.py
108 lines (92 loc) · 5.99 KB
/
resnet18.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from legodnn.utils.dl.common.env import set_random_seed
set_random_seed(0)
import sys
sys.setrecursionlimit(100000) # 最大随机深度(与资源占用有关)
import torch
from legodnn import BlockExtractor, BlockTrainer, ServerBlockProfiler, EdgeBlockProfiler, OptimalRuntime
from legodnn.gen_series_legodnn_models import gen_series_legodnn_models
from legodnn.block_detection.model_topology_extraction import topology_extraction # 构建模型拓扑结构
from legodnn.presets.auto_block_manager import AutoBlockManager
from legodnn.presets.common_detection_manager_1204_new import CommonDetectionManager
from legodnn.model_manager.common_model_manager import CommonModelManager # AbstractModelManager的实现
from legodnn.utils.common.file import experiments_model_file_path
from legodnn.utils.dl.common.model import get_module, set_module, get_model_size
from cv_task.datasets.image_classification.cifar_dataloader import CIFAR10Dataloader, CIFAR100Dataloader
from cv_task.image_classification.cifar.models import resnet18
from cv_task.image_classification.cifar.legodnn_configs import get_cifar100_train_config_200e
from cv_task.log_on_time.time_logger import time_logger
dataset_root_dir = '/home/marcus/newspace/datasets'
checkpoint_path = '/home/marcus/newspace/LegoDNN_teacher_model/cifar100/resnet18/2023-06-28/13-03-18/resnet18.pth'
num_workers = 4 # default = 4 如报错请改为0
if __name__ == '__main__':
cv_task = 'image_classification'
dataset_name = 'cifar100'
model_name = 'resnet18'
method = 'legodnn'
device = 'cuda'
compress_layer_max_ratio = 0.125 # 最大压缩比???
model_input_size = (1, 3, 32, 32)
block_sparsity = [0.0, 0.2, 0.4, 0.6, 0.8]
root_path = os.path.join('results/legodnn', cv_task, model_name+'_'+dataset_name + '_' + str(compress_layer_max_ratio).replace('.', '-'))
time_logger_obj = time_logger('../../time_log/legodnn_execute_time', model_name+'_'+dataset_name)
time_logger_obj.start()
compressed_blocks_dir_path = root_path + '/compressed' # 压缩后的block
trained_blocks_dir_path = root_path + '/trained' # 训练过的block
descendant_models_dir_path = root_path + '/descendant' # 后代block
block_training_max_epoch = 65
test_sample_num = 100
checkpoint = checkpoint_path
if dataset_name == 'cifar100':
teacher_model = resnet18(num_classes=100).to(device)
elif dataset_name == 'cifar10':
teacher_model = resnet18(num_classes=10).to(device)
else:
print('\033[31mWrong Dataset!!!\033[0m')
teacher_model.load_state_dict(torch.load(checkpoint)['net']) # 权重导入(初次训练没有)
print('\033[1;36m--------------------------------> BUILD LEGODNN GRAPH\033[0m') # 构建拓扑结构
model_graph = topology_extraction(teacher_model, model_input_size, device=device, mode='unpack') # pack/unpack
model_graph.print_ordered_node() # 按顺序打印节点
time_logger_obj.lap('Build topological graph')
print('\033[1;36m--------------------------------> START BLOCK DETECTION\033[0m') # 根据拓扑结构探测block
detection_manager = CommonDetectionManager(model_graph, max_ratio=compress_layer_max_ratio)
detection_manager.detection_all_blocks()
detection_manager.print_all_blocks()
time_logger_obj.lap('Detect blocks')
# modelmanager和blockmanager
model_manager = CommonModelManager()
block_manager = AutoBlockManager(block_sparsity, detection_manager, model_manager)
print('\033[1;36m--------------------------------> START BLOCK EXTRACTION\033[0m') # block导出
block_extractor = BlockExtractor(teacher_model, block_manager, compressed_blocks_dir_path, model_input_size, device)
block_extractor.extract_all_blocks() # 按稀疏度导出blocks
time_logger_obj.lap('Extract all blocks')
print('\033[1;36m--------------------------------> START BLOCK TRAIN\033[0m')
# num_workers>=1 报错 ValueError: signal number 32 out of range
if dataset_name == 'cifar100':
train_loader, test_loader = CIFAR100Dataloader(root_dir=dataset_root_dir, num_workers=num_workers)
else: # cifar10
train_loader, test_loader = CIFAR10Dataloader(root_dir=dataset_root_dir, num_workers=num_workers)
print("\033[32mDataloader done\033[0m")
block_trainer = BlockTrainer(teacher_model, block_manager, model_manager, compressed_blocks_dir_path,
trained_blocks_dir_path, block_training_max_epoch, train_loader, device=device)
print("\033[32mDatatrainer initialized\033[0m")
block_trainer.train_all_blocks()
print("\033[32mBlock trained\033[0m")
time_logger_obj.lap('Train all blocks')
# memory,accuracy profiler (original blocks & compressed blocks)
server_block_profiler = ServerBlockProfiler(teacher_model, block_manager, model_manager,
trained_blocks_dir_path, test_loader, model_input_size, device)
server_block_profiler.profile_all_blocks()
time_logger_obj.lap('Profile blocks on memory & acc')
# latency profiler (original blocks & compressed blocks)
edge_block_profiler = EdgeBlockProfiler(block_manager, model_manager, trained_blocks_dir_path,
test_sample_num, model_input_size, device)
edge_block_profiler.profile_all_blocks()
time_logger_obj.lap('Profile blocks on latency')
optimal_runtime = OptimalRuntime(trained_blocks_dir_path, model_input_size,
block_manager, model_manager, device)
model_size_min = get_model_size(torch.load(os.path.join(compressed_blocks_dir_path, 'model_frame.pt')))/1024**2
model_size_max = get_model_size(teacher_model)/1024**2 + 1
gen_series_legodnn_models(deadline=100, model_size_search_range=[model_size_min, model_size_max], target_model_num=100, optimal_runtime=optimal_runtime, descendant_models_save_path=descendant_models_dir_path, device=device)
time_logger_obj.end('Optimal runtime search')