Skip to content

Commit

Permalink
Add support feature display name in datasets. (#1161)
Browse files Browse the repository at this point in the history
* Add support feature display name in datasets.

* Fix possible None in features.

* Take proper display name for features.
Fix display name override for HuggingFaceToxicityModel.
  • Loading branch information
Liraim authored Jun 19, 2024
1 parent 5e70d53 commit a4b7442
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 3 deletions.
10 changes: 10 additions & 0 deletions src/evidently/calculation_engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from evidently.base_metric import Metric
from evidently.base_metric import MetricResult
from evidently.calculation_engine.metric_implementation import MetricImplementation
from evidently.features.generated_features import GeneratedFeature
from evidently.utils.data_preprocessing import DataDefinition

TMetricImplementation = TypeVar("TMetricImplementation", bound=MetricImplementation)
TInputData = TypeVar("TInputData")
Expand Down Expand Up @@ -90,6 +92,14 @@ def get_metric_execution_iterator(self) -> List[Tuple[Metric, TMetricImplementat

return [(metric, self.get_metric_implementation(metric_to_calculations[metric])) for metric in self.metrics]

def form_datasets(
self,
data: Optional[TInputData],
features: Optional[Dict[tuple, GeneratedFeature]],
data_definition: DataDefinition,
):
raise NotImplementedError()


def _aggregate_metrics(agg, item):
agg[type(item)] = agg.get(type(item), []) + [item]
Expand Down
29 changes: 29 additions & 0 deletions src/evidently/calculation_engine/python_engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import abc
import logging
from typing import Dict
from typing import Generic
from typing import Optional
from typing import TypeVar
Expand All @@ -12,6 +13,8 @@
from evidently.base_metric import Metric
from evidently.calculation_engine.engine import Engine
from evidently.calculation_engine.metric_implementation import MetricImplementation
from evidently.features.generated_features import GeneratedFeature
from evidently.utils.data_preprocessing import DataDefinition
from evidently.utils.data_preprocessing import create_data_definition


Expand Down Expand Up @@ -98,6 +101,32 @@ def calculate(self, context, data: PythonInputData):
return _Wrapper(self, metric)
return impl

def form_datasets(
self,
data: Optional[PythonInputData],
features: Optional[Dict[tuple, GeneratedFeature]],
data_definition: DataDefinition,
):
if data is None:
return None, None
if features is not None:
rename = {x.feature_name().name: x.feature_name().display_name for x in features.values()}
else:
rename = {}
current = data.current_data
if data.current_additional_features is not None:
current = data.current_data.join(data.current_additional_features)

current = current.rename(columns=rename)
reference = data.reference_data
if data.reference_data is not None and data.reference_additional_features is not None:
reference = data.reference_data.join(data.reference_additional_features)

if reference is not None:
reference = reference.rename(columns=rename)

return reference, current


class PythonMetricImplementation(Generic[TMetric], MetricImplementation):
def __init__(self, engine: PythonEngine, metric: TMetric):
Expand Down
2 changes: 1 addition & 1 deletion src/evidently/descriptors/hf_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def feature(self, column_name: str) -> GeneratedFeature:
model_str = "" if self.model is None else f"({self.model})"
return HuggingFaceToxicityFeature(
column_name=column_name,
display_name=f"Hugging Face Toxicity {model_str} for {column_name}",
display_name=self.display_name or f"Hugging Face Toxicity {model_str} for {column_name}",
model=self.model,
toxic_label=self.toxic_label,
)
2 changes: 1 addition & 1 deletion src/evidently/features/generated_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get_parameters(self) -> Optional[tuple]:
hash(params)
except TypeError:
logging.warning(f"unhashable params for {type(self)}. Fallback to unique.")
return None
return (self.feature_id,)
return params


Expand Down
6 changes: 5 additions & 1 deletion src/evidently/suite/base_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ def get_data_definition(
)
return self.data_definition

def get_datasets(self):
return self.engine.form_datasets(self.data, self.features, self.data_definition)


class ContextPayload(BaseModel):
metrics: List[Metric]
Expand Down Expand Up @@ -525,4 +528,5 @@ def to_snapshot(self):
return self._get_snapshot()

def datasets(self):
return self._inner_suite.context.data.get_datasets()
datasets = self._inner_suite.context.get_datasets()
return datasets

0 comments on commit a4b7442

Please sign in to comment.