From 1913fc1af151a6ce718484e324adfe3f6d17aca5 Mon Sep 17 00:00:00 2001 From: SSamDav Date: Wed, 9 Aug 2023 21:46:44 +0100 Subject: [PATCH] Fix error when calling _prepare_decoder_input_ids_for_generation --- src/ecco/lm.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/ecco/lm.py b/src/ecco/lm.py index 6f131de..05176f6 100644 --- a/src/ecco/lm.py +++ b/src/ecco/lm.py @@ -196,12 +196,21 @@ def generate(self, input_str: str, raise ValueError( "max_length set to {} while input token has more tokens ({}). Consider increasing max_length" \ .format(max_length, cur_len)) - + # Get decoder input ids if self.model_type == 'enc-dec': # FIXME: only done because causal LMs like GPT-2 have the _prepare_decoder_input_ids_for_generation method but do not use it assert len(input_ids.size()) == 2 # will break otherwise if version.parse(transformers.__version__) >= version.parse('4.13'): - decoder_input_ids = self.model._prepare_decoder_input_ids_for_generation(input_ids.shape[0], None, None) + + # following the code in https://github.com/huggingface/transformers/blob/d0c1aebea467af499331234e7b285a6bf91ea073/tests/generation/test_utils.py#L2099 + model_kwargs = self.model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) + decoder_input_ids, model_kwargs = self.model._prepare_decoder_input_ids_for_generation( + batch_size=input_ids.shape[0], + model_input_name=self.model.main_input_name, + model_kwargs=model_kwargs, + decoder_start_token_id=self.model.config.decoder_start_token_id, + bos_token_id=self.model.config.bos_token_id, + ) else: decoder_input_ids = self.model._prepare_decoder_input_ids_for_generation(input_ids, None, None) else: