From 8e527243f4fc63efe3f4c0001673147ea6d48f85 Mon Sep 17 00:00:00 2001 From: "T.J. Alumbaugh" Date: Fri, 20 Sep 2024 12:45:12 -0700 Subject: [PATCH] Don't force dtype to torch.long in index_copy PiperOrigin-RevId: 676937992 --- ai_edge_torch/generative/layers/kv_cache.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ai_edge_torch/generative/layers/kv_cache.py b/ai_edge_torch/generative/layers/kv_cache.py index 3479b61..32103e5 100644 --- a/ai_edge_torch/generative/layers/kv_cache.py +++ b/ai_edge_torch/generative/layers/kv_cache.py @@ -172,8 +172,8 @@ def _update_kv_base_impl( v_slice: torch.Tensor, ) -> KVCacheEntry: """Update the cache buffer without High Level Function Boundary annotation.""" - k = cache.k_cache.index_copy(1, input_pos.to(torch.long), k_slice) - v = cache.v_cache.index_copy(1, input_pos.to(torch.long), v_slice) + k = cache.k_cache.index_copy(1, input_pos, k_slice) + v = cache.v_cache.index_copy(1, input_pos, v_slice) updated_cache = KVCacheEntry(k, v) return updated_cache