Skip to content

Commit

Permalink
handle no target transforms in DistributedMLForecast.to_local (#388)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored Jul 22, 2024
1 parent 9bdbd11 commit 6271920
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 17 deletions.
7 changes: 6 additions & 1 deletion mlforecast/distributed/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,8 @@ def possibly_concat_indices(collection):
return combined

def combine_target_tfms(by_partition):
if by_partition[0] is None:
return None
by_transform = [
[part[i] for part in by_partition] for i in range(len(by_partition[0]))
]
Expand Down Expand Up @@ -836,7 +838,10 @@ def combine_core_lag_tfms(by_partition):
statics = ufp.drop_index_if_pandas(statics)
for tfm in combined_core_lag_tfms.values():
tfm._core_tfm = tfm._core_tfm.take(sort_idxs)
combined_target_tfms = [tfm.take(sort_idxs) for tfm in combined_target_tfms]
if combined_target_tfms is not None:
combined_target_tfms = [
tfm.take(sort_idxs) for tfm in combined_target_tfms
]
old_data = data.copy()
old_indptr = indptr.copy()
indptr = np.append(0, sizes[sort_idxs]).cumsum()
Expand Down
7 changes: 6 additions & 1 deletion nbs/distributed.forecast.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,8 @@
" return combined\n",
"\n",
" def combine_target_tfms(by_partition):\n",
" if by_partition[0] is None:\n",
" return None\n",
" by_transform = [\n",
" [part[i] for part in by_partition]\n",
" for i in range(len(by_partition[0]))\n",
Expand Down Expand Up @@ -879,7 +881,10 @@
" statics = ufp.drop_index_if_pandas(statics)\n",
" for tfm in combined_core_lag_tfms.values():\n",
" tfm._core_tfm = tfm._core_tfm.take(sort_idxs)\n",
" combined_target_tfms = [tfm.take(sort_idxs) for tfm in combined_target_tfms]\n",
" if combined_target_tfms is not None:\n",
" combined_target_tfms = [\n",
" tfm.take(sort_idxs) for tfm in combined_target_tfms\n",
" ]\n",
" old_data = data.copy()\n",
" old_indptr = indptr.copy()\n",
" indptr = np.append(0, sizes[sort_idxs]).cumsum()\n",
Expand Down
54 changes: 39 additions & 15 deletions nbs/docs/getting-started/quick_start_distributed.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -413,48 +413,48 @@
" <th>0</th>\n",
" <td>id_00</td>\n",
" <td>2002-09-27 00:00:00</td>\n",
" <td>21.024446</td>\n",
" <td>21.710263</td>\n",
" <td>22.489947</td>\n",
" <td>21.679944</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>id_00</td>\n",
" <td>2002-09-28 00:00:00</td>\n",
" <td>84.190221</td>\n",
" <td>84.160383</td>\n",
" <td>81.806826</td>\n",
" <td>84.151205</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>id_00</td>\n",
" <td>2002-09-29 00:00:00</td>\n",
" <td>164.370398</td>\n",
" <td>163.325095</td>\n",
" <td>162.705641</td>\n",
" <td>164.024508</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>id_00</td>\n",
" <td>2002-09-30 00:00:00</td>\n",
" <td>246.09351</td>\n",
" <td>246.099914</td>\n",
" <td>246.990386</td>\n",
" <td>246.099977</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>id_00</td>\n",
" <td>2002-10-01 00:00:00</td>\n",
" <td>311.239076</td>\n",
" <td>314.455627</td>\n",
" <td>314.741463</td>\n",
" <td>315.261537</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" unique_id ds DaskXGBForecast DaskLGBMForecast\n",
"0 id_00 2002-09-27 00:00:00 21.024446 21.710263\n",
"1 id_00 2002-09-28 00:00:00 84.190221 84.160383\n",
"2 id_00 2002-09-29 00:00:00 164.370398 163.325095\n",
"3 id_00 2002-09-30 00:00:00 246.09351 246.099914\n",
"4 id_00 2002-10-01 00:00:00 311.239076 314.455627"
"0 id_00 2002-09-27 00:00:00 22.489947 21.679944\n",
"1 id_00 2002-09-28 00:00:00 81.806826 84.151205\n",
"2 id_00 2002-09-29 00:00:00 162.705641 164.024508\n",
"3 id_00 2002-09-30 00:00:00 246.990386 246.099977\n",
"4 id_00 2002-10-01 00:00:00 314.741463 315.261537"
]
},
"execution_count": null,
Expand Down Expand Up @@ -666,6 +666,30 @@
"pd.testing.assert_frame_equal(preds, local_preds, check_dtype=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "84fa8337-f06c-4457-9587-48d47d2ac61a",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"# test to_local without target transforms\n",
"fcst_no_targ_tfms = DistributedMLForecast(\n",
" models=[DaskXGBForecast(n_estimators=5, random_state=0)],\n",
" freq='D', \n",
" lags=[1],\n",
" lag_transforms={1: [ExpandingMean()]},\n",
" date_features=['dayofweek'],\n",
")\n",
"fcst_no_targ_tfms.fit(\n",
" partitioned_series,\n",
" static_features=['static_0', 'static_1'],\n",
")\n",
"local_fcst = fcst_no_targ_tfms.to_local()\n",
"assert local_fcst.ts.target_transforms is None"
]
},
{
"cell_type": "markdown",
"id": "29841c02-b0bc-44cc-a8f3-da31b442584b",
Expand Down

0 comments on commit 6271920

Please sign in to comment.