Skip to content

Commit

Permalink
Keep the order of values of --models flag in batch conversion.
Browse files Browse the repository at this point in the history
It allows to change the order of conversions.

PiperOrigin-RevId: 681949266
  • Loading branch information
ai-edge-bot authored and copybara-github committed Oct 3, 2024
1 parent bb459de commit 66c56f0
Showing 1 changed file with 134 additions and 122 deletions.
256 changes: 134 additions & 122 deletions ai_edge_torch/generative/tools/batch_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,18 @@
"The list of models to convert.",
)

_PREFILL_SEQ_LEN = flags.DEFINE_integer(
"prefill_seq_len",
1024,
"The maximum size of prefill input tensor.",
)

_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
"kv_cache_max_len",
1280,
"The maximum size of KV cache buffer, including both prefill and decode.",
)

_PRECISIONS = flags.DEFINE_list(
"precisions",
["q8", "f32"],
Expand All @@ -89,6 +101,7 @@ class ConversionConfig:
kv_cache_max_len: int
export_precision: Sequence[ExportPrecision]
model_builder: Callable[..., torch.nn.Module]
model_size: str

def print_config(self) -> None:
"""Prints the conversion config."""
Expand All @@ -98,141 +111,139 @@ def print_config(self) -> None:
logging.info("Prefill seq len: %s", self.prefill_seq_len)
logging.info("KV cache max len: %s", self.kv_cache_max_len)
logging.info("Export precision: %s", self.export_precision)
logging.info("Model size: %s", self.model_size)


def prepare_conversion_configs() -> Sequence[ConversionConfig]:
"""Prepares the conversion configs according to the flags."""

def get_conversion_config(
model_name: str,
input_checkpoint_subdir: str,
tflite_output_subdir: str,
model_builder: Callable[..., torch.nn.Module],
model_size: str,
) -> ConversionConfig:
"""Returns the conversion config for a model."""
export_precision = []
if "q8" in _PRECISIONS.value:
export_precision.append(ExportPrecision.INT8)
if "f32" in _PRECISIONS.value:
export_precision.append(ExportPrecision.FP32)

return ConversionConfig(
model_name=model_name,
input_checkpoint=os.path.join(
_CHECKPOINT_ROOT_PATH.value, input_checkpoint_subdir
),
tflite_output_path=os.path.join(_OUTPUT_DIR.value, tflite_output_subdir),
prefill_seq_len=_PREFILL_SEQ_LEN.value,
kv_cache_max_len=_KV_CACHE_MAX_LEN.value,
export_precision=export_precision,
model_builder=model_builder,
model_size=model_size,
)


def prepare_conversion_configs() -> Sequence[ConversionConfig]:
"""Prepares the conversion configs according to the flags."""
conversion_configs = []
if "tinyllama" in _MODELS.value:
conversion_configs.append(
ConversionConfig(
model_name="tinyllama",
input_checkpoint=os.path.join(
_CHECKPOINT_ROOT_PATH.value, "tiny_llama"
),
tflite_output_path=os.path.join(_OUTPUT_DIR.value, "tiny_llama"),
prefill_seq_len=1024,
kv_cache_max_len=1280,
export_precision=export_precision,
model_builder=tiny_llama.build_model,
)
)
if "gemma" in _MODELS.value:
conversion_configs.append(
ConversionConfig(
model_name="gemma",
input_checkpoint=os.path.join(
_CHECKPOINT_ROOT_PATH.value, "gemma-2b"
),
tflite_output_path=os.path.join(_OUTPUT_DIR.value, "gemma"),
prefill_seq_len=1024,
kv_cache_max_len=1280,
export_precision=export_precision,
model_builder=gemma1.build_2b_model,
)
)
if "gemma2" in _MODELS.value:
conversion_configs.append(
ConversionConfig(
model_name="gemma2",
input_checkpoint=os.path.join(
_CHECKPOINT_ROOT_PATH.value, "gemma2-2b"
),
tflite_output_path=os.path.join(_OUTPUT_DIR.value, "gemma2"),
prefill_seq_len=1024,
kv_cache_max_len=1280,
export_precision=export_precision,
model_builder=gemma2.build_2b_model,
)
)
if "llama3.2" in _MODELS.value:
conversion_configs.append(
ConversionConfig(
model_name="llama3_2",
input_checkpoint=os.path.join(_CHECKPOINT_ROOT_PATH.value, "llama"),
tflite_output_path=os.path.join(_OUTPUT_DIR.value, "llama"),
prefill_seq_len=1024,
kv_cache_max_len=1280,
export_precision=export_precision,
model_builder=llama.build_3b_model,
)
)
if "phi2" in _MODELS.value:
conversion_configs.append(
ConversionConfig(
model_name="phi2",
input_checkpoint=os.path.join(_CHECKPOINT_ROOT_PATH.value, "phi2"),
tflite_output_path=os.path.join(_OUTPUT_DIR.value, "phi2"),
prefill_seq_len=1024,
kv_cache_max_len=1280,
export_precision=export_precision,
model_builder=phi2.build_model,
)
)
if "phi3.5" in _MODELS.value:
conversion_configs.append(
ConversionConfig(
model_name="phi3_5",
input_checkpoint=os.path.join(_CHECKPOINT_ROOT_PATH.value, "phi3"),
tflite_output_path=os.path.join(_OUTPUT_DIR.value, "phi3"),
prefill_seq_len=1024,
kv_cache_max_len=1280,
export_precision=export_precision,
model_builder=phi3.build_model,
)
)
if "openelm" in _MODELS.value:
conversion_configs.append(
ConversionConfig(
model_name="openelm",
input_checkpoint=os.path.join(
_CHECKPOINT_ROOT_PATH.value, "openelm"
),
tflite_output_path=os.path.join(_OUTPUT_DIR.value, "openelm"),
prefill_seq_len=1024,
kv_cache_max_len=1280,
export_precision=export_precision,
model_builder=openelm.build_model,
)
)
if "smollm" in _MODELS.value:
conversion_configs.append(
ConversionConfig(
model_name="smollm",
input_checkpoint=os.path.join(
_CHECKPOINT_ROOT_PATH.value, "smollm"
),
tflite_output_path=os.path.join(_OUTPUT_DIR.value, "smollm"),
prefill_seq_len=1024,
kv_cache_max_len=1280,
export_precision=export_precision,
model_builder=smollm.build_model,
)
)
if "qwen2.5" in _MODELS.value:
conversion_configs.append(
ConversionConfig(
model_name="qwen2_5",
input_checkpoint=os.path.join(_CHECKPOINT_ROOT_PATH.value, "qwen"),
tflite_output_path=os.path.join(_OUTPUT_DIR.value, "qwen"),
prefill_seq_len=1024,
kv_cache_max_len=1280,
export_precision=export_precision,
model_builder=qwen.build_3b_model,
)
)
for model in _MODELS.value:
if model == "gemma":
conversion_configs.append(
get_conversion_config(
model_name="gemma",
input_checkpoint_subdir="gemma-2b",
tflite_output_subdir="gemma",
model_builder=gemma1.build_2b_model,
model_size="2b",
)
)
elif model == "gemma2":
conversion_configs.append(
get_conversion_config(
model_name="gemma2",
input_checkpoint_subdir="gemma2-2b",
tflite_output_subdir="gemma2",
model_builder=gemma2.build_2b_model,
model_size="2b",
)
)
elif model == "llama3.2":
conversion_configs.append(
get_conversion_config(
model_name="llama3.2",
input_checkpoint_subdir="llama",
tflite_output_subdir="llama",
model_builder=llama.build_3b_model,
model_size="3b",
)
)
elif model == "openelm":
conversion_configs.append(
get_conversion_config(
model_name="openelm",
input_checkpoint_subdir="openelm",
tflite_output_subdir="openelm",
model_builder=openelm.build_model,
model_size="3b",
)
)
elif model == "phi2":
conversion_configs.append(
get_conversion_config(
model_name="phi2",
input_checkpoint_subdir="phi2",
tflite_output_subdir="phi2",
model_builder=phi2.build_model,
model_size="2.7b",
)
)
elif model == "phi3.5":
conversion_configs.append(
get_conversion_config(
model_name="phi3.5",
input_checkpoint_subdir="phi3",
tflite_output_subdir="phi3",
model_builder=phi3.build_model,
model_size="3.8b",
)
)
elif model == "qwen2.5":
conversion_configs.append(
get_conversion_config(
model_name="qwen2.5",
input_checkpoint_subdir="qwen",
tflite_output_subdir="qwen",
model_builder=qwen.build_3b_model,
model_size="3b",
)
)
elif model == "smollm":
conversion_configs.append(
get_conversion_config(
model_name="smollm",
input_checkpoint_subdir="smollm",
tflite_output_subdir="smollm",
model_builder=smollm.build_model,
model_size="135m",
)
)
elif model == "tinyllama":
conversion_configs.append(
get_conversion_config(
model_name="tinyllama",
input_checkpoint_subdir="tiny_llama",
tflite_output_subdir="tiny_llama",
model_builder=tiny_llama.build_model,
model_size="1.1b",
)
)
else:
raise ValueError(f"Unsupported model: {model}")
return conversion_configs


def get_output_filename(
model_name: str,
model_size: str,
precision: ExportPrecision,
prefill_seq_len: int,
kv_cache_max_len: int,
Expand All @@ -244,7 +255,7 @@ def get_output_filename(
precision_str = "f32"
else:
raise ValueError(f"Unsupported precision: {precision}")
return f"{model_name}_{precision_str}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite"
return f"{model_name}_{model_size}_{precision_str}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite"


def convert_models(conversion_configs: Sequence[ConversionConfig]) -> None:
Expand All @@ -260,6 +271,7 @@ def convert_models(conversion_configs: Sequence[ConversionConfig]) -> None:
for precision in config.export_precision:
output_filename = get_output_filename(
config.model_name,
config.model_size,
precision,
config.prefill_seq_len,
config.kv_cache_max_len,
Expand Down

0 comments on commit 66c56f0

Please sign in to comment.