diff --git a/src/dalle_mini/model/modeling.py b/src/dalle_mini/model/modeling.py index 40125a3f9..0dd83db63 100644 --- a/src/dalle_mini/model/modeling.py +++ b/src/dalle_mini/model/modeling.py @@ -1612,6 +1612,7 @@ def generate( condition_scale: Optional[float] = 1.0, input_ids_uncond: Optional[jnp.ndarray] = None, attention_mask_uncond: Optional[jnp.ndarray] = None, + model_kwargs_uncond: Optional[Dict[str, jnp.ndarray]] = None, **model_kwargs, ): """Edit: Allow super conditioning.""" @@ -1651,33 +1652,46 @@ def generate( params, {"attention_mask": attention_mask, **model_kwargs_input}, ) - if condition_scale != 1.0: - assert ( - input_ids_uncond is not None - ), "`input_ids_uncond` has to be defined for super conditioning." - assert ( - do_sample is True - ), "`do_sample` has to be True for super conditioning." - assert ( - num_beams == 1 - ), "`num_beams` has to be 1 for super conditioning." + else: + # model_kwargs was passed in, doesn't include named parameter yet + model_kwargs["attention_mask"] = attention_mask + if condition_scale != 1.0: + model_kwargs_uncond = model_kwargs_uncond or {} + if model_kwargs_uncond.get("encoder_outputs") is None: model_kwargs_uncond = ( self._prepare_encoder_decoder_kwargs_for_generation( input_ids_uncond, params, { "attention_mask": attention_mask_uncond, - **model_kwargs_input, + **model_kwargs_uncond, }, ) ) - else: - model_kwargs_uncond = None + # since the null prompt is usually constant, allow passing only one for the whole batch + model_kwargs_uncond = jax.tree_util.tree_map( + lambda x, y: jnp.broadcast_to(x, y.shape), + model_kwargs_uncond, + model_kwargs, + ) + else: + model_kwargs_uncond = None + # prepare decoder_input_ids for generation input_ids = ( jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id ) + if condition_scale != 1.0: + assert (input_ids_uncond is not None) or ( + model_kwargs_uncond is not None + and model_kwargs_uncond.get("encoder_outputs") is not None + ), '`input_ids_uncond` or `model_kwargs_uncond["encoder_outputs"]` has to be defined for super conditioning.' + assert ( + do_sample is True + ), "`do_sample` has to be True for super conditioning." + assert num_beams == 1, "`num_beams` has to be 1 for super conditioning." + if not do_sample and num_beams == 1: logits_processor = self._get_logits_processor( no_repeat_ngram_size,