Skip to content

Commit

Permalink
dev: filter_on_column, tests big refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
mysiar committed Oct 16, 2024
1 parent bb38493 commit 4946699
Show file tree
Hide file tree
Showing 9 changed files with 210 additions and 93 deletions.
38 changes: 32 additions & 6 deletions data_flow/data_flow.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
import os
import tempfile
from typing import Any

import fireducks.pandas as fd
import pandas as pd
import polars as pl
from pyarrow import feather

from data_flow.lib import FileType
from data_flow.lib.data_columns import data_get_columns, data_delete_columns, data_rename_columns, data_select_columns
from data_flow.lib import FileType, Operator
from data_flow.lib.data_columns import (
data_get_columns,
data_delete_columns,
data_rename_columns,
data_select_columns,
data_filter_on_column,
)
from data_flow.lib.data_from import (
from_csv_2_file,
from_feather_2_file,
Expand Down Expand Up @@ -188,7 +195,26 @@ def columns_select(self, columns: list):
else:
data_select_columns(tmp_filename=self.__filename, file_type=self.__file_type, columns=columns)

# def filter_on_column(self, column: str, value: Any, operator: Operator):
# if self.__in_memory:
#
#
def filter_on_column(self, column: str, value: Any, operator: Operator):
if self.__in_memory:
match operator:
case Operator.Eq:
self.__data = self.__data[self.__data[column] == value]
case Operator.Gte:
self.__data = self.__data[self.__data[column] >= value]
case Operator.Lte:
self.__data = self.__data[self.__data[column] <= value]
case Operator.Gt:
self.__data = self.__data[self.__data[column] > value]
case Operator.Lt:
self.__data = self.__data[self.__data[column] < value]
case Operator.Ne:
self.__data = self.__data[self.__data[column] != value]
else:
data_filter_on_column(
tmp_filename=self.__filename,
file_type=self.__file_type,
column=column,
value=value,
operator=operator,
)
33 changes: 33 additions & 0 deletions data_flow/lib/data_columns.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Any

import fireducks.pandas as fd

from data_flow.lib.FileType import FileType
from data_flow.lib.Operator import Operator


def data_get_columns(tmp_filename: str, file_type: FileType) -> list:
Expand Down Expand Up @@ -47,3 +50,33 @@ def data_select_columns(tmp_filename: str, file_type: FileType, columns: list) -
data.to_feather(tmp_filename)
case _:
raise ValueError(f"File type not implemented: {file_type} !")


def data_filter_on_column(tmp_filename: str, file_type: FileType, column: str, value: Any, operator: Operator) -> None:
match file_type:
case FileType.parquet:
data = fd.read_parquet(tmp_filename)
case FileType.feather:
data = fd.read_feather(tmp_filename)
case _:
raise ValueError(f"File type not implemented: {file_type} !")

match operator:
case Operator.Eq:
data = data[data[column] == value]
case Operator.Gte:
data = data[data[column] >= value]
case Operator.Lte:
data = data[data[column] <= value]
case Operator.Gt:
data = data[data[column] > value]
case Operator.Lt:
data = data[data[column] < value]
case Operator.Ne:
data = data[data[column] != value]

match file_type:
case FileType.parquet:
data.to_parquet(tmp_filename)
case FileType.feather:
data.to_feather(tmp_filename)
80 changes: 80 additions & 0 deletions tests/BaseTestCase.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import unittest
from typing import Callable
from zipfile import ZipFile

import pandas as pd

from data_flow import DataFlow
from data_flow.lib import Operator


class BaseTestCase(unittest.TestCase):
def setUp(self):
Expand All @@ -18,3 +22,79 @@ def setUp(self):

def assertPandasEqual(self, df1: pd.DataFrame, df2: pd.DataFrame):
self.assertTrue(df1.equals(df2), "Pandas DataFrames are not equal !")

def all(self, function: Callable):
self._sequence(data=function())
self._filter_Eq(data=function())
self._filter_Gte(data=function())
self._filter_Lte(data=function())
self._filter_Gt(data=function())
self._filter_Lt(data=function())
self._filter_Ne(data=function())

# @count_assertions
def _sequence(self, data: DataFlow.DataFrame) -> None:
self.assertPandasEqual(data.to_pandas(), DataFlow().DataFrame().from_csv(self.CSV_FILE).to_pandas())
polars = data.to_polars()

self.assertEqual(10, len(data.columns()))

data.columns_delete(
[
"Industry_aggregation_NZSIOC",
"Industry_code_NZSIOC",
"Industry_name_NZSIOC",
"Industry_code_ANZSIC06",
"Variable_code",
"Variable_name",
"Variable_category",
]
)

self.assertEqual(3, len(data.columns()))
self.assertListEqual(["Year", "Units", "Value"], data.columns())

data.columns_rename(columns_mapping={"Year": "_year_", "Units": "_units_"})
self.assertListEqual(["_year_", "_units_", "Value"], data.columns())

data.columns_select(columns=["_year_"])
self.assertListEqual(["_year_"], data.columns())

self.assertPandasEqual(
DataFlow().DataFrame().from_polars(polars).to_pandas(),
DataFlow().DataFrame().from_csv(self.CSV_FILE).to_pandas(),
)

def _filter_Eq(self, data: DataFlow.DataFrame) -> None:
data.filter_on_column(column="Year", operator=Operator.Eq, value=2018)
self.assertListEqual([2018], list(data.to_pandas().Year.unique()))

def _filter_Gte(self, data: DataFlow.DataFrame) -> None:
data.filter_on_column(column="Year", operator=Operator.Gte, value=2018)
result = data.to_pandas().Year.unique().tolist()
result.sort()
self.assertListEqual([2018, 2019, 2020, 2021, 2022, 2023], result)

def _filter_Lte(self, data: DataFlow.DataFrame) -> None:
data.filter_on_column(column="Year", operator=Operator.Lte, value=2018)
result = data.to_pandas().Year.unique().tolist()
result.sort()
self.assertListEqual([2013, 2014, 2015, 2016, 2017, 2018], result)

def _filter_Gt(self, data: DataFlow.DataFrame) -> None:
data.filter_on_column(column="Year", operator=Operator.Gt, value=2018)
result = data.to_pandas().Year.unique().tolist()
result.sort()
self.assertListEqual([2019, 2020, 2021, 2022, 2023], result)

def _filter_Lt(self, data: DataFlow.DataFrame) -> None:
data.filter_on_column(column="Year", operator=Operator.Lt, value=2018)
result = data.to_pandas().Year.unique().tolist()
result.sort()
self.assertListEqual([2013, 2014, 2015, 2016, 2017], result)

def _filter_Ne(self, data: DataFlow.DataFrame) -> None:
data.filter_on_column(column="Year", operator=Operator.Ne, value=2018)
result = data.to_pandas().Year.unique().tolist()
result.sort()
self.assertListEqual([2013, 2014, 2015, 2016, 2017, 2019, 2020, 2021, 2022, 2023], result)
37 changes: 0 additions & 37 deletions tests/SequenceTestCase.py

This file was deleted.

23 changes: 13 additions & 10 deletions tests/test_data_flow_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,32 @@
from data_flow import DataFlow
from data_flow.lib import FileType
from data_flow.lib.tools import delete_file
from tests.SequenceTestCase import SequenceTestCase
from tests.BaseTestCase import BaseTestCase


class DataFlowCSVTestCase(SequenceTestCase):
class DataFlowCSVTestCase(BaseTestCase):
def setUp(self):
super().setUp()
delete_file(self.TEST_CSV_FILE)
DataFlow().DataFrame().from_csv(self.CSV_FILE).to_csv(self.TEST_CSV_FILE)

def test_memory(self):
df = DataFlow().DataFrame().from_csv(self.TEST_CSV_FILE)

self._sequence(data=df)
self.all(self.__memory)

def test_parquet(self):
df = DataFlow().DataFrame(in_memory=False).from_csv(self.TEST_CSV_FILE)

self._sequence(data=df)
self.all(self.__parquet)

def test_feather(self):
df = DataFlow().DataFrame(in_memory=False, file_type=FileType.feather).from_csv(self.TEST_CSV_FILE)
self.all(self.__feather)

def __memory(self) -> DataFlow.DataFrame:
return DataFlow().DataFrame().from_csv(self.TEST_CSV_FILE)

def __parquet(self) -> DataFlow.DataFrame:
return DataFlow().DataFrame(in_memory=False).from_csv(self.TEST_CSV_FILE)

self._sequence(data=df)
def __feather(self) -> DataFlow.DataFrame:
return DataFlow().DataFrame(in_memory=False, file_type=FileType.feather).from_csv(self.TEST_CSV_FILE)


if __name__ == "__main__":
Expand Down
23 changes: 13 additions & 10 deletions tests/test_data_flow_feather.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,32 @@
from data_flow import DataFlow
from data_flow.lib import FileType
from data_flow.lib.tools import delete_file
from tests.SequenceTestCase import SequenceTestCase
from tests.BaseTestCase import BaseTestCase


class DataFlowFeatherTestCase(SequenceTestCase):
class DataFlowFeatherTestCase(BaseTestCase):
def setUp(self):
super().setUp()
delete_file(self.TEST_FEATHER_FILE)
DataFlow().DataFrame().from_csv(self.CSV_FILE).to_feather(self.TEST_FEATHER_FILE)

def test_memory(self):
df = DataFlow().DataFrame().from_feather(self.TEST_FEATHER_FILE)

self._sequence(data=df)
self.all(self.__memory)

def test_parquet(self):
df = DataFlow().DataFrame(in_memory=False).from_feather(self.TEST_FEATHER_FILE)

self._sequence(data=df)
self.all(self.__parquet)

def test_feather(self):
df = DataFlow().DataFrame(in_memory=False, file_type=FileType.feather).from_feather(self.TEST_FEATHER_FILE)
self.all(self.__feather)

def __memory(self) -> DataFlow.DataFrame:
return DataFlow().DataFrame().from_feather(self.TEST_FEATHER_FILE)

def __parquet(self) -> DataFlow.DataFrame:
return DataFlow().DataFrame(in_memory=False).from_feather(self.TEST_FEATHER_FILE)

self._sequence(data=df)
def __feather(self) -> DataFlow.DataFrame:
return DataFlow().DataFrame(in_memory=False, file_type=FileType.feather).from_feather(self.TEST_FEATHER_FILE)


if __name__ == "__main__":
Expand Down
23 changes: 13 additions & 10 deletions tests/test_data_flow_hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,32 @@
from data_flow import DataFlow
from data_flow.lib import FileType
from data_flow.lib.tools import delete_file
from tests.SequenceTestCase import SequenceTestCase
from tests.BaseTestCase import BaseTestCase


class DataFlowHdfTestCase(SequenceTestCase):
class DataFlowHdfTestCase(BaseTestCase):
def setUp(self):
super().setUp()
delete_file(self.TEST_HDF_FILE)
DataFlow().DataFrame().from_csv(self.CSV_FILE).to_hdf(self.TEST_HDF_FILE)

def test_memory(self):
df = DataFlow().DataFrame().from_hdf(self.TEST_HDF_FILE)

self._sequence(data=df)
self.all(self.__memory)

def test_parquet(self):
df = DataFlow().DataFrame(in_memory=False).from_hdf(self.TEST_HDF_FILE)

self._sequence(data=df)
self.all(self.__parquet)

def test_feather(self):
df = DataFlow().DataFrame(in_memory=False, file_type=FileType.feather).from_hdf(self.TEST_HDF_FILE)
self.all(self.__feather)

def __memory(self) -> DataFlow.DataFrame:
return DataFlow().DataFrame().from_hdf(self.TEST_HDF_FILE)

def __parquet(self) -> DataFlow.DataFrame:
return DataFlow().DataFrame(in_memory=False).from_hdf(self.TEST_HDF_FILE)

self._sequence(data=df)
def __feather(self) -> DataFlow.DataFrame:
return DataFlow().DataFrame(in_memory=False, file_type=FileType.feather).from_hdf(self.TEST_HDF_FILE)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 4946699

Please sign in to comment.