Skip to content

Commit

Permalink
Allow specifying question and context columns in LLM dataset configs
Browse files Browse the repository at this point in the history
  • Loading branch information
whoseoyster committed Sep 25, 2023
1 parent a87a1d1 commit 8437057
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 1 deletion.
6 changes: 6 additions & 0 deletions openlayer/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ class LLMInputSchema(BaseDatasetSchema):
inputVariableNames = ma.fields.List(
ma.fields.Str(validate=COLUMN_NAME_VALIDATION_LIST), required=True
)
contextColumnName = ma.fields.Str(
validate=COLUMN_NAME_VALIDATION_LIST, allow_none=True, load_default=None
)
questionColumnName = ma.fields.Str(
validate=COLUMN_NAME_VALIDATION_LIST, allow_none=True, load_default=None
)


class TabularInputSchema(BaseDatasetSchema):
Expand Down
46 changes: 46 additions & 0 deletions openlayer/validators/dataset_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,14 +199,22 @@ class LLInputValidator(BaseDatasetValidator):
"""

input_variable_names: Optional[List[str]] = None
context_column_name: Optional[str] = None
question_column_name: Optional[str] = None

def _validate_inputs(self):
"""Validates LLM inputs."""
# Setting the attributes needed for the validations
self.input_variable_names = self.dataset_config.get("inputVariableNames")
self.context_column_name = self.dataset_config.get("contextColumnName")
self.question_column_name = self.dataset_config.get("questionColumnName")

if self.input_variable_names:
self._validate_input_variables()
if self.context_column_name:
self._validate_context()
if self.question_column_name:
self._validate_question()

def _validate_input_variables(self):
"""Validates the data in the input variables columns."""
Expand Down Expand Up @@ -234,6 +242,44 @@ def _validate_input_variables(self):
"`inputVariableNames` do not exceed the maximum character limit."
)

def _validate_context(self):
"""Validations on the ground truth column."""
if self.context_column_name not in self.dataset_df.columns:
self.failed_validations.append(
f"The context column `{self.context_column_name}` specified as"
" `contextColumnName` is not in the dataset."
)
elif not hasattr(self.dataset_df[self.context_column_name], "str"):
self.failed_validations.append(
f"The context column `{self.context_column_name}` specified as"
" `contextColumnName` is not a string column."
)
elif exceeds_character_limit(self.dataset_df, self.context_column_name):
self.failed_validations.append(
f"The ground truth column `{self.context_column_name}` specified as"
" `contextColumnName` contains strings that exceed the "
f" {constants.MAXIMUM_CHARACTER_LIMIT} character limit."
)

def _validate_question(self):
"""Validations on the ground truth column."""
if self.question_column_name not in self.dataset_df.columns:
self.failed_validations.append(
f"The question column `{self.question_column_name}` specified as"
" `questionColumnName` is not in the dataset."
)
elif not hasattr(self.dataset_df[self.question_column_name], "str"):
self.failed_validations.append(
f"The question column `{self.question_column_name}` specified as"
" `questionColumnName` is not a string column."
)
elif exceeds_character_limit(self.dataset_df, self.question_column_name):
self.failed_validations.append(
f"The ground truth column `{self.question_column_name}` specified as"
" `questionColumnName` contains strings that exceed the "
f" {constants.MAXIMUM_CHARACTER_LIMIT} character limit."
)

@staticmethod
def _input_variables_not_castable_to_str(
dataset_df: pd.DataFrame,
Expand Down
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ install_requires =
openai
pandas
pybars3
requests
requests_toolbelt
requests>=2.28.2
tabulate
Expand Down

0 comments on commit 8437057

Please sign in to comment.