From 66c56f0fcea877148d7829eff55934c57d7038ae Mon Sep 17 00:00:00 2001 From: Google AI Edge Date: Thu, 3 Oct 2024 11:03:16 -0700 Subject: [PATCH] Keep the order of values of --models flag in batch conversion. It allows to change the order of conversions. PiperOrigin-RevId: 681949266 --- .../generative/tools/batch_convert.py | 256 +++++++++--------- 1 file changed, 134 insertions(+), 122 deletions(-) diff --git a/ai_edge_torch/generative/tools/batch_convert.py b/ai_edge_torch/generative/tools/batch_convert.py index aa8bd08..db15adc 100644 --- a/ai_edge_torch/generative/tools/batch_convert.py +++ b/ai_edge_torch/generative/tools/batch_convert.py @@ -63,6 +63,18 @@ "The list of models to convert.", ) +_PREFILL_SEQ_LEN = flags.DEFINE_integer( + "prefill_seq_len", + 1024, + "The maximum size of prefill input tensor.", +) + +_KV_CACHE_MAX_LEN = flags.DEFINE_integer( + "kv_cache_max_len", + 1280, + "The maximum size of KV cache buffer, including both prefill and decode.", +) + _PRECISIONS = flags.DEFINE_list( "precisions", ["q8", "f32"], @@ -89,6 +101,7 @@ class ConversionConfig: kv_cache_max_len: int export_precision: Sequence[ExportPrecision] model_builder: Callable[..., torch.nn.Module] + model_size: str def print_config(self) -> None: """Prints the conversion config.""" @@ -98,141 +111,139 @@ def print_config(self) -> None: logging.info("Prefill seq len: %s", self.prefill_seq_len) logging.info("KV cache max len: %s", self.kv_cache_max_len) logging.info("Export precision: %s", self.export_precision) + logging.info("Model size: %s", self.model_size) -def prepare_conversion_configs() -> Sequence[ConversionConfig]: - """Prepares the conversion configs according to the flags.""" - +def get_conversion_config( + model_name: str, + input_checkpoint_subdir: str, + tflite_output_subdir: str, + model_builder: Callable[..., torch.nn.Module], + model_size: str, +) -> ConversionConfig: + """Returns the conversion config for a model.""" export_precision = [] if "q8" in _PRECISIONS.value: export_precision.append(ExportPrecision.INT8) if "f32" in _PRECISIONS.value: export_precision.append(ExportPrecision.FP32) + return ConversionConfig( + model_name=model_name, + input_checkpoint=os.path.join( + _CHECKPOINT_ROOT_PATH.value, input_checkpoint_subdir + ), + tflite_output_path=os.path.join(_OUTPUT_DIR.value, tflite_output_subdir), + prefill_seq_len=_PREFILL_SEQ_LEN.value, + kv_cache_max_len=_KV_CACHE_MAX_LEN.value, + export_precision=export_precision, + model_builder=model_builder, + model_size=model_size, + ) + + +def prepare_conversion_configs() -> Sequence[ConversionConfig]: + """Prepares the conversion configs according to the flags.""" conversion_configs = [] - if "tinyllama" in _MODELS.value: - conversion_configs.append( - ConversionConfig( - model_name="tinyllama", - input_checkpoint=os.path.join( - _CHECKPOINT_ROOT_PATH.value, "tiny_llama" - ), - tflite_output_path=os.path.join(_OUTPUT_DIR.value, "tiny_llama"), - prefill_seq_len=1024, - kv_cache_max_len=1280, - export_precision=export_precision, - model_builder=tiny_llama.build_model, - ) - ) - if "gemma" in _MODELS.value: - conversion_configs.append( - ConversionConfig( - model_name="gemma", - input_checkpoint=os.path.join( - _CHECKPOINT_ROOT_PATH.value, "gemma-2b" - ), - tflite_output_path=os.path.join(_OUTPUT_DIR.value, "gemma"), - prefill_seq_len=1024, - kv_cache_max_len=1280, - export_precision=export_precision, - model_builder=gemma1.build_2b_model, - ) - ) - if "gemma2" in _MODELS.value: - conversion_configs.append( - ConversionConfig( - model_name="gemma2", - input_checkpoint=os.path.join( - _CHECKPOINT_ROOT_PATH.value, "gemma2-2b" - ), - tflite_output_path=os.path.join(_OUTPUT_DIR.value, "gemma2"), - prefill_seq_len=1024, - kv_cache_max_len=1280, - export_precision=export_precision, - model_builder=gemma2.build_2b_model, - ) - ) - if "llama3.2" in _MODELS.value: - conversion_configs.append( - ConversionConfig( - model_name="llama3_2", - input_checkpoint=os.path.join(_CHECKPOINT_ROOT_PATH.value, "llama"), - tflite_output_path=os.path.join(_OUTPUT_DIR.value, "llama"), - prefill_seq_len=1024, - kv_cache_max_len=1280, - export_precision=export_precision, - model_builder=llama.build_3b_model, - ) - ) - if "phi2" in _MODELS.value: - conversion_configs.append( - ConversionConfig( - model_name="phi2", - input_checkpoint=os.path.join(_CHECKPOINT_ROOT_PATH.value, "phi2"), - tflite_output_path=os.path.join(_OUTPUT_DIR.value, "phi2"), - prefill_seq_len=1024, - kv_cache_max_len=1280, - export_precision=export_precision, - model_builder=phi2.build_model, - ) - ) - if "phi3.5" in _MODELS.value: - conversion_configs.append( - ConversionConfig( - model_name="phi3_5", - input_checkpoint=os.path.join(_CHECKPOINT_ROOT_PATH.value, "phi3"), - tflite_output_path=os.path.join(_OUTPUT_DIR.value, "phi3"), - prefill_seq_len=1024, - kv_cache_max_len=1280, - export_precision=export_precision, - model_builder=phi3.build_model, - ) - ) - if "openelm" in _MODELS.value: - conversion_configs.append( - ConversionConfig( - model_name="openelm", - input_checkpoint=os.path.join( - _CHECKPOINT_ROOT_PATH.value, "openelm" - ), - tflite_output_path=os.path.join(_OUTPUT_DIR.value, "openelm"), - prefill_seq_len=1024, - kv_cache_max_len=1280, - export_precision=export_precision, - model_builder=openelm.build_model, - ) - ) - if "smollm" in _MODELS.value: - conversion_configs.append( - ConversionConfig( - model_name="smollm", - input_checkpoint=os.path.join( - _CHECKPOINT_ROOT_PATH.value, "smollm" - ), - tflite_output_path=os.path.join(_OUTPUT_DIR.value, "smollm"), - prefill_seq_len=1024, - kv_cache_max_len=1280, - export_precision=export_precision, - model_builder=smollm.build_model, - ) - ) - if "qwen2.5" in _MODELS.value: - conversion_configs.append( - ConversionConfig( - model_name="qwen2_5", - input_checkpoint=os.path.join(_CHECKPOINT_ROOT_PATH.value, "qwen"), - tflite_output_path=os.path.join(_OUTPUT_DIR.value, "qwen"), - prefill_seq_len=1024, - kv_cache_max_len=1280, - export_precision=export_precision, - model_builder=qwen.build_3b_model, - ) - ) + for model in _MODELS.value: + if model == "gemma": + conversion_configs.append( + get_conversion_config( + model_name="gemma", + input_checkpoint_subdir="gemma-2b", + tflite_output_subdir="gemma", + model_builder=gemma1.build_2b_model, + model_size="2b", + ) + ) + elif model == "gemma2": + conversion_configs.append( + get_conversion_config( + model_name="gemma2", + input_checkpoint_subdir="gemma2-2b", + tflite_output_subdir="gemma2", + model_builder=gemma2.build_2b_model, + model_size="2b", + ) + ) + elif model == "llama3.2": + conversion_configs.append( + get_conversion_config( + model_name="llama3.2", + input_checkpoint_subdir="llama", + tflite_output_subdir="llama", + model_builder=llama.build_3b_model, + model_size="3b", + ) + ) + elif model == "openelm": + conversion_configs.append( + get_conversion_config( + model_name="openelm", + input_checkpoint_subdir="openelm", + tflite_output_subdir="openelm", + model_builder=openelm.build_model, + model_size="3b", + ) + ) + elif model == "phi2": + conversion_configs.append( + get_conversion_config( + model_name="phi2", + input_checkpoint_subdir="phi2", + tflite_output_subdir="phi2", + model_builder=phi2.build_model, + model_size="2.7b", + ) + ) + elif model == "phi3.5": + conversion_configs.append( + get_conversion_config( + model_name="phi3.5", + input_checkpoint_subdir="phi3", + tflite_output_subdir="phi3", + model_builder=phi3.build_model, + model_size="3.8b", + ) + ) + elif model == "qwen2.5": + conversion_configs.append( + get_conversion_config( + model_name="qwen2.5", + input_checkpoint_subdir="qwen", + tflite_output_subdir="qwen", + model_builder=qwen.build_3b_model, + model_size="3b", + ) + ) + elif model == "smollm": + conversion_configs.append( + get_conversion_config( + model_name="smollm", + input_checkpoint_subdir="smollm", + tflite_output_subdir="smollm", + model_builder=smollm.build_model, + model_size="135m", + ) + ) + elif model == "tinyllama": + conversion_configs.append( + get_conversion_config( + model_name="tinyllama", + input_checkpoint_subdir="tiny_llama", + tflite_output_subdir="tiny_llama", + model_builder=tiny_llama.build_model, + model_size="1.1b", + ) + ) + else: + raise ValueError(f"Unsupported model: {model}") return conversion_configs def get_output_filename( model_name: str, + model_size: str, precision: ExportPrecision, prefill_seq_len: int, kv_cache_max_len: int, @@ -244,7 +255,7 @@ def get_output_filename( precision_str = "f32" else: raise ValueError(f"Unsupported precision: {precision}") - return f"{model_name}_{precision_str}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite" + return f"{model_name}_{model_size}_{precision_str}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite" def convert_models(conversion_configs: Sequence[ConversionConfig]) -> None: @@ -260,6 +271,7 @@ def convert_models(conversion_configs: Sequence[ConversionConfig]) -> None: for precision in config.export_precision: output_filename = get_output_filename( config.model_name, + config.model_size, precision, config.prefill_seq_len, config.kv_cache_max_len,