diff --git a/ai_edge_torch/generative/layers/kv_cache.py b/ai_edge_torch/generative/layers/kv_cache.py index 3479b61..32103e5 100644 --- a/ai_edge_torch/generative/layers/kv_cache.py +++ b/ai_edge_torch/generative/layers/kv_cache.py @@ -172,8 +172,8 @@ def _update_kv_base_impl( v_slice: torch.Tensor, ) -> KVCacheEntry: """Update the cache buffer without High Level Function Boundary annotation.""" - k = cache.k_cache.index_copy(1, input_pos.to(torch.long), k_slice) - v = cache.v_cache.index_copy(1, input_pos.to(torch.long), v_slice) + k = cache.k_cache.index_copy(1, input_pos, k_slice) + v = cache.v_cache.index_copy(1, input_pos, v_slice) updated_cache = KVCacheEntry(k, v) return updated_cache