diff --git a/example/language_model/language-model.py b/example/language_model/language-model.py index ea53fb020..9ce46e5fd 100644 --- a/example/language_model/language-model.py +++ b/example/language_model/language-model.py @@ -2,7 +2,7 @@ import torch.nn.functional as F import matplotlib.pyplot as plt -# How much tokens to keep as context when making the prediction for the next one +# Number of tokens to keep as context when making the prediction for the next one CONTEXT_SIZE = 3 # Size of the vector to represent a single token EMBEDDING_SIZE = 10 diff --git a/gigatorch/nn.py b/gigatorch/nn.py index e83865a27..44c6597f8 100644 --- a/gigatorch/nn.py +++ b/gigatorch/nn.py @@ -58,7 +58,6 @@ def __call__(self, x): def calc_loss(self, ys, y_pred): # Convertin y_pred to probabilities prob = [self.prob_fn(i, y_pred) for i in y_pred] - print("Prob", prob) loss = sum(self.loss_fn(ys, y_pred), Tensor(0)) loss.backprop() return loss.data diff --git a/requirements.txt b/requirements.txt index 5293a5e0e..6f7c10d10 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,21 +1,55 @@ +appnope==0.1.4 +asttokens==2.4.1 black==24.2.0 click==8.1.7 +comm==0.2.2 +contourpy==1.2.0 +cycler==0.12.1 +debugpy==1.8.1 +decorator==5.1.1 +executing==2.0.1 filelock==3.13.1 +fonttools==4.49.0 fsspec==2024.2.0 iniconfig==2.0.0 +ipykernel==6.29.3 +ipython==8.22.2 +jedi==0.19.1 Jinja2==3.1.3 +jupyter_client==8.6.1 +jupyter_core==5.7.2 +kiwisolver==1.4.5 MarkupSafe==2.1.5 +matplotlib==3.8.3 +matplotlib-inline==0.1.6 mpmath==1.3.0 mypy-extensions==1.0.0 +nest-asyncio==1.6.0 networkx==3.2.1 numpy==1.26.4 packaging==23.2 +parso==0.8.3 pathspec==0.12.1 +pexpect==4.9.0 pillow==10.2.0 platformdirs==4.2.0 pluggy==1.4.0 +prompt-toolkit==3.0.43 +psutil==5.9.8 +ptyprocess==0.7.0 +pure-eval==0.2.2 +Pygments==2.17.2 +pyparsing==3.1.2 pytest==8.0.2 +python-dateutil==2.9.0.post0 +pyzmq==25.1.2 setuptools-black==0.1.5 +six==1.16.0 +stack-data==0.6.3 sympy==1.12 torch==2.2.1 +torchvision==0.17.1 +tornado==6.4 +traitlets==5.14.2 typing_extensions==4.10.0 +wcwidth==0.2.13