Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot save any torch tensor or nn_module #1112

Open
rdinnager opened this issue Oct 6, 2023 · 4 comments
Open

Cannot save any torch tensor or nn_module #1112

rdinnager opened this issue Oct 6, 2023 · 4 comments

Comments

@rdinnager
Copy link
Collaborator

Ever since I updated torch to the latest version, I have not been able to save any torch object using the default serialization method. I have instead had to fall back on options(torch.serialization_version = 2) to be able to save any tensors or model objects. Here is a reprex:

library(torch)

x <- torch_rand(100, 1000)
torch_save(x, "test.pth")
#> Error in `FUN()`:
#> ! `metadata` must be a named list of scalar characters.
#> Backtrace:
#>      ▆
#>   1. ├─torch::torch_save(x, "test.pth")
#>   2. └─torch:::torch_save.torch_tensor(x, "test.pth")
#>   3.   └─torch:::torch_save_to_file(...)
#>   4.     └─safetensors::safe_save_file(state_dict, path = con, metadata = metadata)
#>   5.       └─safetensors:::write_safe(tensors, metadata, con)
#>   6.         └─safetensors:::make_meta(tensors, metadata)
#>   7.           └─safetensors:::validate_metadata(metadata)
#>   8.             └─base::lapply(...)
#>   9.               └─safetensors (local) FUN(X[[i]], ...)
#>  10.                 └─cli::cli_abort("{.arg metadata} must be a named list of scalar characters.")
#>  11.                   └─rlang::abort(...)

## trying to use safetensors directly gives a different error:
safetensors::safe_save_file(x, "test.pth")
#> Error in for (tensor in tensors) {: invalid for() loop sequence

## This works (but seems very slow for some reason)
options(torch.serialization_version = 2)
torch_save(x, "test.pth")

Created on 2023-10-06 with reprex v2.0.2

Any ideas what is going wrong here?

@dfalbel
Copy link
Member

dfalbel commented Oct 6, 2023

Hi @rdinnager,

Thanks for reporting.

That's weird, I'd assume this is a mismatch between the torch version and the safetensors versions, as at some point I think I saw some similar issue.

Can you try updating your safetensors package. I just tried lates commit from torch + (CRAN or latest commit) safetensors and they seem to work correctly.

@rdinnager
Copy link
Collaborator Author

I am using the latest version of both torch and safetensors from CRAN:

packageVersion("safetensors")
#> [1] '0.1.2'
packageVersion("torch")
#> [1] '0.11.0'

I'm thinking now that I actually need the development version of torch to work properly, after looking through the recent commit history. I will try that!

@dfalbel
Copy link
Member

dfalbel commented Oct 10, 2023

ohhh, I think that might be the case. You are right, you might need to downgrade safetensors or use the dev version of torch. I'm going to make a new torch release soon.

@rdinnager
Copy link
Collaborator Author

Yes, I decided to wait until the new release and just use options(torch.serialization_version = 2) in the mean time. I find the new precompiled cuda binary method of installation so convenient I just don't want to bother trying to install from source at the moment, which would require me installing the compatible CUDA locally (and I'm not up for that right now ;) ).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants