diff --git a/notebooks/oligodendrocyte/CohAE_oligo.pkl b/notebooks/oligodendrocyte/CohAE_oligo.pkl new file mode 100644 index 0000000..a244954 Binary files /dev/null and b/notebooks/oligodendrocyte/CohAE_oligo.pkl differ diff --git a/notebooks/oligodendrocyte/model-oligodendrocyte-gat.ipynb b/notebooks/oligodendrocyte/model-oligodendrocyte-gat.ipynb new file mode 100644 index 0000000..203cf20 --- /dev/null +++ b/notebooks/oligodendrocyte/model-oligodendrocyte-gat.ipynb @@ -0,0 +1,474 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import pickle\n", + "import numpy as np\n", + "import scvelo as scv\n", + "import scanpy\n", + "import torch\n", + "\n", + "from veloproj import *\n", + "\n", + "scv.settings.verbosity = 1" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Namespace(adata=None, aux_weight=100.0, conv_thred=1e-06, data_dir='./oligo_lite.h5ad', device='cuda:2', exp_name='CohAE_oligo', fit_offset_pred=True, fit_offset_train=False, g_rep_dim=50, gnn_layer='GAT', gumbsoft_tau=5.0, h_dim=256, is_half=False, k_dim=50, ld_adata='projection.h5', ld_nb_g_src='SU', log_interval=100, lr=1e-05, lr_decay=0.9, mask_cluster_list='velo_constraint_cluster_list.txt', model_name='oligo_model.cpt', n_conn_nb=30, n_epochs=20000, n_nb_newadata=30, n_raw_gene=2000, nb_g_src='X', output='./', refit=True, scv_n_jobs=10, seed=42, sl1_beta=8.0, use_norm=False, use_offset_pred=False, use_x=False, v_rg_wt=500.0, vis_key='X_umap', vis_type_col='celltype', weight_decay=0.0, z_dim=100)" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "parser = get_parser()\n", + "args = parser.parse_args(args=['--lr', '1e-5',\n", + " '--n-epochs', '20000', \n", + " '--g-rep-dim', '50',\n", + " '--k-dim', '50',\n", + " '--data-dir', './oligo_lite.h5ad',\n", + " '--model-name', 'oligo_model.cpt',\n", + " '--exp-name', 'CohAE_oligo',\n", + " '--device', 'cuda:2',\n", + " '--gumbsoft_tau', '5',\n", + " '--nb_g_src', 'X',\n", + " '--ld_nb_g_src', \"SU\",\n", + " '--n_raw_gene', '2000',\n", + " '--n_conn_nb', '30',\n", + " '--n_nb_newadata', '30',\n", + " '--aux_weight', '100',\n", + " '--fit_offset_train', 'false',\n", + " '--fit_offset_pred', 'true',\n", + " '--use_offset_pred', 'false',\n", + " '--gnn_layer', 'GAT',\n", + " '--vis-key', 'X_umap',\n", + " '--vis_type_col', 'celltype',\n", + " '--scv_n_jobs', '10',\n", + " '--mask_cluster_list', 'velo_constraint_cluster_list.txt', # in comparison with the retina dataset adding velocity constrints to\n", + " # low-dimentional space for all the cells, \n", + " # we provide an interface for users to specify specific cell clusters\n", + " # which we would like to constrain\n", + " # this argument specifes the path to a file with the first row indicating\n", + " # the cluster column name in adata.obs DataFrame\n", + " # and the rest rows denote the cluster types for constrain.\n", + " # in this example, we constrain the velocity of `NFOLs`\n", + " # to be similar between projected and linear regression estimated low-dimensional space,\n", + " # doing some type of `transfer learning` from scvelo stochastic mode to veloAE\n", + " '--v_rg_wt', '500',\n", + " '--sl1_beta', '8'\n", + " ])\n", + "args " + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "torch.manual_seed(args.seed)\n", + "torch.cuda.manual_seed(args.seed)\n", + "np.random.seed(args.seed)\n", + "torch.backends.cudnn.deterministic = True\n", + "\n", + "device = torch.device(args.device if args.device.startswith('cuda') and torch.cuda.is_available() else \"cpu\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "cluster_edges = [\n", + " (\"COPs\", \"NFOLs\"), \n", + " (\"NFOLs\", \"MFOLs\")]\n", + "EXP_NAME = args.exp_name\n", + "exp_metrics = {}" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "def main_AE(args, adata):\n", + " spliced = adata.layers['Ms']\n", + " unspliced = adata.layers['Mu']\n", + " tensor_s = torch.FloatTensor(spliced).to(device)\n", + " tensor_u = torch.FloatTensor(unspliced).to(device)\n", + " tensor_x = torch.FloatTensor(adata.X.toarray()).to(device)\n", + " tensor_v = torch.FloatTensor(adata.layers['stc_velocity']).to(device)\n", + "\n", + " model = init_model(adata, args, device)\n", + "\n", + " inputs = [tensor_s, tensor_u]\n", + " xyids = [0, 1]\n", + " if args.use_x:\n", + " inputs.append(tensor_x)\n", + "\n", + " model = fit_model(args, adata, model, inputs, tensor_v, xyids, device)\n", + " return tensor_s, tensor_u, tensor_x " + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Abundance of ['spliced', 'unspliced']: [0.64 0.36]\n", + "AnnData object with n_obs × n_vars = 6253 × 2000\n", + " obs: 'Age', 'Cell_Conc', 'ChipID', 'Clusters', 'Comments', 'Date_Captured', 'DonorID', 'Label', 'NGI_PlateWell', 'Num_Pooled_Animals', 'PCR_Cycles', 'Plug_Date', 'Project', 'SampleID', 'SampleOK', 'Sample_Index', 'Seq_Comment', 'Seq_Lib_Date', 'Seq_Lib_Ok', 'Serial_Number', 'Sex', 'Species', 'Strain', 'Target_Num_Cells', 'Tissue', 'Transcriptome', '_X', '_Y', 'cDNA_Lib_Ok', 'ngperul_cDNA', 'V1', 'V2', 'label', 'celltype', 'initial_size_spliced', 'initial_size_unspliced', 'initial_size', 'n_counts'\n", + " var: 'Accession', 'Chromosome', 'End', 'Start', 'Strand', 'gene_count_corr', 'means', 'dispersions', 'dispersions_norm', 'highly_variable'\n", + " uns: 'pca', 'neighbors'\n", + " obsm: 'X_coor', 'X_umap', 'X_pca'\n", + " varm: 'PCs'\n", + " layers: 'ambiguous', 'matrix', 'spliced', 'unspliced', 'Ms', 'Mu'\n", + " obsp: 'distances', 'connectivities'\n" + ] + } + ], + "source": [ + "adata = scanpy.read_h5ad(args.data_dir)\n", + "scv.pp.neighbors(adata, n_neighbors=30, n_pcs=30)\n", + "scv.utils.show_proportions(adata)\n", + "scv.pp.filter_and_normalize(adata, min_shared_counts=30, n_top_genes=args.n_raw_gene)\n", + "scv.pp.moments(adata, n_pcs=30, n_neighbors=30)\n", + "print(adata)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d4e104312ec0496582218eb3722b52f0", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/6253 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "# Cross-Boundary Transition Score (A->B)\n", + "{('COPs', 'NFOLs'): 0.13601112, ('NFOLs', 'MFOLs'): 0.2676039}\n", + "Total Mean: 0.20180751383304596\n", + "# Cross-Boundary Velocity Coherence (A->B)\n", + "{('COPs', 'NFOLs'): 0.86201894, ('NFOLs', 'MFOLs'): 0.8960272}\n", + "Total Mean: 0.8790230751037598\n", + "# Cross-Boundary Direction Correctness (A->B)\n", + "{('COPs', 'NFOLs'): 0.18456332196112613, ('NFOLs', 'MFOLs'): 0.527163266529319}\n", + "Total Mean: 0.35586329424522256\n", + "# In-cluster Coherence\n", + "{'COPs': 0.89906085, 'MFOLs': 0.9619137, 'NFOLs': 0.9057924, 'OPCs': 0.8825826}\n", + "Total Mean: 0.9123374223709106\n", + "# In-cluster Confidence\n", + "{'COPs': 0.8975995237356137, 'MFOLs': 0.9524482225827059, 'NFOLs': 0.8943025764184453, 'OPCs': 0.8614016425225043}\n", + "Total Mean: 0.9014379913148174\n" + ] + } + ], + "source": [ + "scv.tl.velocity(adata, vkey='stc_velocity', mode=\"stochastic\")\n", + "scv.tl.velocity_graph(adata, vkey='stc_velocity', n_jobs=args.scv_n_jobs)\n", + "scv.tl.velocity_confidence(adata, vkey='stc_velocity')\n", + "scv.pl.velocity_embedding_stream(adata, vkey=\"stc_velocity\", basis=args.vis_key, color=args.vis_type_col,\n", + " dpi=150, \n", + " title='ScVelo Stochastic Mode')\n", + "exp_metrics[\"stc_mode\"] = evaluate(adata, cluster_edges, args.vis_type_col, \"stc_velocity\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loss: (Total) 78.359283, (AE) 77.568756, (LR) 100.00 * 0.000732, (RG) 500.00 * 0.001435: 100%|██████████| 20000/20000 [23:42<00:00, 14.06it/s] \n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "tensor_s, tensor_u, tensor_x = main_AE(args, adata)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8950947162e449fb9d83a42c7f72991b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/6253 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "# Cross-Boundary Transition Score (A->B)\n", + "{('COPs', 'NFOLs'): 0.2392812, ('NFOLs', 'MFOLs'): 0.104577646}\n", + "Total Mean: 0.17192941904067993\n", + "# Cross-Boundary Velocity Coherence (A->B)\n", + "{('COPs', 'NFOLs'): 0.9953603, ('NFOLs', 'MFOLs'): 0.996584}\n", + "Total Mean: 0.9959721565246582\n", + "# Cross-Boundary Direction Correctness (A->B)\n", + "{('COPs', 'NFOLs'): 0.8487240633966037, ('NFOLs', 'MFOLs'): 0.5442262991169514}\n", + "Total Mean: 0.6964751812567775\n", + "# In-cluster Coherence\n", + "{'COPs': 0.9963123, 'MFOLs': 0.99721044, 'NFOLs': 0.99829817, 'OPCs': 0.99997526}\n", + "Total Mean: 0.9979490041732788\n", + "# In-cluster Confidence\n", + "{'COPs': 0.996084103778113, 'MFOLs': 0.9970275277809268, 'NFOLs': 0.9978537890824896, 'OPCs': 0.9856352421545214}\n", + "Total Mean: 0.9941501656990126\n" + ] + } + ], + "source": [ + "def exp(adata, exp_metrics):\n", + " model = init_model(adata, args, device)\n", + " model.load_state_dict(torch.load(args.model_name))\n", + " model = model.to(device)\n", + " model.eval()\n", + " with torch.no_grad():\n", + " x = model.encoder(tensor_x)\n", + " s = model.encoder(tensor_s)\n", + " u = model.encoder(tensor_u)\n", + " \n", + " v = estimate_ld_velocity(s, u, device=device, perc=[5, 95], \n", + " norm=args.use_norm, fit_offset=args.fit_offset_pred, \n", + " use_offset=args.use_offset_pred).cpu().numpy()\n", + " x = x.cpu().numpy()\n", + " s = s.cpu().numpy()\n", + " u = u.cpu().numpy()\n", + " \n", + " adata = new_adata(adata, x, s, u, v, g_basis=args.ld_nb_g_src, n_nb_newadata=args.n_nb_newadata)\n", + " scv.tl.velocity_graph(adata, vkey='new_velocity', n_jobs=args.scv_n_jobs)\n", + " scv.pl.velocity_embedding_stream(adata, vkey=\"new_velocity\", basis=args.vis_key, color=args.vis_type_col,\n", + " title=\"Project Original Velocity into Low-Dim Space\",\n", + " dpi=150,\n", + " save='oligodendrocyte_pojection.png') \n", + " scv.tl.velocity_confidence(adata, vkey='new_velocity')\n", + " exp_metrics['Cohort AutoEncoder'] = evaluate(adata, cluster_edges, args.vis_type_col, \"new_velocity\")\n", + " \n", + "exp(adata, exp_metrics)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Abundance of ['spliced', 'unspliced']: [0.64 0.36]\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b365cfd20d234ea0bb23482306bf2916", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/757 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "# Cross-Boundary Transition Score (A->B)\n", + "{('COPs', 'NFOLs'): 0.39235646, ('NFOLs', 'MFOLs'): 0.43137965}\n", + "Total Mean: 0.41186803579330444\n", + "# Cross-Boundary Velocity Coherence (A->B)\n", + "{('COPs', 'NFOLs'): 0.8339525880108608, ('NFOLs', 'MFOLs'): 0.7648149955134319}\n", + "Total Mean: 0.7993837917621464\n", + "# Cross-Boundary Direction Correctness (A->B)\n", + "{('COPs', 'NFOLs'): -0.045721793801200644, ('NFOLs', 'MFOLs'): 0.09865751437961526}\n", + "Total Mean: 0.026467860289207307\n", + "# In-cluster Coherence\n", + "{'COPs': 0.9741825131895359, 'MFOLs': 0.8120565060783369, 'NFOLs': 0.7916376122571769, 'OPCs': 0.9990290240941994}\n", + "Total Mean: 0.8942264139048123\n", + "# In-cluster Confidence\n", + "{'COPs': 0.9734623313499147, 'MFOLs': 0.8024698073010618, 'NFOLs': 0.7887305956486915, 'OPCs': 0.9938755845841556}\n", + "Total Mean: 0.8896345797209559\n" + ] + } + ], + "source": [ + "adata = scanpy.read_h5ad(args.data_dir)\n", + "scv.utils.show_proportions(adata)\n", + "scv.pp.filter_and_normalize(adata, min_shared_counts=30, n_top_genes=args.n_raw_gene)\n", + "scv.pp.neighbors(adata, n_pcs=30, n_neighbors=30)\n", + "scv.pp.moments(adata, n_pcs=30, n_neighbors=30)\n", + "\n", + "scv.tl.recover_dynamics(adata, n_jobs=args.scv_n_jobs)\n", + "scv.tl.velocity(adata, vkey='dyn_velocity', mode=\"dynamical\")\n", + "\n", + "scv.tl.velocity_graph(adata, vkey='dyn_velocity', n_jobs=args.scv_n_jobs)\n", + "scv.tl.velocity_confidence(adata, vkey='dyn_velocity')\n", + "scv.pl.velocity_embedding_stream(adata, \n", + " vkey=\"dyn_velocity\", \n", + " basis=args.vis_key, \n", + " color=[args.vis_type_col],\n", + " dpi=150, \n", + " title='ScVelo Dynamical Mode')\n", + "exp_metrics[\"dyn_mode\"] = evaluate(adata[:, adata.var.dyn_velocity_genes], cluster_edges, args.vis_type_col, \"dyn_velocity\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "with open(\"{}.pkl\".format(EXP_NAME), 'wb') as out_file:\n", + " pickle.dump(exp_metrics, out_file)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "interpreter": { + "hash": "a8471068523155614cf4eb871cd7d17435a8456aa34f037779c06d1e28b048ac" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.12" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebooks/oligodendrocyte/oligo_model.cpt b/notebooks/oligodendrocyte/oligo_model.cpt new file mode 100644 index 0000000..545349e Binary files /dev/null and b/notebooks/oligodendrocyte/oligo_model.cpt differ diff --git a/notebooks/oligodendrocyte/velo_constraint_cluster_list.txt b/notebooks/oligodendrocyte/velo_constraint_cluster_list.txt new file mode 100644 index 0000000..7816e82 --- /dev/null +++ b/notebooks/oligodendrocyte/velo_constraint_cluster_list.txt @@ -0,0 +1,2 @@ +celltype +NFOLs \ No newline at end of file