Skip to content

Commit

Permalink
fix: series shorter than max_horizon (#459)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored Nov 28, 2024
1 parent fdb7772 commit 4dabb31
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
7 changes: 4 additions & 3 deletions mlforecast/grouped_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,10 @@ def expand_target(self, max_horizon: int) -> np.ndarray:
)
for j in range(max_horizon):
for i in range(self.n_groups):
out[self.indptr[i] : self.indptr[i + 1] - j, j] = self.data[
self.indptr[i] + j : self.indptr[i + 1]
]
if self.indptr[i + 1] - self.indptr[i] > j:
out[self.indptr[i] : self.indptr[i + 1] - j, j] = self.data[
self.indptr[i] + j : self.indptr[i + 1]
]
return out

def take_from_groups(self, idx: Union[int, slice]) -> "GroupedArray":
Expand Down
9 changes: 5 additions & 4 deletions nbs/grouped_array.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,12 @@
" )\n",
" for j in range(max_horizon):\n",
" for i in range(self.n_groups):\n",
" out[self.indptr[i] : self.indptr[i + 1] - j, j] = self.data[\n",
" self.indptr[i] + j : self.indptr[i + 1]\n",
" ]\n",
" if self.indptr[i + 1] - self.indptr[i] > j:\n",
" out[self.indptr[i] : self.indptr[i + 1] - j, j] = self.data[\n",
" self.indptr[i] + j : self.indptr[i + 1]\n",
" ]\n",
" return out\n",
" \n",
"\n",
" def take_from_groups(self, idx: Union[int, slice]) -> 'GroupedArray':\n",
" \"\"\"Takes `idx` from each group in the array.\"\"\"\n",
" ranges = [\n",
Expand Down

0 comments on commit 4dabb31

Please sign in to comment.