Skip to content

Commit

Permalink
Add test to check that the external kv cache op is present in the model.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 681923541
  • Loading branch information
hheydary authored and copybara-github committed Oct 3, 2024
1 parent 6c4d39b commit bb459de
Showing 1 changed file with 37 additions and 16 deletions.
53 changes: 37 additions & 16 deletions ai_edge_torch/generative/test/test_model_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,31 +43,40 @@ def setUp(self):
)
)

def _test_model_with_kv_cache(self, config, pytorch_model):
def _get_params(self, enable_hlfb: bool):
"""Returns a model, edge model and the kwargs to use for testing."""
config = toy_model_with_kv_cache.get_model_config()
config.enable_hlfb = enable_hlfb
pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
tokens, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
[10], dtype=torch.int
)
kv = kv_cache.KVCache.from_model_config(config)
kwargs = {
"tokens": tokens,
"input_pos": input_pos,
"kv_cache": kv,
}

edge_model = ai_edge_torch.convert(
pytorch_model,
sample_kwargs={
"tokens": tokens,
"input_pos": input_pos,
"kv_cache": kv,
},
sample_kwargs=kwargs,
)
edge_model.set_interpreter_builder(
self._interpreter_builder(edge_model.tflite_model())
)
return pytorch_model, edge_model, kwargs

def _test_model_with_kv_cache(self, enable_hlfb: bool):
pytorch_model, edge_model, kwargs = self._get_params(enable_hlfb)

self.assertTrue(
test_utils.compare_tflite_torch(
edge_model,
pytorch_model,
tokens,
input_pos,
kv,
kwargs["tokens"],
kwargs["input_pos"],
kwargs["kv_cache"],
signature_name="serving_default",
atol=1e-5,
rtol=1e-5,
Expand All @@ -79,19 +88,31 @@ def _test_model_with_kv_cache(self, config, pytorch_model):
reason="tests with custom ops are not supported on oss",
)
def test_toy_model_with_kv_cache(self):
config = toy_model_with_kv_cache.get_model_config()
pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
self._test_model_with_kv_cache(config, pytorch_model)
self._test_model_with_kv_cache(enable_hlfb=False)

@googletest.skipIf(
ai_edge_config.Config.use_torch_xla,
reason="tests with custom ops are not supported on oss",
)
def test_toy_model_with_kv_cache_with_hlfb(self):
config = toy_model_with_kv_cache.get_model_config()
config.enable_hlfb = True
pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
self._test_model_with_kv_cache(config, pytorch_model)
self._test_model_with_kv_cache(enable_hlfb=True)

@googletest.skipIf(
ai_edge_config.Config.use_torch_xla,
reason="tests with custom ops are not supported on oss",
)
def test_toy_model_has_ekv_op(self):
"""Tests that the model has the external kv cache op."""
_, edge_model, _ = self._get_params(enable_hlfb=True)
interpreter_ = interpreter.InterpreterWithCustomOps(
custom_op_registerers=["GenAIOpsRegisterer"],
model_content=edge_model.tflite_model(),
experimental_default_delegate_latest_features=True,
)

# pylint: disable=protected-access
op_names = [op["op_name"] for op in interpreter_._get_ops_details()]
self.assertIn("odml.update_external_kv_cache", op_names)

def _test_multisig_model(self, config, pytorch_model, atol, rtol):
# prefill
Expand Down

0 comments on commit bb459de

Please sign in to comment.