Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature attribution for darts model #1235

Open
bohdan-safoniuk opened this issue Jan 23, 2024 · 2 comments
Open

Feature attribution for darts model #1235

bohdan-safoniuk opened this issue Jan 23, 2024 · 2 comments

Comments

@bohdan-safoniuk
Copy link

❓ Questions and Help

Hi, I've been working on FeatureExplainer for probabilistic TiDE model and eventually came up with this solution:

class PlainProbabilisticTiDEModel(nn.Module):
    def __init__(self, model: TiDEModel, num_samples: int):
        super().__init__()
        self.model = model.model
        self.likelihood = model.likelihood
        self.batch_size = model.batch_size
        self.num_samples = num_samples

    def forward(
        self,
        x_past: torch.Tensor,
        x_future: torch.Tensor,
        x_static: torch.Tensor
    ) -> torch.Tensor:
        num_series = x_past.shape[0]
        batch_sample_size = min(
            max(self.batch_size // num_series, 1), self.num_samples
        )

        batch_predictions, sample_count = [], 0
        while sample_count < self.num_samples:
            # make sure we don't produce too many samples
            if sample_count + batch_sample_size > self.num_samples:
                batch_sample_size = self.num_samples - sample_count

            # stack multiple copies of the tensors to produce probabilistic forecasts
            input_data_tuple_samples = self._sample_tiling(
                (x_past, x_future, x_static), batch_sample_size
            )

            # get predictions for 1 whole batch (can include predictions of multiple series
            # and for multiple samples if a probabilistic forecast is produced)
            output = self.model(input_data_tuple_samples)
            # (batch_size, n_timestamps, n_components, n_lh_params)
            output = self.likelihood.sample(output)

            # reshape from 3d tensor (num_series x batch_sample_size, ...)
            # into 4d tensor (batch_sample_size, num_series, ...), where dim 0 represents the samples
            out_shape = output.shape
            output = output.reshape(
                (
                    batch_sample_size,
                    num_series,
                )
                + out_shape[1:]
            )

            # save all predictions and update the `sample_count` variable
            batch_predictions.append(output)
            sample_count += batch_sample_size

        batch_predictions = torch.cat(batch_predictions, dim=0)
        return batch_predictions.median(0).values


class TideExplainer:
    
    def __init__(self, model: TiDEModel, num_samples: int):
        self.model = PlainProbabilisticTiDEModel(
            model=model,
            num_samples=num_samples
        )
        self._collate_fn = model._batch_collate_fn
        self.uses_static_covariates = model.uses_static_covariates

        model_wrapped = ModelInputWrapper(self.model)
        self.method = FeatureAblation(model_wrapped)

    def explain(
        self,
        target_series: Union[TimeSeries, Sequence[TimeSeries]],
        past_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]],
        future_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]],
        n: int,
        input_chunk_length: int,
        output_chunk_length: int,
        batch_size: int = 1,
        perturbations_per_eval: int = 100,
        verbose: bool = True
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        """
        Create feature importances for given time series data.
        
        Parameters
        ----------
        target_series
            The target series that are to be predicted into the future.
        past_covariates
            Some past-observed covariates that are used for predictions.
        future_covariates
            Some future-known covariates that are used for predictions.
        n
            Forecast horizon: The number of time steps to predict after the end of the target series.
        input_chunk_length
            The length of the target series the model takes as input.
        output_chunk_length
            The length of the target series the model emits in output.
        batch_size
            Number of time series to process simultaneously.
        perturbations_per_eval
            Feature ablation (https://captum.ai/api/feature_ablation.html) parameter.
        verbose
            Whether to show progress bar for data loader.
        
        Returns
        -------
        time_importances
            ``np.ndarray`` of shape `(batch_size, input_chunk_length)`,
            which describes the dependence of the forecast on each historical timestep.
        past_importances
            ``np.ndarray`` of shape `(batch_size, n_targets + n_past_features + n_future_features)`,
            which describes the dependence of the forecast on each past feature at historical timesteps.
        future_importances
            ``np.ndarray`` of shape `(batch_size, n_future_features)`,
            which describes the dependence of the forecast on each future feature at future timesteps.
        static_importances
            ``np.ndarray`` of shape `(batch_size, n_static_features)`,
            which describes the dependence of the forecast on each static feature.
        """
        inference_dataset = MixedCovariatesInferenceDataset(
            target_series=target_series,
            past_covariates=past_covariates,
            future_covariates=future_covariates,
            n=n,
            input_chunk_length=input_chunk_length,
            output_chunk_length=output_chunk_length,
            use_static_covariates=self.uses_static_covariates
        )
        loader = DataLoader(
            inference_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=0,
            pin_memory=True,
            drop_last=False,
            collate_fn=self._collate_fn,
        )
        
        time_importances, past_importances, future_importances, static_importances = [], [], [] ,[]
        for batch in tqdm(loader, display=verbose):
            (
                x_past_target,
                x_past_covariates,
                x_historic_future_covariates,
                x_future_covariates,
                x_future_past_covariates,
                x_static_covariates,
                _
            ) = batch

            input_batch = self._process_input_batch(
                input_batch=(
                    x_past_target,
                    x_past_covariates,
                    x_historic_future_covariates,
                    x_future_covariates,
                    x_static_covariates
                )
            )

            # calculate input attribute to forecast
            # x_past_attrs: 
            # (batch_size * forecast_horizon, input_length, n_targets + n_past_features + n_future_features)
            # x_future_attrs: (batch_size * forecast_horizon, input_length, n_future_features)
            # x_static_attrs: (batch_size * forecast_horizon, n_targets, n_static_features)
            x_past_attrs, x_future_attr, x_static_attr = self.method.attribute(
                inputs=input_batch,
                perturbations_per_eval=perturbations_per_eval
            )

            # reshape attributes from (batch_size * forecast_horizon, ...) 
            # to (batch_size, num_timestamps, ...)
            x_past_attrs = x_past_attrs.reshape(batch_size, -1, *x_past_attrs.shape[1:])
            x_future_attr = x_future_attr.reshape(batch_size, -1, *x_future_attr.shape[1:])
            x_static_attr = x_static_attr.reshape(batch_size, -1, *x_static_attr.shape[1:])

            # calculate feature importances
            time_importance = x_past_attrs.sum(dim=3).sum(dim=1)  # (batch_size, input_length)
            past_importance = x_past_attrs.sum(dim=2).sum(dim=1)
            # (batch_size, n_targets + n_past_features + n_future_features)
            future_importance = x_future_attr.sum(dim=2).sum(dim=1)  # (batch_size, n_future_features)
            static_importance = x_static_attr.sum(dim=2).sum(dim=1)  # (batch_size, n_static_features)

            time_importances.append(time_importance)
            past_importances.append(past_importance)
            future_importances.append(future_importance)
            static_importances.append(static_importance)

        time_importances = torch.cat(time_importances, dim=0).numpy()
        past_importances = torch.cat(past_importances, dim=0).numpy()
        future_importances = torch.cat(future_importances, dim=0).numpy()
        static_importances = torch.cat(static_importances, dim=0).numpy()

        return time_importances, past_importances, future_importances, static_importances

Basically, it's feature ablation approach with additional darts-to-torch transformations.

My question is: is there any way to apply gradient-based attribution method to a model with output of shape (batch_size, n_timestamps, n_targets)? If not, is a perturbation/permutation method the only option for a forecasting model?

@BohdanBilonoh
Copy link

Hi there 👋 I am also interested in this question

@dimochak
Copy link

+1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants