Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MisconfigurationException: configure_optimizers must include a monitor when a ReduceLROnPlateau scheduler is used #79

Open
adewdev opened this issue Jul 11, 2023 · 5 comments

Comments

@adewdev
Copy link

adewdev commented Jul 11, 2023

(spacetimeformer) adewdev@bootedeagle:~/Devtools/PyWS0/spacetimeformer-main-original/spacetimeformer$ python train.py spacetimeformer solar_energy --context_points 168 --target_points 24 --d_model 100 --d_ff 400 --enc_layers 5 --dec_layers 5 --l2_coeff 1e-3 --dropout_ff .2 --dropout_emb .1 --d_qk 20 --d_v 20 --n_heads 6 --run_name spatiotemporal_al_solar --batch_size 32 --class_loss_imp 0 --initial_downsample_convs 1 --decay_factor .8 --warmup_steps 1000
Using default wandb log dir path of ./data/STF_LOG_DIR. This can be adjusted with the environment variable STF_LOG_DIR
Forecaster
L2: 0.001
Linear Window: 0
Linear Shared Weights: False
RevIN: False
Decomposition: False
GlobalSelfAttn: AttentionLayer(
(inner_attention): PerformerAttention(
(kernel_fn): ReLU()
)
(query_projection): Linear(in_features=100, out_features=120, bias=True)
(key_projection): Linear(in_features=100, out_features=120, bias=True)
(value_projection): Linear(in_features=100, out_features=120, bias=True)
(out_projection): Linear(in_features=120, out_features=100, bias=True)
(dropout_qkv): Dropout(p=0.0, inplace=False)
)
GlobalCrossAttn: AttentionLayer(
(inner_attention): PerformerAttention(
(kernel_fn): ReLU()
)
(query_projection): Linear(in_features=100, out_features=120, bias=True)
(key_projection): Linear(in_features=100, out_features=120, bias=True)
(value_projection): Linear(in_features=100, out_features=120, bias=True)
(out_projection): Linear(in_features=120, out_features=100, bias=True)
(dropout_qkv): Dropout(p=0.0, inplace=False)
)
LocalSelfAttn: AttentionLayer(
(inner_attention): PerformerAttention(
(kernel_fn): ReLU()
)
(query_projection): Linear(in_features=100, out_features=120, bias=True)
(key_projection): Linear(in_features=100, out_features=120, bias=True)
(value_projection): Linear(in_features=100, out_features=120, bias=True)
(out_projection): Linear(in_features=120, out_features=100, bias=True)
(dropout_qkv): Dropout(p=0.0, inplace=False)
)
LocalCrossAttn: AttentionLayer(
(inner_attention): PerformerAttention(
(kernel_fn): ReLU()
)
(query_projection): Linear(in_features=100, out_features=120, bias=True)
(key_projection): Linear(in_features=100, out_features=120, bias=True)
(value_projection): Linear(in_features=100, out_features=120, bias=True)
(out_projection): Linear(in_features=120, out_features=100, bias=True)
(dropout_qkv): Dropout(p=0.0, inplace=False)
)
Using Embedding: spatio-temporal
Time Emb Dim: 6
Space Embedding: True
Time Embedding: True
Val Embedding: True
Given Embedding: True
Null Value: None
Pad Value: None
Reconstruction Dropout: Timesteps 0.05, Standard 0.1, Seq (max len = 5) 0.2, Skip All Drop 1.0
*** Spacetimeformer (v1.5) Summary: ***
Model Dim: 100
FF Dim: 400
Enc Layers: 5
Dec Layers: 5
Embed Dropout: 0.1
FF Dropout: 0.2
Attn Out Dropout: 0.0
Attn Matrix Dropout: 0.0
QKV Dropout: 0.0
L2 Coeff: 0.001
Warmup Steps: 1000
Normalization Scheme: batch
Attention Time Windows: 1
Shifted Time Windows: False
Position Emb Type: abs
Recon Loss Imp: 0.0


/home/adewdev/Devtools/anaconda3/envs/spacetimeformer/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:287: LightningDeprecationWarning: Passing Trainer(accelerator='dp') has been deprecated in v1.5 and will be removed in v1.7. Use Trainer(strategy='dp') instead.
rank_zero_deprecation(
/home/adewdev/Devtools/anaconda3/envs/spacetimeformer/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:597: UserWarning: 'dp' is not supported on CPUs, hence setting strategy='ddp'.
rank_zero_warn(f"{strategy_flag!r} is not supported on CPUs, hence setting strategy='ddp'.")
/home/adewdev/Devtools/anaconda3/envs/spacetimeformer/lib/python3.8/site-packages/pytorch_lightning/loops/utilities.py:91: PossibleUserWarning: max_epochs was not set. Setting it to 1000 epochs. To train without an epoch limit, set max_epochs=-1.
rank_zero_warn(
GPU available: True, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/adewdev/Devtools/anaconda3/envs/spacetimeformer/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1823: PossibleUserWarning: GPU available but not used. Set accelerator and devices using Trainer(accelerator='gpu', devices=1).
rank_zero_warn(
Trainer(limit_val_batches=1.0) was configured so 100% of the batches will be used..
Trainer(val_check_interval=1.0) was configured so validation will run at the end of the training epoch..
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1

distributed_backend=gloo
All distributed processes registered. Starting with 1 processes

Traceback (most recent call last):
File "train.py", line 869, in
main(args)
File "train.py", line 849, in main
trainer.fit(forecaster, datamodule=data_module)
File "/home/adewdev/Devtools/anaconda3/envs/spacetimeformer/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 771, in fit
self._call_and_handle_interrupt(
File "/home/adewdev/Devtools/anaconda3/envs/spacetimeformer/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 722, in _call_and_handle_interrupt
return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
File "/home/adewdev/Devtools/anaconda3/envs/spacetimeformer/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 93, in launch
return function(*args, **kwargs)
File "/home/adewdev/Devtools/anaconda3/envs/spacetimeformer/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 812, in _fit_impl
results = self._run(model, ckpt_path=self.ckpt_path)
File "/home/adewdev/Devtools/anaconda3/envs/spacetimeformer/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1218, in _run
self.strategy.setup(self)
File "/home/adewdev/Devtools/anaconda3/envs/spacetimeformer/lib/python3.8/site-packages/pytorch_lightning/strategies/ddp.py", line 155, in setup
super().setup(trainer)
File "/home/adewdev/Devtools/anaconda3/envs/spacetimeformer/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py", line 139, in setup
self.setup_optimizers(trainer)
File "/home/adewdev/Devtools/anaconda3/envs/spacetimeformer/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py", line 128, in setup_optimizers
self.optimizers, self.lr_scheduler_configs, self.optimizer_frequencies = _init_optimizers_and_lr_schedulers(
File "/home/adewdev/Devtools/anaconda3/envs/spacetimeformer/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py", line 190, in _init_optimizers_and_lr_schedulers
_configure_schedulers_automatic_opt(lr_schedulers, monitor)
File "/home/adewdev/Devtools/anaconda3/envs/spacetimeformer/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py", line 305, in _configure_schedulers_automatic_opt
raise MisconfigurationException(
pytorch_lightning.utilities.exceptions.MisconfigurationException: configure_optimizers must include a monitor when a ReduceLROnPlateau scheduler is used. For example: {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}
(spacetimeformer) adewdev@bootedeagle:~/Devtools/PyWS0/spacetimeformer-main-original/spacetimeformer$ conda list

packages in environment at /home/adewdev/Devtools/anaconda3/envs/spacetimeformer:

Name Version Build Channel

_libgcc_mutex 0.1 main
_openmp_mutex 5.1 1_gnu
absl-py 1.4.0 pypi_0 pypi
aiohttp 3.8.4 pypi_0 pypi
aiosignal 1.3.1 pypi_0 pypi
antlr4-python3-runtime 4.9.3 pypi_0 pypi
appdirs 1.4.4 pypi_0 pypi
async-timeout 4.0.2 pypi_0 pypi
attrs 23.1.0 pypi_0 pypi
axial-positional-embedding 0.2.1 pypi_0 pypi
blas 1.0 mkl
brotlipy 0.7.0 py38h27cfd23_1003
bzip2 1.0.8 h7b6447c_0
ca-certificates 2023.05.30 h06a4308_0
cachetools 5.3.1 pypi_0 pypi
certifi 2023.5.7 py38h06a4308_0
cffi 1.15.0 py38h7f8727e_0
cftime 1.6.2 pypi_0 pypi
chardet 5.1.0 pypi_0 pypi
charset-normalizer 3.2.0 pypi_0 pypi
click 8.1.4 pypi_0 pypi
cmake 3.26.4 pypi_0 pypi
cmdstanpy 0.9.68 pypi_0 pypi
contourpy 1.1.0 pypi_0 pypi
convertdate 2.4.0 pypi_0 pypi
cryptography 39.0.1 py38h9ce1e76_0
cudatoolkit 11.3.1 h2bc3f7f_2
cycler 0.11.0 pypi_0 pypi
cython 0.29.36 pypi_0 pypi
docker-pycreds 0.4.0 pypi_0 pypi
einops 0.6.1 pypi_0 pypi
ffmpeg 4.3 hf484d3e_0 pytorch
filelock 3.12.2 pypi_0 pypi
fonttools 4.40.0 pypi_0 pypi
freetype 2.12.1 h4a9f257_0
frozenlist 1.3.3 pypi_0 pypi
fsspec 2023.6.0 pypi_0 pypi
giflib 5.2.1 h5eee18b_3
gitdb 4.0.10 pypi_0 pypi
gitpython 3.1.32 pypi_0 pypi
gmp 6.2.1 h295c915_3
gnutls 3.6.15 he1e5248_0
google-auth 2.21.0 pypi_0 pypi
google-auth-oauthlib 1.0.0 pypi_0 pypi
grpcio 1.56.0 pypi_0 pypi
idna 3.4 py38h06a4308_0
importlib-metadata 6.8.0 pypi_0 pypi
importlib-resources 6.0.0 pypi_0 pypi
intel-openmp 2021.4.0 h06a4308_3561
jinja2 3.1.2 pypi_0 pypi
joblib 1.3.1 pypi_0 pypi
jpeg 9e h5eee18b_1
kiwisolver 1.4.4 pypi_0 pypi
lame 3.100 h7b6447c_0
lcms2 2.12 h3be6417_0
lerc 3.0 h295c915_0
libdeflate 1.17 h5eee18b_0
libedit 3.1.20221030 h5eee18b_0
libffi 3.2.1 hf484d3e_1007
libgcc-ng 11.2.0 h1234567_1
libgomp 11.2.0 h1234567_1
libiconv 1.16 h7f8727e_2
libidn2 2.3.4 h5eee18b_0
libpng 1.6.39 h5eee18b_0
libstdcxx-ng 11.2.0 h1234567_1
libtasn1 4.19.0 h5eee18b_0
libtiff 4.5.0 h6a678d5_2
libunistring 0.9.10 h27cfd23_0
libuv 1.44.2 h5eee18b_0
libwebp 1.2.4 h11a3e52_1
libwebp-base 1.2.4 h5eee18b_1
lit 16.0.6 pypi_0 pypi
local-attention 1.8.6 pypi_0 pypi
lz4-c 1.9.4 h6a678d5_0
markdown 3.4.3 pypi_0 pypi
markupsafe 2.1.3 pypi_0 pypi
matplotlib 3.7.2 pypi_0 pypi
mkl 2021.4.0 h06a4308_640
mkl-service 2.4.0 py38h7f8727e_0
mkl_fft 1.3.1 py38hd3c417c_0
mkl_random 1.2.2 py38h51133e4_0
mpmath 1.3.0 pypi_0 pypi
multidict 6.0.4 pypi_0 pypi
ncurses 6.4 h6a678d5_0
netcdf4 1.6.4 pypi_0 pypi
nettle 3.7.3 hbbd107a_1
networkx 3.1 pypi_0 pypi
numpy 1.24.4 pypi_0 pypi
numpy-base 1.24.3 py38h31eccc5_0
nvidia-cublas-cu11 11.10.3.66 pypi_0 pypi
nvidia-cuda-cupti-cu11 11.7.101 pypi_0 pypi
nvidia-cuda-nvrtc-cu11 11.7.99 pypi_0 pypi
nvidia-cuda-runtime-cu11 11.7.99 pypi_0 pypi
nvidia-cudnn-cu11 8.5.0.96 pypi_0 pypi
nvidia-cufft-cu11 10.9.0.58 pypi_0 pypi
nvidia-curand-cu11 10.2.10.91 pypi_0 pypi
nvidia-cusolver-cu11 11.4.0.1 pypi_0 pypi
nvidia-cusparse-cu11 11.7.4.91 pypi_0 pypi
nvidia-nccl-cu11 2.14.3 pypi_0 pypi
nvidia-nvtx-cu11 11.7.91 pypi_0 pypi
nystrom-attention 0.0.11 pypi_0 pypi
oauthlib 3.2.2 pypi_0 pypi
omegaconf 2.3.0 pypi_0 pypi
opencv-python 4.8.0.74 pypi_0 pypi
openh264 2.1.1 h4ff587b_0
openssl 1.1.1u h7f8727e_0
opt-einsum 3.3.0 pypi_0 pypi
packaging 23.1 pypi_0 pypi
pandas 2.0.3 pypi_0 pypi
pathtools 0.1.2 pypi_0 pypi
performer-pytorch 1.1.4 pypi_0 pypi
pillow 10.0.0 pypi_0 pypi
pip 23.1.2 py38h06a4308_0
protobuf 4.23.4 pypi_0 pypi
psutil 5.9.5 pypi_0 pypi
pyasn1 0.5.0 pypi_0 pypi
pyasn1-modules 0.3.0 pypi_0 pypi
pycparser 2.21 pyhd3eb1b0_0
pydeprecate 0.3.2 pypi_0 pypi
pymeeus 0.5.12 pypi_0 pypi
pyopenssl 23.0.0 py38h06a4308_0
pyparsing 3.0.9 pypi_0 pypi
pysocks 1.7.1 py38h06a4308_0
pystan 2.19.1.1 pypi_0 pypi
python 3.8.0 h0371630_2
python-dateutil 2.8.2 pypi_0 pypi
pytorch 1.11.0 py3.8_cuda11.3_cudnn8.2.0_0 pytorch
pytorch-lightning 1.6.0 pypi_0 pypi
pytorch-mutex 1.0 cuda pytorch
pytz 2023.3 pypi_0 pypi
pyyaml 6.0 pypi_0 pypi
readline 7.0 h7b6447c_5
requests 2.31.0 pypi_0 pypi
requests-oauthlib 1.3.1 pypi_0 pypi
rsa 4.9 pypi_0 pypi
scikit-learn 1.3.0 pypi_0 pypi
scipy 1.10.1 pypi_0 pypi
seaborn 0.12.2 pypi_0 pypi
sentry-sdk 1.28.0 pypi_0 pypi
setproctitle 1.3.2 pypi_0 pypi
setuptools 67.8.0 py38h06a4308_0
six 1.16.0 pyhd3eb1b0_1
smmap 5.0.0 pypi_0 pypi
spacetimeformer 1.5.0 dev_0
sqlite 3.33.0 h62c20be_0
sympy 1.12 pypi_0 pypi
tensorboard 2.13.0 pypi_0 pypi
tensorboard-data-server 0.7.1 pypi_0 pypi
threadpoolctl 3.1.0 pypi_0 pypi
tk 8.6.12 h1ccaba5_0
torch 2.0.1 pypi_0 pypi
torchaudio 0.11.0 py38_cu113 pytorch
torchmetrics 0.5.1 pypi_0 pypi
torchvision 0.12.0 py38_cu113 pytorch
tqdm 4.65.0 pypi_0 pypi
triton 2.0.0 pypi_0 pypi
typing-extensions 4.7.1 pypi_0 pypi
typing_extensions 4.6.3 py38h06a4308_0
tzdata 2023.3 pypi_0 pypi
ujson 5.8.0 pypi_0 pypi
urllib3 1.26.16 py38h06a4308_0
wandb 0.15.5 pypi_0 pypi
werkzeug 2.3.6 pypi_0 pypi
wheel 0.38.4 py38h06a4308_0
xz 5.4.2 h5eee18b_0
yarl 1.9.2 pypi_0 pypi
zipp 3.16.0 pypi_0 pypi
zlib 1.2.13 h5eee18b_0
zstd 1.5.5 hc292b87_0

i followed all the instructions but getting above error. not sure if i missed any.

@SEU-ccq
Copy link

SEU-ccq commented Jul 25, 2023

I'm having the same problem, have you solved it yet?

@HiFei4869
Copy link

HiFei4869 commented Jul 28, 2023

I have the same problem.
I tried adding the following code to train.py. But it doesn't work.

def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), lr=0.02)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                       mode='min',
                                                       factor=0.2,
                                                       patience=2,
                                                       min_lr=1e-6,
                                                       verbose=True)
return {
   "optimizer": optimizer,
   "lr_scheduler": scheduler, # Changed scheduler to lr_scheduler
   "monitor": "metric_to_track"
}

Sad to find that many people have the same unsolved problem.

@HiFei4869
Copy link

HiFei4869 commented Jul 29, 2023

@adewdev @seuccq
I tried to modify the spacetimeformer_model.py. It's under spacetimeformer/spacetimeformer_model.
In the function configure_optimizers, I modified the code by adding a monitor in return value.
It looks like this:

def configure_optimizers(self):
    optimizer = torch.optim.AdamW(
        self.parameters(),
        lr=self.base_lr,
        weight_decay=self.l2_coeff,
        )
    scheduler = stf.lr_scheduler.WarmupReduceLROnPlateau(
        optimizer,
        init_lr=self.init_lr,
        peak_lr=self.base_lr,
        warmup_steps=self.warmup_steps,
        patience=3,
        factor=self.decay_factor,
    )
    monitor = 'val/loss'
   
    return {
        "optimizer": optimizer,
        "lr_scheduler": {
            "scheduler": scheduler,
            "monitor": monitor
        }
    }

It seems to work well, at least it won't give me the same error message.

@victorventuri
Copy link

@adewdev @seuccq I tried to modify the spacetimeformer_model.py. It's under spacetimeformer/spacetimeformer_model. In the function configure_optimizers, I modified the code by adding a monitor in return value. It looks like this:

def configure_optimizers(self):
    optimizer = torch.optim.AdamW(
        self.parameters(),
        lr=self.base_lr,
        weight_decay=self.l2_coeff,
        )
    scheduler = stf.lr_scheduler.WarmupReduceLROnPlateau(
        optimizer,
        init_lr=self.init_lr,
        peak_lr=self.base_lr,
        warmup_steps=self.warmup_steps,
        patience=3,
        factor=self.decay_factor,
    )
    monitor = 'val/loss'
   
    return {
        "optimizer": optimizer,
        "lr_scheduler": {
            "scheduler": scheduler,
            "monitor": monitor
        }
    }

It seems to work well, at least it won't give me the same error message.

I had this issue too, and your solution solved it for me. Thank you!

@pdy265
Copy link

pdy265 commented Mar 25, 2024

@adewdev @seuccq I tried to modify the spacetimeformer_model.py. It's under spacetimeformer/spacetimeformer_model. In the function configure_optimizers, I modified the code by adding a monitor in return value. It looks like this:

def configure_optimizers(self):
    optimizer = torch.optim.AdamW(
        self.parameters(),
        lr=self.base_lr,
        weight_decay=self.l2_coeff,
        )
    scheduler = stf.lr_scheduler.WarmupReduceLROnPlateau(
        optimizer,
        init_lr=self.init_lr,
        peak_lr=self.base_lr,
        warmup_steps=self.warmup_steps,
        patience=3,
        factor=self.decay_factor,
    )
    monitor = 'val/loss'
   
    return {
        "optimizer": optimizer,
        "lr_scheduler": {
            "scheduler": scheduler,
            "monitor": monitor
        }
    }

It seems to work well, at least it won't give me the same error message.

It seems to also worked for me, thank you very much. And I would like to ask how to obtain the data when using the spacetimeformer model for training。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants