Skip to content

Commit

Permalink
Fix RNN for next word prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
orifmilod committed May 13, 2024
1 parent 49bffbd commit 1f9f72c
Show file tree
Hide file tree
Showing 2 changed files with 242 additions and 97 deletions.
67 changes: 67 additions & 0 deletions example/language_model/pytorch/cbow.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,73 @@
"print(f'Context: {context}\\n')\n",
"print(f'Prediction: {index_to_word[torch.argmax(a[0]).item()]}')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cpu\n",
"mps:0\n",
"Total time with cpu (10000): 0.0005950000000005673\n",
"Total time with cpu (100000000): 0.0002990000000000492\n",
"available\n",
"Total time with gpu (10000): 0.0002899999999996794\n",
"Total time with gpu (100000000): 0.0003420000000007306\n"
]
}
],
"source": [
"from time import process_time\n",
"import torch\n",
"\n",
"print(torch.get_default_device())\n",
"torch.set_default_device('mps')\n",
"print(torch.get_default_device())\n",
"def testgpu():\n",
" if torch.backends.mps.is_available():\n",
" print(\"available\")\n",
" mps_device = torch.device(\"mps\")\n",
" t0 = process_time()\n",
" x = torch.ones(n1, device=mps_device)\n",
" y = x + torch.rand(n1, device=mps_device)\n",
" t1 = process_time()\n",
" print(f\"Total time with gpu ({n1}): {t1-t0}\")\n",
" t0 = process_time()\n",
" x = torch.ones(n2, device=mps_device)\n",
" y = x + torch.rand(n2, device=mps_device)\n",
" t1 = process_time()\n",
" print(f\"Total time with gpu ({n2}): {t1-t0}\")\n",
"\n",
"def testcpu():\n",
" t0 = process_time()\n",
" x = torch.ones(n1)\n",
" y = x + torch.rand(n1)\n",
" t1 = process_time()\n",
" print(f\"Total time with cpu ({n1}): {t1-t0}\")\n",
" t0 = process_time()\n",
" x = torch.ones(n2)\n",
" y = x + torch.rand(n2)\n",
" t1 = process_time()\n",
" print(f\"Total time with cpu ({n2}): {t1-t0}\")\n",
"\n",
"if __name__ == '__main__':\n",
" n1 = 10000\n",
" n2 = 100000000\n",
" testcpu()\n",
" testgpu()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
Loading

0 comments on commit 1f9f72c

Please sign in to comment.