From 4946699469a714d3a57eeaa74ff4aaefaa493f59 Mon Sep 17 00:00:00 2001 From: Piotr Synowiec Date: Wed, 16 Oct 2024 18:45:03 +0200 Subject: [PATCH] dev: filter_on_column, tests big refactor --- data_flow/data_flow.py | 38 +++++++++++++--- data_flow/lib/data_columns.py | 33 ++++++++++++++ tests/BaseTestCase.py | 80 +++++++++++++++++++++++++++++++++ tests/SequenceTestCase.py | 37 --------------- tests/test_data_flow_csv.py | 23 +++++----- tests/test_data_flow_feather.py | 23 +++++----- tests/test_data_flow_hdf.py | 23 +++++----- tests/test_data_flow_json.py | 23 +++++----- tests/test_data_flow_parquet.py | 23 +++++----- 9 files changed, 210 insertions(+), 93 deletions(-) delete mode 100644 tests/SequenceTestCase.py diff --git a/data_flow/data_flow.py b/data_flow/data_flow.py index ca8ad0f..592e353 100644 --- a/data_flow/data_flow.py +++ b/data_flow/data_flow.py @@ -1,13 +1,20 @@ import os import tempfile +from typing import Any import fireducks.pandas as fd import pandas as pd import polars as pl from pyarrow import feather -from data_flow.lib import FileType -from data_flow.lib.data_columns import data_get_columns, data_delete_columns, data_rename_columns, data_select_columns +from data_flow.lib import FileType, Operator +from data_flow.lib.data_columns import ( + data_get_columns, + data_delete_columns, + data_rename_columns, + data_select_columns, + data_filter_on_column, +) from data_flow.lib.data_from import ( from_csv_2_file, from_feather_2_file, @@ -188,7 +195,26 @@ def columns_select(self, columns: list): else: data_select_columns(tmp_filename=self.__filename, file_type=self.__file_type, columns=columns) - # def filter_on_column(self, column: str, value: Any, operator: Operator): - # if self.__in_memory: - # - # + def filter_on_column(self, column: str, value: Any, operator: Operator): + if self.__in_memory: + match operator: + case Operator.Eq: + self.__data = self.__data[self.__data[column] == value] + case Operator.Gte: + self.__data = self.__data[self.__data[column] >= value] + case Operator.Lte: + self.__data = self.__data[self.__data[column] <= value] + case Operator.Gt: + self.__data = self.__data[self.__data[column] > value] + case Operator.Lt: + self.__data = self.__data[self.__data[column] < value] + case Operator.Ne: + self.__data = self.__data[self.__data[column] != value] + else: + data_filter_on_column( + tmp_filename=self.__filename, + file_type=self.__file_type, + column=column, + value=value, + operator=operator, + ) diff --git a/data_flow/lib/data_columns.py b/data_flow/lib/data_columns.py index dfdfd5a..2d10f76 100644 --- a/data_flow/lib/data_columns.py +++ b/data_flow/lib/data_columns.py @@ -1,6 +1,9 @@ +from typing import Any + import fireducks.pandas as fd from data_flow.lib.FileType import FileType +from data_flow.lib.Operator import Operator def data_get_columns(tmp_filename: str, file_type: FileType) -> list: @@ -47,3 +50,33 @@ def data_select_columns(tmp_filename: str, file_type: FileType, columns: list) - data.to_feather(tmp_filename) case _: raise ValueError(f"File type not implemented: {file_type} !") + + +def data_filter_on_column(tmp_filename: str, file_type: FileType, column: str, value: Any, operator: Operator) -> None: + match file_type: + case FileType.parquet: + data = fd.read_parquet(tmp_filename) + case FileType.feather: + data = fd.read_feather(tmp_filename) + case _: + raise ValueError(f"File type not implemented: {file_type} !") + + match operator: + case Operator.Eq: + data = data[data[column] == value] + case Operator.Gte: + data = data[data[column] >= value] + case Operator.Lte: + data = data[data[column] <= value] + case Operator.Gt: + data = data[data[column] > value] + case Operator.Lt: + data = data[data[column] < value] + case Operator.Ne: + data = data[data[column] != value] + + match file_type: + case FileType.parquet: + data.to_parquet(tmp_filename) + case FileType.feather: + data.to_feather(tmp_filename) diff --git a/tests/BaseTestCase.py b/tests/BaseTestCase.py index be9cf19..e76cd59 100644 --- a/tests/BaseTestCase.py +++ b/tests/BaseTestCase.py @@ -1,8 +1,12 @@ import unittest +from typing import Callable from zipfile import ZipFile import pandas as pd +from data_flow import DataFlow +from data_flow.lib import Operator + class BaseTestCase(unittest.TestCase): def setUp(self): @@ -18,3 +22,79 @@ def setUp(self): def assertPandasEqual(self, df1: pd.DataFrame, df2: pd.DataFrame): self.assertTrue(df1.equals(df2), "Pandas DataFrames are not equal !") + + def all(self, function: Callable): + self._sequence(data=function()) + self._filter_Eq(data=function()) + self._filter_Gte(data=function()) + self._filter_Lte(data=function()) + self._filter_Gt(data=function()) + self._filter_Lt(data=function()) + self._filter_Ne(data=function()) + + # @count_assertions + def _sequence(self, data: DataFlow.DataFrame) -> None: + self.assertPandasEqual(data.to_pandas(), DataFlow().DataFrame().from_csv(self.CSV_FILE).to_pandas()) + polars = data.to_polars() + + self.assertEqual(10, len(data.columns())) + + data.columns_delete( + [ + "Industry_aggregation_NZSIOC", + "Industry_code_NZSIOC", + "Industry_name_NZSIOC", + "Industry_code_ANZSIC06", + "Variable_code", + "Variable_name", + "Variable_category", + ] + ) + + self.assertEqual(3, len(data.columns())) + self.assertListEqual(["Year", "Units", "Value"], data.columns()) + + data.columns_rename(columns_mapping={"Year": "_year_", "Units": "_units_"}) + self.assertListEqual(["_year_", "_units_", "Value"], data.columns()) + + data.columns_select(columns=["_year_"]) + self.assertListEqual(["_year_"], data.columns()) + + self.assertPandasEqual( + DataFlow().DataFrame().from_polars(polars).to_pandas(), + DataFlow().DataFrame().from_csv(self.CSV_FILE).to_pandas(), + ) + + def _filter_Eq(self, data: DataFlow.DataFrame) -> None: + data.filter_on_column(column="Year", operator=Operator.Eq, value=2018) + self.assertListEqual([2018], list(data.to_pandas().Year.unique())) + + def _filter_Gte(self, data: DataFlow.DataFrame) -> None: + data.filter_on_column(column="Year", operator=Operator.Gte, value=2018) + result = data.to_pandas().Year.unique().tolist() + result.sort() + self.assertListEqual([2018, 2019, 2020, 2021, 2022, 2023], result) + + def _filter_Lte(self, data: DataFlow.DataFrame) -> None: + data.filter_on_column(column="Year", operator=Operator.Lte, value=2018) + result = data.to_pandas().Year.unique().tolist() + result.sort() + self.assertListEqual([2013, 2014, 2015, 2016, 2017, 2018], result) + + def _filter_Gt(self, data: DataFlow.DataFrame) -> None: + data.filter_on_column(column="Year", operator=Operator.Gt, value=2018) + result = data.to_pandas().Year.unique().tolist() + result.sort() + self.assertListEqual([2019, 2020, 2021, 2022, 2023], result) + + def _filter_Lt(self, data: DataFlow.DataFrame) -> None: + data.filter_on_column(column="Year", operator=Operator.Lt, value=2018) + result = data.to_pandas().Year.unique().tolist() + result.sort() + self.assertListEqual([2013, 2014, 2015, 2016, 2017], result) + + def _filter_Ne(self, data: DataFlow.DataFrame) -> None: + data.filter_on_column(column="Year", operator=Operator.Ne, value=2018) + result = data.to_pandas().Year.unique().tolist() + result.sort() + self.assertListEqual([2013, 2014, 2015, 2016, 2017, 2019, 2020, 2021, 2022, 2023], result) diff --git a/tests/SequenceTestCase.py b/tests/SequenceTestCase.py deleted file mode 100644 index 992279a..0000000 --- a/tests/SequenceTestCase.py +++ /dev/null @@ -1,37 +0,0 @@ -from data_flow import DataFlow -from tests.BaseTestCase import BaseTestCase - - -class SequenceTestCase(BaseTestCase): - def _sequence(self, data: DataFlow.DataFrame) -> None: - self.assertPandasEqual(data.to_pandas(), DataFlow().DataFrame().from_csv(self.CSV_FILE).to_pandas()) - - polars = data.to_polars() - - self.assertEqual(10, len(data.columns())) - - data.columns_delete( - [ - "Industry_aggregation_NZSIOC", - "Industry_code_NZSIOC", - "Industry_name_NZSIOC", - "Industry_code_ANZSIC06", - "Variable_code", - "Variable_name", - "Variable_category", - ] - ) - - self.assertEqual(3, len(data.columns())) - self.assertListEqual(["Year", "Units", "Value"], data.columns()) - - data.columns_rename(columns_mapping={"Year": "_year_", "Units": "_units_"}) - self.assertListEqual(["_year_", "_units_", "Value"], data.columns()) - - data.columns_select(columns=["_year_"]) - self.assertListEqual(["_year_"], data.columns()) - - self.assertPandasEqual( - DataFlow().DataFrame().from_polars(polars).to_pandas(), - DataFlow().DataFrame().from_csv(self.CSV_FILE).to_pandas(), - ) diff --git a/tests/test_data_flow_csv.py b/tests/test_data_flow_csv.py index 9eb8631..f68d2f5 100644 --- a/tests/test_data_flow_csv.py +++ b/tests/test_data_flow_csv.py @@ -3,29 +3,32 @@ from data_flow import DataFlow from data_flow.lib import FileType from data_flow.lib.tools import delete_file -from tests.SequenceTestCase import SequenceTestCase +from tests.BaseTestCase import BaseTestCase -class DataFlowCSVTestCase(SequenceTestCase): +class DataFlowCSVTestCase(BaseTestCase): def setUp(self): super().setUp() delete_file(self.TEST_CSV_FILE) DataFlow().DataFrame().from_csv(self.CSV_FILE).to_csv(self.TEST_CSV_FILE) def test_memory(self): - df = DataFlow().DataFrame().from_csv(self.TEST_CSV_FILE) - - self._sequence(data=df) + self.all(self.__memory) def test_parquet(self): - df = DataFlow().DataFrame(in_memory=False).from_csv(self.TEST_CSV_FILE) - - self._sequence(data=df) + self.all(self.__parquet) def test_feather(self): - df = DataFlow().DataFrame(in_memory=False, file_type=FileType.feather).from_csv(self.TEST_CSV_FILE) + self.all(self.__feather) + + def __memory(self) -> DataFlow.DataFrame: + return DataFlow().DataFrame().from_csv(self.TEST_CSV_FILE) + + def __parquet(self) -> DataFlow.DataFrame: + return DataFlow().DataFrame(in_memory=False).from_csv(self.TEST_CSV_FILE) - self._sequence(data=df) + def __feather(self) -> DataFlow.DataFrame: + return DataFlow().DataFrame(in_memory=False, file_type=FileType.feather).from_csv(self.TEST_CSV_FILE) if __name__ == "__main__": diff --git a/tests/test_data_flow_feather.py b/tests/test_data_flow_feather.py index 74951e6..ec69196 100644 --- a/tests/test_data_flow_feather.py +++ b/tests/test_data_flow_feather.py @@ -3,29 +3,32 @@ from data_flow import DataFlow from data_flow.lib import FileType from data_flow.lib.tools import delete_file -from tests.SequenceTestCase import SequenceTestCase +from tests.BaseTestCase import BaseTestCase -class DataFlowFeatherTestCase(SequenceTestCase): +class DataFlowFeatherTestCase(BaseTestCase): def setUp(self): super().setUp() delete_file(self.TEST_FEATHER_FILE) DataFlow().DataFrame().from_csv(self.CSV_FILE).to_feather(self.TEST_FEATHER_FILE) def test_memory(self): - df = DataFlow().DataFrame().from_feather(self.TEST_FEATHER_FILE) - - self._sequence(data=df) + self.all(self.__memory) def test_parquet(self): - df = DataFlow().DataFrame(in_memory=False).from_feather(self.TEST_FEATHER_FILE) - - self._sequence(data=df) + self.all(self.__parquet) def test_feather(self): - df = DataFlow().DataFrame(in_memory=False, file_type=FileType.feather).from_feather(self.TEST_FEATHER_FILE) + self.all(self.__feather) + + def __memory(self) -> DataFlow.DataFrame: + return DataFlow().DataFrame().from_feather(self.TEST_FEATHER_FILE) + + def __parquet(self) -> DataFlow.DataFrame: + return DataFlow().DataFrame(in_memory=False).from_feather(self.TEST_FEATHER_FILE) - self._sequence(data=df) + def __feather(self) -> DataFlow.DataFrame: + return DataFlow().DataFrame(in_memory=False, file_type=FileType.feather).from_feather(self.TEST_FEATHER_FILE) if __name__ == "__main__": diff --git a/tests/test_data_flow_hdf.py b/tests/test_data_flow_hdf.py index dc1680e..5c26a54 100644 --- a/tests/test_data_flow_hdf.py +++ b/tests/test_data_flow_hdf.py @@ -3,29 +3,32 @@ from data_flow import DataFlow from data_flow.lib import FileType from data_flow.lib.tools import delete_file -from tests.SequenceTestCase import SequenceTestCase +from tests.BaseTestCase import BaseTestCase -class DataFlowHdfTestCase(SequenceTestCase): +class DataFlowHdfTestCase(BaseTestCase): def setUp(self): super().setUp() delete_file(self.TEST_HDF_FILE) DataFlow().DataFrame().from_csv(self.CSV_FILE).to_hdf(self.TEST_HDF_FILE) def test_memory(self): - df = DataFlow().DataFrame().from_hdf(self.TEST_HDF_FILE) - - self._sequence(data=df) + self.all(self.__memory) def test_parquet(self): - df = DataFlow().DataFrame(in_memory=False).from_hdf(self.TEST_HDF_FILE) - - self._sequence(data=df) + self.all(self.__parquet) def test_feather(self): - df = DataFlow().DataFrame(in_memory=False, file_type=FileType.feather).from_hdf(self.TEST_HDF_FILE) + self.all(self.__feather) + + def __memory(self) -> DataFlow.DataFrame: + return DataFlow().DataFrame().from_hdf(self.TEST_HDF_FILE) + + def __parquet(self) -> DataFlow.DataFrame: + return DataFlow().DataFrame(in_memory=False).from_hdf(self.TEST_HDF_FILE) - self._sequence(data=df) + def __feather(self) -> DataFlow.DataFrame: + return DataFlow().DataFrame(in_memory=False, file_type=FileType.feather).from_hdf(self.TEST_HDF_FILE) if __name__ == "__main__": diff --git a/tests/test_data_flow_json.py b/tests/test_data_flow_json.py index fa04d2f..5c03f26 100644 --- a/tests/test_data_flow_json.py +++ b/tests/test_data_flow_json.py @@ -3,29 +3,32 @@ from data_flow import DataFlow from data_flow.lib import FileType from data_flow.lib.tools import delete_file -from tests.SequenceTestCase import SequenceTestCase +from tests.BaseTestCase import BaseTestCase -class DataFlowJsonTestCase(SequenceTestCase): +class DataFlowJsonTestCase(BaseTestCase): def setUp(self): super().setUp() delete_file(self.TEST_JSON_FILE) DataFlow().DataFrame().from_csv(self.CSV_FILE).to_json(self.TEST_JSON_FILE) def test_memory(self): - df = DataFlow().DataFrame().from_json(self.TEST_JSON_FILE) - - self._sequence(data=df) + self.all(self.__memory) def test_parquet(self): - df = DataFlow().DataFrame(in_memory=False).from_json(self.TEST_JSON_FILE) - - self._sequence(data=df) + self.all(self.__parquet) def test_feather(self): - df = DataFlow().DataFrame(in_memory=False, file_type=FileType.feather).from_json(self.TEST_JSON_FILE) + self.all(self.__feather) + + def __memory(self) -> DataFlow.DataFrame: + return DataFlow().DataFrame().from_json(self.TEST_JSON_FILE) + + def __parquet(self) -> DataFlow.DataFrame: + return DataFlow().DataFrame(in_memory=False).from_json(self.TEST_JSON_FILE) - self._sequence(data=df) + def __feather(self) -> DataFlow.DataFrame: + return DataFlow().DataFrame(in_memory=False, file_type=FileType.feather).from_json(self.TEST_JSON_FILE) if __name__ == "__main__": diff --git a/tests/test_data_flow_parquet.py b/tests/test_data_flow_parquet.py index abc3008..c38c0f5 100644 --- a/tests/test_data_flow_parquet.py +++ b/tests/test_data_flow_parquet.py @@ -3,29 +3,32 @@ from data_flow import DataFlow from data_flow.lib import FileType from data_flow.lib.tools import delete_file -from tests.SequenceTestCase import SequenceTestCase +from tests.BaseTestCase import BaseTestCase -class DataFlowParquetTestCase(SequenceTestCase): +class DataFlowParquetTestCase(BaseTestCase): def setUp(self): super().setUp() delete_file(self.TEST_PARQUET_FILE) DataFlow().DataFrame().from_csv(self.CSV_FILE).to_parquet(self.TEST_PARQUET_FILE) def test_memory(self): - df = DataFlow().DataFrame().from_parquet(self.TEST_PARQUET_FILE) - - self._sequence(data=df) + self.all(self.__memory) def test_parquet(self): - df = DataFlow().DataFrame(in_memory=False).from_parquet(self.TEST_PARQUET_FILE) - - self._sequence(data=df) + self.all(self.__parquet) def test_feather(self): - df = DataFlow().DataFrame(in_memory=False, file_type=FileType.feather).from_parquet(self.TEST_PARQUET_FILE) + self.all(self.__feather) + + def __memory(self) -> DataFlow.DataFrame: + return DataFlow().DataFrame().from_parquet(self.TEST_PARQUET_FILE) + + def __parquet(self) -> DataFlow.DataFrame: + return DataFlow().DataFrame(in_memory=False).from_parquet(self.TEST_PARQUET_FILE) - self._sequence(data=df) + def __feather(self) -> DataFlow.DataFrame: + return DataFlow().DataFrame(in_memory=False, file_type=FileType.feather).from_parquet(self.TEST_PARQUET_FILE) if __name__ == "__main__":