diff --git a/notebooks/oligodendrocyte/CohAE_oligo.pkl b/notebooks/oligodendrocyte/CohAE_oligo.pkl new file mode 100644 index 0000000..a244954 Binary files /dev/null and b/notebooks/oligodendrocyte/CohAE_oligo.pkl differ diff --git a/notebooks/oligodendrocyte/model-oligodendrocyte-gat.ipynb b/notebooks/oligodendrocyte/model-oligodendrocyte-gat.ipynb new file mode 100644 index 0000000..203cf20 --- /dev/null +++ b/notebooks/oligodendrocyte/model-oligodendrocyte-gat.ipynb @@ -0,0 +1,474 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import pickle\n", + "import numpy as np\n", + "import scvelo as scv\n", + "import scanpy\n", + "import torch\n", + "\n", + "from veloproj import *\n", + "\n", + "scv.settings.verbosity = 1" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Namespace(adata=None, aux_weight=100.0, conv_thred=1e-06, data_dir='./oligo_lite.h5ad', device='cuda:2', exp_name='CohAE_oligo', fit_offset_pred=True, fit_offset_train=False, g_rep_dim=50, gnn_layer='GAT', gumbsoft_tau=5.0, h_dim=256, is_half=False, k_dim=50, ld_adata='projection.h5', ld_nb_g_src='SU', log_interval=100, lr=1e-05, lr_decay=0.9, mask_cluster_list='velo_constraint_cluster_list.txt', model_name='oligo_model.cpt', n_conn_nb=30, n_epochs=20000, n_nb_newadata=30, n_raw_gene=2000, nb_g_src='X', output='./', refit=True, scv_n_jobs=10, seed=42, sl1_beta=8.0, use_norm=False, use_offset_pred=False, use_x=False, v_rg_wt=500.0, vis_key='X_umap', vis_type_col='celltype', weight_decay=0.0, z_dim=100)" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "parser = get_parser()\n", + "args = parser.parse_args(args=['--lr', '1e-5',\n", + " '--n-epochs', '20000', \n", + " '--g-rep-dim', '50',\n", + " '--k-dim', '50',\n", + " '--data-dir', './oligo_lite.h5ad',\n", + " '--model-name', 'oligo_model.cpt',\n", + " '--exp-name', 'CohAE_oligo',\n", + " '--device', 'cuda:2',\n", + " '--gumbsoft_tau', '5',\n", + " '--nb_g_src', 'X',\n", + " '--ld_nb_g_src', \"SU\",\n", + " '--n_raw_gene', '2000',\n", + " '--n_conn_nb', '30',\n", + " '--n_nb_newadata', '30',\n", + " '--aux_weight', '100',\n", + " '--fit_offset_train', 'false',\n", + " '--fit_offset_pred', 'true',\n", + " '--use_offset_pred', 'false',\n", + " '--gnn_layer', 'GAT',\n", + " '--vis-key', 'X_umap',\n", + " '--vis_type_col', 'celltype',\n", + " '--scv_n_jobs', '10',\n", + " '--mask_cluster_list', 'velo_constraint_cluster_list.txt', # in comparison with the retina dataset adding velocity constrints to\n", + " # low-dimentional space for all the cells, \n", + " # we provide an interface for users to specify specific cell clusters\n", + " # which we would like to constrain\n", + " # this argument specifes the path to a file with the first row indicating\n", + " # the cluster column name in adata.obs DataFrame\n", + " # and the rest rows denote the cluster types for constrain.\n", + " # in this example, we constrain the velocity of `NFOLs`\n", + " # to be similar between projected and linear regression estimated low-dimensional space,\n", + " # doing some type of `transfer learning` from scvelo stochastic mode to veloAE\n", + " '--v_rg_wt', '500',\n", + " '--sl1_beta', '8'\n", + " ])\n", + "args " + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "torch.manual_seed(args.seed)\n", + "torch.cuda.manual_seed(args.seed)\n", + "np.random.seed(args.seed)\n", + "torch.backends.cudnn.deterministic = True\n", + "\n", + "device = torch.device(args.device if args.device.startswith('cuda') and torch.cuda.is_available() else \"cpu\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "cluster_edges = [\n", + " (\"COPs\", \"NFOLs\"), \n", + " (\"NFOLs\", \"MFOLs\")]\n", + "EXP_NAME = args.exp_name\n", + "exp_metrics = {}" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "def main_AE(args, adata):\n", + " spliced = adata.layers['Ms']\n", + " unspliced = adata.layers['Mu']\n", + " tensor_s = torch.FloatTensor(spliced).to(device)\n", + " tensor_u = torch.FloatTensor(unspliced).to(device)\n", + " tensor_x = torch.FloatTensor(adata.X.toarray()).to(device)\n", + " tensor_v = torch.FloatTensor(adata.layers['stc_velocity']).to(device)\n", + "\n", + " model = init_model(adata, args, device)\n", + "\n", + " inputs = [tensor_s, tensor_u]\n", + " xyids = [0, 1]\n", + " if args.use_x:\n", + " inputs.append(tensor_x)\n", + "\n", + " model = fit_model(args, adata, model, inputs, tensor_v, xyids, device)\n", + " return tensor_s, tensor_u, tensor_x " + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Abundance of ['spliced', 'unspliced']: [0.64 0.36]\n", + "AnnData object with n_obs × n_vars = 6253 × 2000\n", + " obs: 'Age', 'Cell_Conc', 'ChipID', 'Clusters', 'Comments', 'Date_Captured', 'DonorID', 'Label', 'NGI_PlateWell', 'Num_Pooled_Animals', 'PCR_Cycles', 'Plug_Date', 'Project', 'SampleID', 'SampleOK', 'Sample_Index', 'Seq_Comment', 'Seq_Lib_Date', 'Seq_Lib_Ok', 'Serial_Number', 'Sex', 'Species', 'Strain', 'Target_Num_Cells', 'Tissue', 'Transcriptome', '_X', '_Y', 'cDNA_Lib_Ok', 'ngperul_cDNA', 'V1', 'V2', 'label', 'celltype', 'initial_size_spliced', 'initial_size_unspliced', 'initial_size', 'n_counts'\n", + " var: 'Accession', 'Chromosome', 'End', 'Start', 'Strand', 'gene_count_corr', 'means', 'dispersions', 'dispersions_norm', 'highly_variable'\n", + " uns: 'pca', 'neighbors'\n", + " obsm: 'X_coor', 'X_umap', 'X_pca'\n", + " varm: 'PCs'\n", + " layers: 'ambiguous', 'matrix', 'spliced', 'unspliced', 'Ms', 'Mu'\n", + " obsp: 'distances', 'connectivities'\n" + ] + } + ], + "source": [ + "adata = scanpy.read_h5ad(args.data_dir)\n", + "scv.pp.neighbors(adata, n_neighbors=30, n_pcs=30)\n", + "scv.utils.show_proportions(adata)\n", + "scv.pp.filter_and_normalize(adata, min_shared_counts=30, n_top_genes=args.n_raw_gene)\n", + "scv.pp.moments(adata, n_pcs=30, n_neighbors=30)\n", + "print(adata)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d4e104312ec0496582218eb3722b52f0", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/6253 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "# Cross-Boundary Transition Score (A->B)\n", + "{('COPs', 'NFOLs'): 0.13601112, ('NFOLs', 'MFOLs'): 0.2676039}\n", + "Total Mean: 0.20180751383304596\n", + "# Cross-Boundary Velocity Coherence (A->B)\n", + "{('COPs', 'NFOLs'): 0.86201894, ('NFOLs', 'MFOLs'): 0.8960272}\n", + "Total Mean: 0.8790230751037598\n", + "# Cross-Boundary Direction Correctness (A->B)\n", + "{('COPs', 'NFOLs'): 0.18456332196112613, ('NFOLs', 'MFOLs'): 0.527163266529319}\n", + "Total Mean: 0.35586329424522256\n", + "# In-cluster Coherence\n", + "{'COPs': 0.89906085, 'MFOLs': 0.9619137, 'NFOLs': 0.9057924, 'OPCs': 0.8825826}\n", + "Total Mean: 0.9123374223709106\n", + "# In-cluster Confidence\n", + "{'COPs': 0.8975995237356137, 'MFOLs': 0.9524482225827059, 'NFOLs': 0.8943025764184453, 'OPCs': 0.8614016425225043}\n", + "Total Mean: 0.9014379913148174\n" + ] + } + ], + "source": [ + "scv.tl.velocity(adata, vkey='stc_velocity', mode=\"stochastic\")\n", + "scv.tl.velocity_graph(adata, vkey='stc_velocity', n_jobs=args.scv_n_jobs)\n", + "scv.tl.velocity_confidence(adata, vkey='stc_velocity')\n", + "scv.pl.velocity_embedding_stream(adata, vkey=\"stc_velocity\", basis=args.vis_key, color=args.vis_type_col,\n", + " dpi=150, \n", + " title='ScVelo Stochastic Mode')\n", + "exp_metrics[\"stc_mode\"] = evaluate(adata, cluster_edges, args.vis_type_col, \"stc_velocity\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loss: (Total) 78.359283, (AE) 77.568756, (LR) 100.00 * 0.000732, (RG) 500.00 * 0.001435: 100%|██████████| 20000/20000 [23:42<00:00, 14.06it/s] \n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYQAAAD4CAYAAADsKpHdAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAa/klEQVR4nO3dfWxd9Z3n8ffHdhJCIQESF6V5qAOEagFN08abZbZDt6tMS4o6DZ2BjqNRk9EgpTAgtZpd7cBWu0UrRRpmlrIbjZoqlAxQtTwMlCF/wCwsjMrMiicHUhIeMjiQNiZuEgiFQEKC7e/+cX/XnGMf28m9174O5/OSru7x9zzc771O7sfnd869RxGBmZlZS7MbMDOzqcGBYGZmgAPBzMwSB4KZmQEOBDMzS9qa3UCt5s6dGx0dHc1uw8zspLJ169Y3I6K9aN5JGwgdHR10d3c3uw0zs5OKpF+NNs9DRmZmBjgQzMwscSCYmRngQDAzs8SBYGZmgAPBzMwSB4KZmQElDIRndx/kB4/s5Fj/YLNbMTObUkoXCM/96m02PN5D/6ADwcwsq3SBYGZmxcYNBEmbJe2XtCNTu0fStnTbLWlbqndIOpKZ96PMOsskbZfUI2mDJKX6jLS9HklPS+po/NMcyReKMzPLO549hNuBldlCRPxxRCyNiKXA/cDPM7N3VedFxNWZ+kZgHbAk3arbvAp4OyLOA24BbqrliRyvSgyZmdlw4wZCRDwBHCyal/7K/yZw11jbkDQPmBURT0blIs53Apen2auAO9L0fcCK6t7DRPIOgplZXr3HEC4B9kXEq5naYknPS/qFpEtSbT7Qm1mmN9Wq8/YAREQ/8A4wp+jBJK2T1C2p+8CBAzU1LLyLYGZWpN5AWE1+76APWBQRnwP+AviZpFlQ+C5c/SN9rHn5YsSmiOiMiM729sKv8zYzsxrVfD0ESW3AHwLLqrWIOAocTdNbJe0CzqeyR7Ags/oCYG+a7gUWAr1pm7MZZYiqkcJHlc3McurZQ/h94JWIGBoKktQuqTVNn0Pl4PFrEdEHHJJ0cTo+sAZ4MK22BVibpq8AHo8JfLf2QWUzs2LHc9rpXcCTwGck9Uq6Ks3qYuTB5C8CL0j6JZUDxFdHRPWv/WuAHwM9wC7g4VS/DZgjqYfKMNP1dTwfMzOr0bhDRhGxepT6nxbU7qdyGmrR8t3ARQX1D4Arx+uj0TxgZGaW508qm5kZ4EAwM7OktIHgk4zMzPJKFwiT8CFoM7OTUukCYYj3EMzMckoXCN4/MDMrVrpAMDOzYqUNhPCYkZlZTukCwceUzcyKlS4QzMysWGkDwZ9DMDPLK10geMTIzKxY6QLBzMyKlTYQPGJkZpZXukDwV1eYmRUrXSBU+RKaZmZ5pQsE7yCYmRUrXSCYmVmx0gaCB4zMzPLGDQRJmyXtl7QjU7tR0huStqXbZZl5N0jqkbRT0qWZ+jJJ29O8DUpHdyXNkHRPqj8tqaPBzzH/fCZy42ZmJ7Hj2UO4HVhZUL8lIpam20MAki4AuoAL0zo/lNSalt8IrAOWpFt1m1cBb0fEecAtwE01PhczM6vDuIEQEU8AB49ze6uAuyPiaES8DvQAyyXNA2ZFxJNROb3nTuDyzDp3pOn7gBWahHNDfZKRmVlePccQrpP0QhpSOjPV5gN7Msv0ptr8ND28nlsnIvqBd4A5RQ8oaZ2kbkndBw4cqK1rn2ZkZlao1kDYCJwLLAX6gJtTvejdNsaoj7XOyGLEpojojIjO9vb2E2p45AN4F8HMLKumQIiIfRExEBGDwK3A8jSrF1iYWXQBsDfVFxTUc+tIagNmc/xDVCfM+wdmZsVqCoR0TKDqG0D1DKQtQFc6c2gxlYPHz0REH3BI0sXp+MAa4MHMOmvT9BXA4+GPEZuZTbq28RaQdBfwJWCupF7g+8CXJC2lMrSzG/g2QES8KOle4CWgH7g2IgbSpq6hcsbSTODhdAO4DfiJpB4qewZdDXhe43PkmJnljBsIEbG6oHzbGMuvB9YX1LuBiwrqHwBXjtdHo/iYsplZsdJ+UtnMzPJKGwgeMTIzyytdIMjnGZmZFSpdIJiZWbHSBoJPbDUzyytdIPgsIzOzYqULhCp/dYWZWV7pAsE7CGZmxUoXCGZmVqy0geCDymZmeaULBB9UNjMrVrpAMDOzYqUNBI8YmZnllS4Q/NUVZmbFShcIZmZWrLSB4IuymZnllS8QPGJkZlaofIGQeAfBzCxv3ECQtFnSfkk7MrW/kfSKpBckPSDpjFTvkHRE0rZ0+1FmnWWStkvqkbRBqnwiQNIMSfek+tOSOhr/NDPPZyI3bmZ2EjuePYTbgZXDao8CF0XE7wD/CtyQmbcrIpam29WZ+kZgHbAk3arbvAp4OyLOA24BbjrhZ2FmZnUbNxAi4gng4LDaIxHRn358Clgw1jYkzQNmRcSTUTmaeydweZq9CrgjTd8HrKjuPZiZ2eRpxDGEPwMezvy8WNLzkn4h6ZJUmw/0ZpbpTbXqvD0AKWTeAeY0oK9Czhozs2Jt9aws6XtAP/DTVOoDFkXEW5KWAf8g6UKKh+6rh3XHmjf88dZRGXZi0aJF9bRuZmbD1LyHIGkt8DXgT9IwEBFxNCLeStNbgV3A+VT2CLLDSguAvWm6F1iYttkGzGbYEFVVRGyKiM6I6Gxvb6+19bStulY3M/vYqSkQJK0E/hL4ekQcztTbJbWm6XOoHDx+LSL6gEOSLk7HB9YAD6bVtgBr0/QVwOMxgZ8a84CRmVmxcYeMJN0FfAmYK6kX+D6Vs4pmAI+mMfmn0hlFXwT+h6R+YAC4OiKqf+1fQ+WMpZlUjjlUjzvcBvxEUg+VPYOuhjyzcfgSmmZmeeMGQkSsLijfNsqy9wP3jzKvG7iooP4BcOV4fTSKjymbmRUr7SeVzcwsr7SB4IPKZmZ5pQsEDxmZmRUrXSCYmVmx0gaCR4zMzPJKFwi+hKaZWbHSBYKZmRUrbSD4EppmZnmlCwSfZWRmVqx0gVDl/QMzs7zSBoKZmeU5EMzMDChxIPiYsplZXukCwZfQNDMrVrpAMDOzYiUOBI8ZmZlllS4QPGBkZlasdIFgZmbFShsIPsvIzCxv3ECQtFnSfkk7MrWzJD0q6dV0f2Zm3g2SeiTtlHRppr5M0vY0b4PS6T6SZki6J9WfltTR4Oc47PlM5NbNzE5ex7OHcDuwcljteuCxiFgCPJZ+RtIFQBdwYVrnh5Ja0zobgXXAknSrbvMq4O2IOA+4Bbip1idzIryDYGaWN24gRMQTwMFh5VXAHWn6DuDyTP3uiDgaEa8DPcBySfOAWRHxZFS+ZvTOYetUt3UfsEIT+GEBXw/BzKxYrccQzo6IPoB0/8lUnw/sySzXm2rz0/Twem6diOgH3gHmFD2opHWSuiV1HzhwoMbWzcysSKMPKhf9+R1j1MdaZ2QxYlNEdEZEZ3t7e40tVrdV1+pmZh87tQbCvjQMRLrfn+q9wMLMcguAvam+oKCeW0dSGzCbkUNUDeODymZmxWoNhC3A2jS9FngwU+9KZw4tpnLw+Jk0rHRI0sXp+MCaYetUt3UF8Hj4cmZmZpOubbwFJN0FfAmYK6kX+D7wV8C9kq4Cfg1cCRARL0q6F3gJ6AeujYiBtKlrqJyxNBN4ON0AbgN+IqmHyp5BV0Oe2TjC5xmZmeWMGwgRsXqUWStGWX49sL6g3g1cVFD/gBQok8EjRmZmxfxJZTMzA0oYCD6obGZWrHSBYGZmxUobCB4yMjPLK2EgeMzIzKxICQPBzMyKlDYQ/DkEM7O80gWCzzIyMytWukAwM7NipQ0En2VkZpZXukDwiJGZWbHSBYKZmRUrXSBUr87pISMzs7zSBUJLGjPyaadmZnmlC4TqaaeDzgMzs5wSBkJ1yMiJYGaWVbpAaEmB4D0EM7O80gVC9bRT7yGYmeWVLhCqewiOAzOzvJoDQdJnJG3L3N6V9F1JN0p6I1O/LLPODZJ6JO2UdGmmvkzS9jRvgzRx3zhUPcto0GNGZmY5NQdCROyMiKURsRRYBhwGHkizb6nOi4iHACRdAHQBFwIrgR9Kak3LbwTWAUvSbWWtfY3LZxmZmRVq1JDRCmBXRPxqjGVWAXdHxNGIeB3oAZZLmgfMiognozKwfydweYP6GuGjISMngplZVqMCoQu4K/PzdZJekLRZ0pmpNh/Yk1mmN9Xmp+nh9REkrZPULan7wIEDNTX60UHlmlY3M/vYqjsQJE0Hvg78fSptBM4FlgJ9wM3VRQtWjzHqI4sRmyKiMyI629vba+q3pcVfXWFmVqQRewhfBZ6LiH0AEbEvIgYiYhC4FVielusFFmbWWwDsTfUFBfUJMXRQ2YlgZpbTiEBYTWa4KB0TqPoGsCNNbwG6JM2QtJjKweNnIqIPOCTp4nR20RrgwQb0NYrqB9McCGZmWW31rCzpVODLwLcz5b+WtJTKsM/u6ryIeFHSvcBLQD9wbUQMpHWuAW4HZgIPp9uE+OjL7czMLKuuQIiIw8CcYbVvjbH8emB9Qb0buKieXo6Xv8vIzKxYCT+pXLl3HpiZ5ZUwEPzldmZmRUoXCFU+qGxmlle6QGjxJTTNzAqVLhA0dAzBiWBmllW6QPDXX5uZFSthIFTufQzBzCyvdIEgf/21mVmhEgaCP5hmZlakfIGQ7p0HZmZ5pQsEXyDHzKxYaQNhcLDJjZiZTTGlCwT5LCMzs0KlDQTHgZlZXukCocVnGZmZFSpdIPhzCGZmxUoXCP5yOzOzYqULhOrnEHxQ2cwsr65AkLRb0nZJ2yR1p9pZkh6V9Gq6PzOz/A2SeiTtlHRppr4sbadH0gZVP048AeQvtzMzK9SIPYT/GBFLI6Iz/Xw98FhELAEeSz8j6QKgC7gQWAn8UFJrWmcjsA5Ykm4rG9BXoRZ//bWZWaGJGDJaBdyRpu8ALs/U746IoxHxOtADLJc0D5gVEU9G5V36zsw6DaehD6Y5EMzMsuoNhAAekbRV0rpUOzsi+gDS/SdTfT6wJ7Nub6rNT9PD6xOixZ9DMDMr1Fbn+l+IiL2SPgk8KumVMZYtOi4QY9RHbqASOusAFi1adKK9pibSHoITwcwsp649hIjYm+73Aw8Ay4F9aRiIdL8/Ld4LLMysvgDYm+oLCupFj7cpIjojorO9vb2mntUytK2a1jcz+7iqORAkfULS6dVp4CvADmALsDYtthZ4ME1vAbokzZC0mMrB42fSsNIhSRens4vWZNZpuKEvt3MgmJnl1DNkdDbwQDpI2wb8LCL+UdKzwL2SrgJ+DVwJEBEvSroXeAnoB66NiIG0rWuA24GZwMPpNiFa5SEjM7MiNQdCRLwGfLag/hawYpR11gPrC+rdwEW19nIiWtNR5f4Bf/+1mVlW6T6p3FYNBO8imJnllC4QWlpEi2DAgWBmllO6QABoa2nhwwEHgplZVjkDoVUM+BqaZmY5pQyE1hZ5D8HMbJhSBsK01hYfQzAzG6aUgdDaIvo9ZGRmllPKQGhrEf0eMjIzyylnILTKQ0ZmZsOUMxBaWvjQgWBmllPSQPBpp2Zmw5UzEFpbONbvPQQzs6xSBsLMaS0c7R8Yf0EzsxIpZyBMb+XwMQeCmVlWOQNhWitHHAhmZjmlDIRTprXywYcOBDOzrFIGwsxprRxxIJiZ5ZQzEKY7EMzMhitnIPgYgpnZCDUHgqSFkv5J0suSXpT0nVS/UdIbkral22WZdW6Q1CNpp6RLM/VlkraneRskqb6nNbZTprVytH+QQX9a2cxsSFsd6/YD/ykinpN0OrBV0qNp3i0R8T+zC0u6AOgCLgQ+BfxfSedHxACwEVgHPAU8BKwEHq6jtzHNnN4KwAf9A5w6vZ6XwMzs46PmPYSI6IuI59L0IeBlYP4Yq6wC7o6IoxHxOtADLJc0D5gVEU9GRAB3ApfX2tfxmDmtEggeNjIz+0hDjiFI6gA+BzydStdJekHSZklnptp8YE9mtd5Um5+mh9eLHmedpG5J3QcOHKi536FA8IFlM7MhdQeCpNOA+4HvRsS7VIZ/zgWWAn3AzdVFC1aPMeojixGbIqIzIjrb29tr7vmU6pCRA8HMbEhdgSBpGpUw+GlE/BwgIvZFxEBEDAK3AsvT4r3AwszqC4C9qb6goD5hqnsI/voKM7OP1HOWkYDbgJcj4geZ+rzMYt8AdqTpLUCXpBmSFgNLgGciog84JOnitM01wIO19nU8Zp1SOZD87pH+iXwYM7OTSj2n2HwB+BawXdK2VPuvwGpJS6kM++wGvg0QES9Kuhd4icoZStemM4wArgFuB2ZSObtows4wAjjrE9MBOHj42EQ+jJnZSaXmQIiIf6F4/P+hMdZZD6wvqHcDF9Xay4mqBsLb7zsQzMyqSvlJ5dkzpyHBWw4EM7MhpQyEttYWZs+c5j0EM7OMUgYCwFmnTvcxBDOzjPIGwiemew/BzCyj1IHw5ntHm92GmdmUUdpA+NQZM9n72w+ofH2SmZmVNhAWnDmT947289vDHza7FTOzKaG0gbDwrFMB2PP24SZ3YmY2NZQ3EM5MgXDwSJM7MTObGkobCIvmVALh9Tffa3InZmZTQ2kD4bQZbXx6zqm81Pdus1sxM5sSShsIABfMm8X2N95pdhtmZlNCqQOhs+Ms9hw8wp6DPrBsZlbqQPgP588F4J9ffbPJnZiZNV+pA+Hc9tP49JxTeeD53vEXNjP7mCt1IEhize928Ozut+nefbDZ7ZiZNVWpAwHgj//tQubNPoX/cv8LvPuBP7VsZuVV+kA4bUYbN1/5WX791mFWb3qKV/cdanZLZmZNUfpAAPj3583l1jWdvPHbI3zlfz3Bn/7dM/z4n1/jO3c/z/O/fnvMdf2NqWb2caGp8m2fklYC/xtoBX4cEX811vKdnZ3R3d3d0B7eeu8om//f6zy0/Te8/ub7Q/X5Z8zkjd9WvuLi7FkzOPj+Mb51cQcPbnuDt94/xh99fgG/t2QO01tbmd7WwvS2FlolWlqgVaK1RSjdV+vTWluYdco0WlQ5lpG7R6gFWiREuhdIlelqXWkdM7PjJWlrRHQWzpsKgSCpFfhX4MtAL/AssDoiXhptnYkIhKy+d47wtQ3/wvlnn868M07hH3f8hsPHBibs8WolgYDB9Guc3toC4qNgSXmh3Dof/ZSLExVOjggdHcdy+fooDzLmtrJ1jbt8UZ+F6xzHdkf2Ut/rZdZo31mxhD/47KdqWnesQGirq6vGWQ70RMRrAJLuBlYBowbCRJs3eyZb/9uXh37+wTdHLjM4GBwbGOTg+8c41j/IsYFBjn44yLGBAQYGYWAwGIwYuq9MV+rHBgZ574N+BiMIICIYHKxMD0bl5wiG5g+mnyOCwRj58/Y33uHfzJuFlJ8HkM38bPzn61FYHy77B8SJbmv4ZvOPM8o6x9nj8fQyyuSIa2KMvq3idUZ5GmYTZvbMaROy3akSCPOBPZmfe4F/N3whSeuAdQCLFi2anM7G0NIiTmlp5VNnzGx2K2ZmdZsqB5WL9rFH/K0VEZsiojMiOtvb2yehLTOz8pgqgdALLMz8vADY26RezMxKaaoEwrPAEkmLJU0HuoAtTe7JzKxUpsQxhIjol3Qd8H+onHa6OSJebHJbZmalMiUCASAiHgIeanYfZmZlNVWGjMzMrMkcCGZmBjgQzMwsmRJfXVELSQeAX9W4+lxgKl4mzX2dGPd14qZqb+7rxNTT16cjovCDXCdtINRDUvdo3+XRTO7rxLivEzdVe3NfJ2ai+vKQkZmZAQ4EMzNLyhoIm5rdwCjc14lxXyduqvbmvk7MhPRVymMIZmY2Uln3EMzMbBgHgpmZASUMBEkrJe2U1CPp+gl+rIWS/knSy5JelPSdVL9R0huStqXbZZl1bki97ZR0aaa+TNL2NG+D6rxOo6TdaXvbJHWn2lmSHpX0aro/czL7kvSZzGuyTdK7kr7brNdL0mZJ+yXtyNQa9hpJmiHpnlR/WlJHHX39jaRXJL0g6QFJZ6R6h6QjmdfuR5PcV8N+dw3u655MT7slbZvM10ujvzc0999X5VKN5bhR+SbVXcA5wHTgl8AFE/h484DPp+nTqVw3+gLgRuA/Fyx/QeppBrA49dqa5j0D/C6Viwk9DHy1zt52A3OH1f4auD5NXw/cNNl9Dftd/Qb4dLNeL+CLwOeBHRPxGgF/DvwoTXcB99TR11eAtjR9U6avjuxyw7YzGX017HfXyL6Gzb8Z+O+T+Xox+ntDU/99lW0PYejazRFxDKheu3lCRERfRDyXpg8BL1O5XOhoVgF3R8TRiHgd6AGWS5oHzIqIJ6Py270TuHwCWl4F3JGm78g8RjP6WgHsioixPo0+oX1FxBPAwYLHbNRrlN3WfcCK49mTKeorIh6JiP7041NULjI1qsnqawxNfb2q0vrfBO4aaxuN7muM94am/vsqWyAUXbt5rDfohkm7a58Dnk6l69Lu/ebMbuFo/c1P08Pr9QjgEUlbVblWNcDZEdEHlX+wwCeb0FdVF/n/pM1+vaoa+RoNrZPezN8B5jSgxz+j8pdi1WJJz0v6haRLMo89WX016nc3Ea/XJcC+iHg1U5vU12vYe0NT/32VLRCO69rNDX9Q6TTgfuC7EfEusBE4F1gK9FHZZR2rv4no+wsR8Xngq8C1kr44xrKT2ReqXDXv68Dfp9JUeL3GU0svDe9T0veAfuCnqdQHLIqIzwF/AfxM0qxJ7KuRv7uJ+L2uJv+Hx6S+XgXvDaMuOspjNLSvsgXCpF+7WdI0Kr/wn0bEzwEiYl9EDETEIHArlaGssfrrJT8EUHffEbE33e8HHkg97Eu7oNVd5P2T3VfyVeC5iNiXemz665XRyNdoaB1JbcBsjn/IZQRJa4GvAX+Shg9IQwxvpemtVMaez5+svhr8u2v069UG/CFwT6bfSXu9it4baPK/r7IFwqReuzmN190GvBwRP8jU52UW+wZQPfthC9CVzg5YDCwBnkm7jockXZy2uQZ4sI6+PiHp9Oo0lQOSO9Ljr02Lrc08xqT0lZH7q63Zr9cwjXyNstu6Ani8+kZ+oiStBP4S+HpEHM7U2yW1pulzUl+vTWJfjfzdNayv5PeBVyJiaMhlsl6v0d4baPa/r/GOOn/cbsBlVI7o7wK+N8GP9XtUdtFeALal22XAT4Dtqb4FmJdZ53upt51kzowBOqn8Z9oF/C3pU+Y19nUOlTMWfgm8WH0dqIwvPga8mu7Pmsy+0vZOBd4CZmdqTXm9qIRSH/Ahlb+2rmrkawScQmVYrIfKmSLn1NFXD5Xx4uq/s+rZJX+Ufse/BJ4D/mCS+2rY766RfaX67cDVw5adlNeL0d8bmvrvy19dYWZmQPmGjMzMbBQOBDMzAxwIZmaWOBDMzAxwIJiZWeJAMDMzwIFgZmbJ/weihdbY3E0sqgAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "tensor_s, tensor_u, tensor_x = main_AE(args, adata)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8950947162e449fb9d83a42c7f72991b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/6253 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "# Cross-Boundary Transition Score (A->B)\n", + "{('COPs', 'NFOLs'): 0.2392812, ('NFOLs', 'MFOLs'): 0.104577646}\n", + "Total Mean: 0.17192941904067993\n", + "# Cross-Boundary Velocity Coherence (A->B)\n", + "{('COPs', 'NFOLs'): 0.9953603, ('NFOLs', 'MFOLs'): 0.996584}\n", + "Total Mean: 0.9959721565246582\n", + "# Cross-Boundary Direction Correctness (A->B)\n", + "{('COPs', 'NFOLs'): 0.8487240633966037, ('NFOLs', 'MFOLs'): 0.5442262991169514}\n", + "Total Mean: 0.6964751812567775\n", + "# In-cluster Coherence\n", + "{'COPs': 0.9963123, 'MFOLs': 0.99721044, 'NFOLs': 0.99829817, 'OPCs': 0.99997526}\n", + "Total Mean: 0.9979490041732788\n", + "# In-cluster Confidence\n", + "{'COPs': 0.996084103778113, 'MFOLs': 0.9970275277809268, 'NFOLs': 0.9978537890824896, 'OPCs': 0.9856352421545214}\n", + "Total Mean: 0.9941501656990126\n" + ] + } + ], + "source": [ + "def exp(adata, exp_metrics):\n", + " model = init_model(adata, args, device)\n", + " model.load_state_dict(torch.load(args.model_name))\n", + " model = model.to(device)\n", + " model.eval()\n", + " with torch.no_grad():\n", + " x = model.encoder(tensor_x)\n", + " s = model.encoder(tensor_s)\n", + " u = model.encoder(tensor_u)\n", + " \n", + " v = estimate_ld_velocity(s, u, device=device, perc=[5, 95], \n", + " norm=args.use_norm, fit_offset=args.fit_offset_pred, \n", + " use_offset=args.use_offset_pred).cpu().numpy()\n", + " x = x.cpu().numpy()\n", + " s = s.cpu().numpy()\n", + " u = u.cpu().numpy()\n", + " \n", + " adata = new_adata(adata, x, s, u, v, g_basis=args.ld_nb_g_src, n_nb_newadata=args.n_nb_newadata)\n", + " scv.tl.velocity_graph(adata, vkey='new_velocity', n_jobs=args.scv_n_jobs)\n", + " scv.pl.velocity_embedding_stream(adata, vkey=\"new_velocity\", basis=args.vis_key, color=args.vis_type_col,\n", + " title=\"Project Original Velocity into Low-Dim Space\",\n", + " dpi=150,\n", + " save='oligodendrocyte_pojection.png') \n", + " scv.tl.velocity_confidence(adata, vkey='new_velocity')\n", + " exp_metrics['Cohort AutoEncoder'] = evaluate(adata, cluster_edges, args.vis_type_col, \"new_velocity\")\n", + " \n", + "exp(adata, exp_metrics)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Abundance of ['spliced', 'unspliced']: [0.64 0.36]\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b365cfd20d234ea0bb23482306bf2916", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/757 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "# Cross-Boundary Transition Score (A->B)\n", + "{('COPs', 'NFOLs'): 0.39235646, ('NFOLs', 'MFOLs'): 0.43137965}\n", + "Total Mean: 0.41186803579330444\n", + "# Cross-Boundary Velocity Coherence (A->B)\n", + "{('COPs', 'NFOLs'): 0.8339525880108608, ('NFOLs', 'MFOLs'): 0.7648149955134319}\n", + "Total Mean: 0.7993837917621464\n", + "# Cross-Boundary Direction Correctness (A->B)\n", + "{('COPs', 'NFOLs'): -0.045721793801200644, ('NFOLs', 'MFOLs'): 0.09865751437961526}\n", + "Total Mean: 0.026467860289207307\n", + "# In-cluster Coherence\n", + "{'COPs': 0.9741825131895359, 'MFOLs': 0.8120565060783369, 'NFOLs': 0.7916376122571769, 'OPCs': 0.9990290240941994}\n", + "Total Mean: 0.8942264139048123\n", + "# In-cluster Confidence\n", + "{'COPs': 0.9734623313499147, 'MFOLs': 0.8024698073010618, 'NFOLs': 0.7887305956486915, 'OPCs': 0.9938755845841556}\n", + "Total Mean: 0.8896345797209559\n" + ] + } + ], + "source": [ + "adata = scanpy.read_h5ad(args.data_dir)\n", + "scv.utils.show_proportions(adata)\n", + "scv.pp.filter_and_normalize(adata, min_shared_counts=30, n_top_genes=args.n_raw_gene)\n", + "scv.pp.neighbors(adata, n_pcs=30, n_neighbors=30)\n", + "scv.pp.moments(adata, n_pcs=30, n_neighbors=30)\n", + "\n", + "scv.tl.recover_dynamics(adata, n_jobs=args.scv_n_jobs)\n", + "scv.tl.velocity(adata, vkey='dyn_velocity', mode=\"dynamical\")\n", + "\n", + "scv.tl.velocity_graph(adata, vkey='dyn_velocity', n_jobs=args.scv_n_jobs)\n", + "scv.tl.velocity_confidence(adata, vkey='dyn_velocity')\n", + "scv.pl.velocity_embedding_stream(adata, \n", + " vkey=\"dyn_velocity\", \n", + " basis=args.vis_key, \n", + " color=[args.vis_type_col],\n", + " dpi=150, \n", + " title='ScVelo Dynamical Mode')\n", + "exp_metrics[\"dyn_mode\"] = evaluate(adata[:, adata.var.dyn_velocity_genes], cluster_edges, args.vis_type_col, \"dyn_velocity\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "with open(\"{}.pkl\".format(EXP_NAME), 'wb') as out_file:\n", + " pickle.dump(exp_metrics, out_file)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "interpreter": { + "hash": "a8471068523155614cf4eb871cd7d17435a8456aa34f037779c06d1e28b048ac" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.12" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebooks/oligodendrocyte/oligo_model.cpt b/notebooks/oligodendrocyte/oligo_model.cpt new file mode 100644 index 0000000..545349e Binary files /dev/null and b/notebooks/oligodendrocyte/oligo_model.cpt differ diff --git a/notebooks/oligodendrocyte/velo_constraint_cluster_list.txt b/notebooks/oligodendrocyte/velo_constraint_cluster_list.txt new file mode 100644 index 0000000..7816e82 --- /dev/null +++ b/notebooks/oligodendrocyte/velo_constraint_cluster_list.txt @@ -0,0 +1,2 @@ +celltype +NFOLs \ No newline at end of file