We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Using a RNN with packed sequences where enforce_sorted=TRUE gives an error.
enforce_sorted=TRUE
Let's define a 3d tensor width dimensions (batch_size, max_len, embedding_size) that represent two embedded sequences of lengths 4 and 2 respectively.
(batch_size, max_len, embedding_size)
# padded input tensor batch_size <- 2 input_size <- 3 seq_len <- c(4, 2) # sequence lengths padded <- torch_randn(batch_size,max(seq_len), input_size) padded[2,3:4,] <- 0 # padding padded
torch_tensor (1,.,.) = -1.0758 -0.5305 1.6832 -0.1549 2.0737 0.4338 1.4333 0.5613 -0.5021 1.2121 0.1815 0.2522 (2,.,.) = -1.3125 0.4738 0.4393 0.6843 -1.1598 0.2858 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 [ CPUFloatType{2,4,3} ]
The length of the two sequences is decreasing and the second sequence is padded with 0.
When the padded sequences are packed with the option enforce_sorted=TRUE, the RNN module gives an error:
# define rnn module hidden_size <-3 rnn <- nn_rnn(input_size, hidden_size, batch_first=TRUE) # pack padded input packed <- nn_utils_rnn_pack_padded_sequence(padded, torch_tensor(seq_len), batch_first=TRUE, enforce_sorted=TRUE) # RNN out <- rnn(packed)
Error in (function (self, other, alpha) : Expected a proper Tensor but got None (or an undefined Tensor in C++) for argument #0 'self'
When padded sequences are packed with enforce_sorted=FALSE, they are processed by the RNN without problem.
enforce_sorted=FALSE
# pack padded input packed <- nn_utils_rnn_pack_padded_sequence(padded, torch_tensor(seq_len), batch_first=TRUE, enforce_sorted=FALSE) # RNN out <- rnn(packed) out
To show that the output is correct, the first element output of the RNN needs to be unpacked.
nn_utils_rnn_pad_packed_sequence(out[[1]], batch_first = TRUE, padding_value = 0)
[[1]] torch_tensor (1,.,.) = 0.4470 0.9314 -0.0264 -0.8938 0.1120 0.7772 -0.9523 0.4127 0.0077 -0.9207 0.6005 0.2798 (2,.,.) = 0.1397 0.6618 -0.1076 -0.2141 0.7277 -0.2181 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 [ CPUFloatType{2,4,3} ][ grad_fn = <IndexSelectBackward0> ] [[2]] torch_tensor 4 2 [ CPULongType{2} ]
As expected, the values of the hidden state is zero after the end of the second (shorter) sequence.
The error also occurs if batch_first=FALSE
batch_first=FALSE
# new padded tensor padded <- torch_randn(max(seq_len), batch_size, input_size) padded[3:4,,2] <- 0 # padding # rnn module with batch_first=FALSE rnn <- nn_rnn(input_size, hidden_size, batch_first=FALSE) # pack padded input packed <- nn_utils_rnn_pack_padded_sequence(padded, torch_tensor(seq_len), batch_first=FALSE, enforce_sorted=TRUE) out <- rnn(packed)
sessionInfo() R version 4.3.1 (2023-06-16 ucrt) Platform: x86_64-w64-mingw32/x64 (64-bit) Running under: Windows 10 x64 (build 19045) other attached packages: [1] torch_0.11.0
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Using a RNN with packed sequences where
enforce_sorted=TRUE
gives an error.Let's define a 3d tensor width dimensions
(batch_size, max_len, embedding_size)
that represent two embedded sequences of lengths 4 and 2 respectively.
The length of the two sequences is decreasing and the second sequence is padded with 0.
When the padded sequences are packed with the option
enforce_sorted=TRUE
,the RNN module gives an error:
When padded sequences are packed with
enforce_sorted=FALSE
, they are processedby the RNN without problem.
To show that the output is correct, the first element output of the RNN needs to be unpacked.
As expected, the values of the hidden state is zero after the end of the second (shorter) sequence.
The error also occurs if
batch_first=FALSE
The text was updated successfully, but these errors were encountered: