Skip to content

Commit

Permalink
Added plot_mean_convergence method
Browse files Browse the repository at this point in the history
  • Loading branch information
AnotherSamWilson committed Aug 2, 2024
1 parent 5ec0cf4 commit fcf447e
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 171 deletions.
259 changes: 88 additions & 171 deletions miceforest/imputed_data.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib.metadata
from io import BytesIO
from itertools import combinations
from typing import Any, Dict, List, Optional, Union
Expand Down Expand Up @@ -111,9 +112,6 @@ def __init__(
self.initialized = False
self.imputed_variable_count = len(self.imputed_variables)
self.modeled_variable_count = len(self.modeled_variables)
# self.iterations = np.zeros(
# shape=(self.num_datasets, self.modeled_variable_count)
# ).astype(int)

# Create a multiindexed dataframe to store our imputation values
iv_multiindex = MultiIndex.from_product(
Expand All @@ -132,6 +130,9 @@ def __init__(
for dataset in datasets:
self.iteration_tab[variable, dataset] = 0

# Save the version of miceforest that was used to make this kernel
self.version = importlib.metadata.version("miceforest")

# Subsetting allows us to get to the imputation values:
def __getitem__(self, tup):
variable, iteration, dataset = tup
Expand Down Expand Up @@ -323,72 +324,6 @@ def complete_data(
if not inplace:
return impute_data

# def get_means(self, datasets, variables=None):
# """
# Return a dict containing the average imputation value
# for specified variables at each iteration.
# """
# num_vars = self._get_num_vars(variables)

# # For every variable, get the correlations between every dataset combination
# # at each iteration
# curr_iteration = self.iteration_count(datasets=datasets)
# if self.save_all_iterations:
# iter_range = list(range(curr_iteration + 1))
# else:
# iter_range = [curr_iteration]
# mean_dict = {
# ds: {
# var: {itr: np.mean(self[ds, var, itr]) for itr in iter_range}
# for var in num_vars
# }
# for ds in datasets
# }

# return mean_dict

# def plot_mean_convergence(self, datasets=None, variables=None, **adj_args):
# """
# Plots the average value of imputations over each iteration.

# Parameters
# ----------
# variables: None or list
# The variables to plot. Must be numeric.
# adj_args
# Passed to matplotlib.pyplot.subplots_adjust()

# """

# try:
# import matplotlib.pyplot as plt
# from matplotlib import gridspec
# except ImportError:
# raise ImportError("matplotlib must be installed to plot mean convergence")

# if self.iteration_count() < 2 or not self.save_all_iterations:
# raise ValueError("There is only one iteration.")

# if datasets is None:
# datasets = list(range(self.dataset_count()))
# else:
# datasets = _ensure_iterable(datasets)
# num_vars = self._get_num_vars(variables)
# mean_dict = self.get_means(datasets=datasets, variables=variables)
# plots, plotrows, plotcols = self._prep_multi_plot(num_vars)
# gs = gridspec.GridSpec(plotrows, plotcols)
# fig, ax = plt.subplots(plotrows, plotcols, squeeze=False)

# for v in range(plots):
# axr, axc = next(iter(gs[v].rowspan)), next(iter(gs[v].colspan))
# var = num_vars[v]
# for d in mean_dict.values():
# ax[axr, axc].plot(list(d[var].values()), color="black")
# ax[axr, axc].set_title(var)
# ax[axr, axc].set_xlabel("Iteration")
# ax[axr, axc].set_ylabel("mean")
# plt.subplots_adjust(**adj_args)

def plot_imputed_distributions(
self, variables: Optional[List[str]] = None, iteration: int = -1
):
Expand Down Expand Up @@ -467,105 +402,87 @@ def plot_imputed_distributions(

return fig

# def get_correlations(
# self, datasets: List[int], variables: Union[List[int], List[str]]
# ):
# """
# Return the correlations between datasets for
# the specified variables.

# Parameters
# ----------
# variables: list[str], list[int]
# The variables to return the correlations for.

# Returns
# -------
# dict
# The correlations at each iteration for the specified
# variables.

# """

# if self.dataset_count() < 3:
# raise ValueError(
# "Not enough datasets to calculate correlations between them"
# )
# curr_iteration = self.iteration_count()
# var_indx = self._get_var_ind_from_list(variables)

# # For every variable, get the correlations between every dataset combination
# # at each iteration
# correlation_dict = {}
# if self.save_all_iterations:
# iter_range = list(range(1, curr_iteration + 1))
# else:
# # Make this iterable for code tidyness
# iter_range = [curr_iteration]

# for var in var_indx:
# # Get a dict of variables and imputations for all datasets for this iteration
# iteration_level_imputations = {
# iteration: {ds: self[ds, var, iteration] for ds in datasets}
# for iteration in iter_range
# }

# combination_correlations = {
# iteration: [
# round(np.corrcoef(impcomb)[0, 1], 3)
# for impcomb in list(combinations(varimps.values(), 2))
# ]
# for iteration, varimps in iteration_level_imputations.items()
# }

# correlation_dict[var] = combination_correlations

# return correlation_dict

# def plot_correlations(self, datasets=None, variables=None, **adj_args):
# """
# Plot the correlations between datasets.
# See get_correlations() for more details.

# Parameters
# ----------
# datasets: None or list[int]
# The datasets to plot.
# variables: None,list
# The variables to plot.
# adj_args
# Additional arguments passed to plt.subplots_adjust()

# """

# try:
# import matplotlib.pyplot as plt
# from matplotlib import gridspec
# except ImportError:
# raise ImportError("matplotlib must be installed to plot importance")

# if self.dataset_count() < 4:
# raise ValueError("Not enough datasets to make box plot")
# if datasets is None:
# datasets = list(range(self.dataset_count()))
# else:
# datasets = _ensure_iterable(datasets)
# var_indx = self._get_var_ind_from_list(variables)
# num_vars = self._get_num_vars(var_indx)
# plots, plotrows, plotcols = self._prep_multi_plot(num_vars)
# correlation_dict = self.get_correlations(datasets=datasets, variables=num_vars)
# gs = gridspec.GridSpec(plotrows, plotcols)
# fig, ax = plt.subplots(plotrows, plotcols, squeeze=False)

# for v in range(plots):
# axr, axc = next(iter(gs[v].rowspan)), next(iter(gs[v].colspan))
# var = list(correlation_dict)[v]
# ax[axr, axc].boxplot(
# list(correlation_dict[var].values()),
# labels=range(len(correlation_dict[var])),
# )
# ax[axr, axc].set_title(self._get_var_name_from_scalar(var))
# ax[axr, axc].set_xlabel("Iteration")
# ax[axr, axc].set_ylabel("Correlations")
# ax[axr, axc].set_ylim([-1, 1])
# plt.subplots_adjust(**adj_args)
def plot_mean_convergence(
self,
variables: Optional[List[str]] = None,
):
"""
Plots the average value and standard deviation of imputations over each iteration.
The lines show the average imputation value for a dataset over the iteration.
The bars show the average standard deviation of the imputation values within datasets.
Parameters
----------
variables: Optional[List[str]], default=None
The variables to plot. By default, all numeric, imputed variables are plotted.
"""

try:
from plotnine import (
aes,
element_text,
facet_wrap,
geom_errorbar,
geom_line,
geom_point,
ggplot,
ggtitle,
theme,
theme_538,
xlab,
ylab,
)
except ImportError:
raise ImportError("plotnine must be installed to plot distributions.")

num_vars = self.working_data.select_dtypes("number").columns.to_list()
imp_vars = self.imputed_variables
imp_num_vars = [v for v in num_vars if v in imp_vars]
if variables is None:
variables = imp_num_vars
else:
variables = [v for v in variables if v in imp_num_vars]

plot_dat = DataFrame()
for variable in variables:
dat = self.imputation_values[variable].melt(col_level="iteration")
dat["dataset"] = self.imputation_values[variable].melt(col_level="dataset")[
"dataset"
]
dat = (
dat.groupby(["dataset", "iteration"])
.agg({"value": ["mean", "std"]})
.reset_index()
)
dat["middle"] = dat[("value", "mean")]
dat["upper"] = dat["middle"] + dat[("value", "std")]
dat["lower"] = dat["middle"] - dat[("value", "std")]
del dat["value"]
dat.columns = dat.columns.droplevel(1)
iter_dat = dat.groupby("iteration").agg(
{"lower": "mean", "middle": "mean", "upper": "mean"}
)
dat["lower"] = dat.iteration.map(iter_dat["lower"])
dat["stdavg"] = dat.iteration.map(iter_dat["middle"])
dat["upper"] = dat.iteration.map(iter_dat["upper"])
dat["variable"] = variable
plot_dat = concat([dat, plot_dat], axis=0)

fig = (
ggplot(plot_dat, aes(x="iteration", y="middle", group="dataset"))
+ geom_line()
+ geom_errorbar(
aes(x="iteration", ymin="lower", ymax="upper", group="dataset")
)
+ geom_point(aes(x="iteration", y="stdavg"))
+ facet_wrap("variable", scales="free")
+ ggtitle("Mean Convergence Plot")
+ xlab("")
+ ylab("")
+ theme(
plot_title=element_text(ha="left", size=20),
)
+ theme_538()
)

return fig
1 change: 1 addition & 0 deletions tests/test_ImputationKernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def make_and_test_kernel(**kwargs):
# Test plotting
kernel.plot_imputed_distributions()
kernel.plot_feature_importance(dataset=0)
kernel.plot_mean_convergence()

return kernel

Expand Down

0 comments on commit fcf447e

Please sign in to comment.