-
Notifications
You must be signed in to change notification settings - Fork 1
/
run_CCC_gat.py
80 lines (67 loc) · 3.36 KB
/
run_CCC_gat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import sys
import numpy as np
from sklearn import metrics
from scipy import sparse
import pickle
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, ChebConv, GATConv, DeepGraphInfomax, global_mean_pool, global_max_pool # noqa
from torch_geometric.data import Data, DataLoader
from datetime import datetime
import time
rootPath = os.path.dirname(sys.path[0])
os.chdir(rootPath+'/CCC_project')
import random
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
# =========================== args ===============================
parser.add_argument( '--data_name', type=str, default='V1_Breast_Cancer_Block_A_Section_1', help="'MERFISH' or 'V1_Breast_Cancer_Block_A_Section_1")
parser.add_argument( '--data_path', type=str, default='generated_data/', help='data path')
parser.add_argument( '--model_path', type=str, default='model')
parser.add_argument( '--embedding_data_path', type=str, default='Embedding_data')
#parser.add_argument( '--result_path', type=str, default='results')
parser.add_argument( '--load', type=int, default=0, help='Load pretrained DGI model')
parser.add_argument( '--num_epoch', type=int, default=5000, help='numebr of epoch in training DGI')
parser.add_argument( '--hidden', type=int, default=256, help='hidden channels in DGI')
parser.add_argument( '--retrain', type=int, default=0 , help='Run which clustering at the end')
parser.add_argument( '--model_load_path', type=str, default='model')
parser.add_argument( '--model_name', type=str, default='r1')
parser.add_argument( '--training_data', type=str, default='provide please')
parser.add_argument( '--heads', type=int, default=1)
parser.add_argument( '--num_cells', type=int, default=1)
parser.add_argument( '--options', type=str)
parser.add_argument( '--withFeature', type=str, default='r1')
parser.add_argument( '--workflow_v', type=int)
parser.add_argument( '--datatype', type=str)
parser.add_argument( '--dropout', type=float, default=0)
parser.add_argument( '--lr_rate', type=float, default=0.00001)
parser.add_argument( '--manual_seed', type=str, default='no')
parser.add_argument( '--seed', type=int)
args = parser.parse_args()
args.embedding_data_path = args.embedding_data_path + args.data_name +'/'
args.model_path = args.model_path + args.data_name +'/'
#args.result_path = args.result_path +'/'+ args.data_name +'/'
args.model_load_path = args.model_load_path + args.data_name +'/'
print(args.model_name+', '+str(args.heads)+', '+args.training_data+', '+str(args.hidden) )
if args.manual_seed == 'yes':
torch.manual_seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)
start_time = time.time()
if not os.path.exists(args.embedding_data_path):
os.makedirs(args.embedding_data_path)
if not os.path.exists(args.model_path):
os.makedirs(args.model_path)
#args.result_path = args.result_path+'/'
#if not os.path.exists(args.result_path):
# os.makedirs(args.result_path)
print ('------------------------Model and Training Details--------------------------')
print(args)
from CCC_train_ST import CCC_on_ST
CCC_on_ST(args)
end_time = time.time() - start_time
print('time elapsed %g min'%(end_time/60))