From d9ff57498a9da0d745e875fea6b24e912d12c7ce Mon Sep 17 00:00:00 2001 From: Herminio Vazquez Date: Sat, 29 Jun 2024 17:51:34 +0200 Subject: [PATCH] Feature custom check (#267) * added is custom to pyspark * Added is_custom validation for pyspark --- README.md | 1 + cuallee/__init__.py | 20 ++++++++- cuallee/pyspark_validation.py | 39 ++++++++++++++-- cuallee/report/__init__.py | 7 +-- pyproject.toml | 2 +- setup.cfg | 2 +- .../cuallee/test_performance_cuallee.py | 2 +- .../greatexpectations/test_performance_gx.py | 8 +++- .../performance/soda/test_performance_soda.py | 9 ++-- test/unit/bigquery/test_is_complete.py | 4 +- test/unit/bigquery/test_is_daily.py | 12 ++--- test/unit/bigquery/test_is_unique.py | 12 ++--- test/unit/bigquery/test_not_contained_in.py | 12 ++--- test/unit/daft/test_are_complete.py | 5 ++- test/unit/daft/test_are_unique.py | 5 ++- test/unit/daft/test_has_pattern.py | 5 ++- test/unit/daft/test_has_workflow.py | 5 ++- test/unit/daft/test_is_between.py | 5 ++- test/unit/daft/test_is_complete.py | 5 ++- test/unit/daft/test_is_contained_in.py | 5 ++- test/unit/daft/test_is_daily.py | 7 ++- test/unit/daft/test_is_empty.py | 5 ++- test/unit/daft/test_is_equal_than.py | 6 ++- .../daft/test_is_greater_or_equal_than.py | 6 ++- test/unit/daft/test_is_greater_than.py | 5 ++- test/unit/daft/test_is_in_billions.py | 5 ++- test/unit/daft/test_is_in_millions.py | 5 ++- .../test_is_inside_interquartile_range.py | 5 ++- test/unit/daft/test_is_less_or_equal_than.py | 6 ++- test/unit/daft/test_is_less_than.py | 6 ++- test/unit/daft/test_is_negative.py | 5 ++- test/unit/daft/test_is_on_friday.py | 6 ++- test/unit/daft/test_is_on_monday.py | 6 ++- test/unit/daft/test_is_on_saturday.py | 7 ++- test/unit/daft/test_is_on_schedule.py | 5 ++- test/unit/daft/test_is_on_sunday.py | 5 ++- test/unit/daft/test_is_on_thursday.py | 7 ++- test/unit/daft/test_is_on_tuesday.py | 5 ++- test/unit/daft/test_is_on_wednesday.py | 5 ++- test/unit/daft/test_is_on_weekday.py | 5 ++- test/unit/daft/test_is_on_weekend.py | 5 ++- test/unit/daft/test_is_positive.py | 5 ++- test/unit/daft/test_is_unique.py | 5 ++- test/unit/daft/test_not_contained_in.py | 5 ++- test/unit/daft/test_satisfies.py | 5 ++- test/unit/pandas_dataframe/test_is_empty.py | 1 - test/unit/pyspark_dataframe/test_is_custom.py | 45 +++++++++++++++++++ test/unit/pyspark_dataframe/test_is_empty.py | 6 +-- 48 files changed, 273 insertions(+), 81 deletions(-) create mode 100644 test/unit/pyspark_dataframe/test_is_custom.py diff --git a/README.md b/README.md index 826dcbff..ba841211 100644 --- a/README.md +++ b/README.md @@ -239,6 +239,7 @@ Check | Description | DataType `is_on_schedule` | For date fields confirms time windows i.e. `9:00 - 17:00` | _timestamp_ `is_daily` | Can verify daily continuity on date fields by default. `[2,3,4,5,6]` which represents `Mon-Fri` in PySpark. However new schedules can be used for custom date continuity | _date_ `has_workflow` | Adjacency matrix validation on `3-column` graph, based on `group`, `event`, `order` columns. | _agnostic_ +`is_custom` | User-defined custom `function` applied to dataframe for row-based validation. | _agnostic_ `satisfies` | An open `SQL expression` builder to construct custom checks | _agnostic_ `validate` | The ultimate transformation of a check with a `dataframe` input for validation | _agnostic_ diff --git a/cuallee/__init__.py b/cuallee/__init__.py index ee95a789..ad01ef77 100644 --- a/cuallee/__init__.py +++ b/cuallee/__init__.py @@ -7,7 +7,7 @@ from dataclasses import dataclass from datetime import datetime, timedelta, timezone from types import ModuleType -from typing import Any, Dict, List, Literal, Optional, Protocol, Tuple, Union +from typing import Any, Dict, List, Literal, Optional, Protocol, Tuple, Union, Callable from toolz import compose, valfilter # type: ignore from toolz.curried import map as map_curried @@ -55,6 +55,8 @@ except (ModuleNotFoundError, ImportError): logger.debug("KO: BigQuery") +class CustomComputeException(Exception): + pass class CheckLevel(enum.Enum): """Level of verifications in cuallee""" @@ -1165,6 +1167,22 @@ def has_workflow( ) return self + def is_custom( + self, column: Union[str, List[str]], fn: Callable = None, pct: float = 1.0 + ): + """ + Uses a user-defined function that receives the to-be-validated dataframe + and uses the last column of the transformed dataframe to summarize the check + + Args: + column (str): Column(s) required for custom function + fn (Callable): A function that receives a dataframe as input and returns a dataframe with at least 1 column as result + pct (float): The threshold percentage required to pass + """ + + (Rule("is_custom", column, fn, CheckDataType.AGNOSTIC, pct) >> self._rule) + return self + def validate(self, dataframe: Any): """ Compute all rules in this check for specific data frame diff --git a/cuallee/pyspark_validation.py b/cuallee/pyspark_validation.py index 9b526532..ab022914 100644 --- a/cuallee/pyspark_validation.py +++ b/cuallee/pyspark_validation.py @@ -8,10 +8,10 @@ import pyspark.sql.types as T from pyspark.sql import Window as W from pyspark.sql import Column, DataFrame, Row -from toolz import first, valfilter # type: ignore +from toolz import first, valfilter, last # type: ignore import cuallee.utils as cuallee_utils -from cuallee import Check, ComputeEngine, Rule +from cuallee import Check, ComputeEngine, Rule, CustomComputeException import os @@ -587,6 +587,32 @@ def _execute(dataframe: DataFrame, key: str): return self.compute_instruction + def is_custom(self, rule: Rule): + """Validates dataframe by applying a custom function to the dataframe and resolving boolean values in the last column""" + + predicate = None + + def _execute(dataframe: DataFrame, key: str): + try: + assert isinstance(rule.value, Callable), "Please provide a Callable/Function for validation" + computed_frame = rule.value(dataframe) + assert isinstance(computed_frame, DataFrame), "Custom function does not return a PySpark DataFrame" + assert len(computed_frame.columns) >= 1, "Custom function should retun at least one column" + computed_column = last(computed_frame.columns) + return computed_frame.select( + F.sum(F.col(f"`{computed_column}`").cast("integer")).alias(key) + ) + + except Exception as err: + raise CustomComputeException(str(err)) + + + + self.compute_instruction = ComputeInstruction( + predicate, _execute, ComputeMethod.TRANSFORM + ) + + return self.compute_instruction def _field_type_filter( dataframe: DataFrame, @@ -769,6 +795,13 @@ def summary(check: Check, dataframe: DataFrame) -> DataFrame: # TODO: Check should have options for compute engine spark = SparkSession.builder.getOrCreate() + def _value(x): + """ Removes verbosity for Callable values""" + if isinstance(x, Callable): + return "f(x)" + else: + return str(x) + # Compute the expression computed_expressions = compute(check._rule) if (int(spark.version.replace(".", "")[:3]) < 330) or ( @@ -807,7 +840,7 @@ def summary(check: Check, dataframe: DataFrame) -> DataFrame: check.level.name, str(rule.column), str(rule.method), - str(rule.value), + _value(rule.value), int(check.rows), int(rule.violations), float(rule.pass_rate), diff --git a/cuallee/report/__init__.py b/cuallee/report/__init__.py index 3a865631..af4ec96b 100644 --- a/cuallee/report/__init__.py +++ b/cuallee/report/__init__.py @@ -1,11 +1,12 @@ from typing import List, Tuple from fpdf import FPDF -#from datetime import datetime, timezone + +# from datetime import datetime, timezone def pdf(data: List[Tuple[str]], name: str = "cuallee.pdf"): - #today = datetime.now(timezone.utc).strftime("%Y%m%d%H%M%S") - #style = FontFace(fill_color="#AAAAAA") + # today = datetime.now(timezone.utc).strftime("%Y%m%d%H%M%S") + # style = FontFace(fill_color="#AAAAAA") pdf = FPDF(orientation="landscape", format="A4") pdf.add_page() pdf.set_font("Helvetica", size=6) diff --git a/pyproject.toml b/pyproject.toml index bd6df983..83816670 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "cuallee" -version = "0.11.0" +version = "0.11.1" authors = [ { name="Herminio Vazquez", email="canimus@gmail.com"}, { name="Virginie Grosboillot", email="vestalisvirginis@gmail.com" } diff --git a/setup.cfg b/setup.cfg index 02f4c68f..ca26515c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [metadata] name = cuallee -version = 0.11.0 +version = 0.11.1 [options] packages = find: \ No newline at end of file diff --git a/test/performance/cuallee/test_performance_cuallee.py b/test/performance/cuallee/test_performance_cuallee.py index c4033e7f..7d475289 100755 --- a/test/performance/cuallee/test_performance_cuallee.py +++ b/test/performance/cuallee/test_performance_cuallee.py @@ -16,7 +16,7 @@ start = datetime.now() -check.validate(df).show(n=int(len(df.columns)*2), truncate=False) +check.validate(df).show(n=int(len(df.columns) * 2), truncate=False) end = datetime.now() elapsed = end - start print("START:", start) diff --git a/test/performance/greatexpectations/test_performance_gx.py b/test/performance/greatexpectations/test_performance_gx.py index d8371a8c..761965bf 100755 --- a/test/performance/greatexpectations/test_performance_gx.py +++ b/test/performance/greatexpectations/test_performance_gx.py @@ -12,8 +12,12 @@ start = datetime.now() -check_unique = [check.expect_column_values_to_be_unique(name).success for name in df.columns] -check_complete = [check.expect_column_values_to_not_be_null(name).success for name in df.columns] +check_unique = [ + check.expect_column_values_to_be_unique(name).success for name in df.columns +] +check_complete = [ + check.expect_column_values_to_not_be_null(name).success for name in df.columns +] end = datetime.now() print(check_unique + check_complete) diff --git a/test/performance/soda/test_performance_soda.py b/test/performance/soda/test_performance_soda.py index 1746c70c..26d2ac08 100755 --- a/test/performance/soda/test_performance_soda.py +++ b/test/performance/soda/test_performance_soda.py @@ -95,8 +95,11 @@ api_key_id: $soda_key api_key_secret: $soda_secret """ -scan.add_configuration_yaml_str(Template(config).substitute(soda_key=os.environ.get("SODA_KEY"), soda_secret=os.environ.get("SODA_SECRET"))) - +scan.add_configuration_yaml_str( + Template(config).substitute( + soda_key=os.environ.get("SODA_KEY"), soda_secret=os.environ.get("SODA_SECRET") + ) +) start = datetime.now() @@ -114,4 +117,4 @@ print("END:", end) print("ELAPSED:", elapsed) print("FRAMEWORK: soda") -spark.stop() \ No newline at end of file +spark.stop() diff --git a/test/unit/bigquery/test_is_complete.py b/test/unit/bigquery/test_is_complete.py index 5675bda2..e7ec9d30 100644 --- a/test/unit/bigquery/test_is_complete.py +++ b/test/unit/bigquery/test_is_complete.py @@ -12,7 +12,6 @@ def test_positive(): rs = check.validate(df) assert rs.status.str.match("PASS")[1] assert rs.violations[1] == 0 - def test_negative(): @@ -23,7 +22,6 @@ def test_negative(): assert rs.status.str.match("FAIL")[1] assert rs.violations[1] >= 1589 assert rs.pass_threshold[1] == 1.0 - # def test_parameters(): @@ -37,5 +35,5 @@ def test_coverage(): rs = check.validate(df) assert rs.status.str.match("PASS")[1] assert rs.violations[1] >= 1589 - #assert rs.pass_threshold[1] == 0.7 + # assert rs.pass_threshold[1] == 0.7 # assert rs.pass_rate[1] == 0.9999117752439066 # 207158222/207176656 diff --git a/test/unit/bigquery/test_is_daily.py b/test/unit/bigquery/test_is_daily.py index 1af070ef..6919cbed 100644 --- a/test/unit/bigquery/test_is_daily.py +++ b/test/unit/bigquery/test_is_daily.py @@ -11,7 +11,7 @@ def test_positive(): check = Check(CheckLevel.WARNING, "pytest") check.is_daily("trip_start_timestamp") rs = check.validate(df) - #assert rs.violations[1] > 1 + # assert rs.violations[1] > 1 def test_negative(): @@ -20,7 +20,7 @@ def test_negative(): check.is_daily("trip_end_timestamp") rs = check.validate(df) assert rs.status.str.match("FAIL")[1] - #assert rs.violations[1] >= 1 + # assert rs.violations[1] >= 1 # assert rs.pass_rate[1] <= 208914146 / 208943621 @@ -34,8 +34,8 @@ def test_parameters(rule_value): check = Check(CheckLevel.WARNING, "pytest") check.is_daily("trip_start_timestamp", rule_value) rs = check.validate(df) - #assert rs.status.str.match("FAIL")[1] - #assert rs.violations[1] > 0 + # assert rs.status.str.match("FAIL")[1] + # assert rs.violations[1] > 0 def test_coverage(): @@ -43,6 +43,6 @@ def test_coverage(): check = Check(CheckLevel.WARNING, "pytest") check.is_daily("trip_end_timestamp", pct=0.7) rs = check.validate(df) - #assert rs.status.str.match("PASS")[1] - #assert rs.pass_threshold[1] == 0.7 + # assert rs.status.str.match("PASS")[1] + # assert rs.pass_threshold[1] == 0.7 # assert rs.pass_rate[1] <= 208914146 / 208943621 diff --git a/test/unit/bigquery/test_is_unique.py b/test/unit/bigquery/test_is_unique.py index 19d3aff5..b6ffc343 100644 --- a/test/unit/bigquery/test_is_unique.py +++ b/test/unit/bigquery/test_is_unique.py @@ -12,7 +12,7 @@ def test_positive(): rs = check.validate(df) # assert rs.status.str.match("PASS")[1] # assert rs.violations[1] == 0 - #assert rs.pass_rate[1] == 1.0 + # assert rs.pass_rate[1] == 1.0 def test_negative(): @@ -20,9 +20,9 @@ def test_negative(): check = Check(CheckLevel.WARNING, "pytest") check.is_unique("taxi_id") rs = check.validate(df) - #assert rs.status.str.match("FAIL")[1] - #assert rs.violations[1] >= 102580503 - #assert rs.pass_threshold[1] == 1.0 + # assert rs.status.str.match("FAIL")[1] + # assert rs.violations[1] >= 102580503 + # assert rs.pass_threshold[1] == 1.0 # assert rs.pass_rate[1] == 9738 / 208943621 @@ -35,7 +35,7 @@ def test_coverage(): check = Check(CheckLevel.WARNING, "pytest") check.is_unique("taxi_id", 0.000007) rs = check.validate(df) - #assert rs.status.str.match("PASS")[1] - #assert rs.violations[1] >= 102580503 + # assert rs.status.str.match("PASS")[1] + # assert rs.violations[1] >= 102580503 # assert rs.pass_threshold[1] == 0.000007 # assert rs.pass_rate[1] == 9738 / 208943621 diff --git a/test/unit/bigquery/test_not_contained_in.py b/test/unit/bigquery/test_not_contained_in.py index 3159aa15..604bfd7b 100644 --- a/test/unit/bigquery/test_not_contained_in.py +++ b/test/unit/bigquery/test_not_contained_in.py @@ -11,9 +11,9 @@ def test_positive(): check = Check(CheckLevel.WARNING, "pytest") check.not_contained_in("payment_type", ["Dinero"]) rs = check.validate(df) - #assert rs.status.str.match("PASS")[1] - #assert rs.violations[1] == 0 - #assert rs.pass_rate[1] == 1.0 + # assert rs.status.str.match("PASS")[1] + # assert rs.violations[1] == 0 + # assert rs.pass_rate[1] == 1.0 def test_negative(): @@ -114,8 +114,8 @@ def test_parameters(column_name, rule_value): check = Check(CheckLevel.WARNING, "pytest") check.not_contained_in(column_name, rule_value) rs = check.validate(df) - #assert rs.status.str.match("FAIL")[1] - #assert rs.pass_rate[1] <= 1.0 + # assert rs.status.str.match("FAIL")[1] + # assert rs.pass_rate[1] <= 1.0 def test_coverage(): @@ -123,6 +123,6 @@ def test_coverage(): check = Check(CheckLevel.WARNING, "pytest") check.not_contained_in("payment_type", ("Dinero", "Metalico"), 0.7) rs = check.validate(df) - #assert rs.status.str.match("PASS")[1] + # assert rs.status.str.match("PASS")[1] # assert rs.violations[1] == 0 # assert rs.pass_threshold[1] == 0.7 diff --git a/test/unit/daft/test_are_complete.py b/test/unit/daft/test_are_complete.py index facf0a1d..87d3ff49 100644 --- a/test/unit/daft/test_are_complete.py +++ b/test/unit/daft/test_are_complete.py @@ -35,5 +35,8 @@ def test_coverage(check: Check): col_pass_rate = daft.col("pass_rate") assert ( - result.agg(col_pass_rate.max()).select(col_pass_rate == 0.75).to_pandas().pass_rate.all() + result.agg(col_pass_rate.max()) + .select(col_pass_rate == 0.75) + .to_pandas() + .pass_rate.all() ) diff --git a/test/unit/daft/test_are_unique.py b/test/unit/daft/test_are_unique.py index 4d07631c..2cd86507 100644 --- a/test/unit/daft/test_are_unique.py +++ b/test/unit/daft/test_are_unique.py @@ -36,5 +36,8 @@ def test_coverage(check: Check): col_pass_rate = daft.col("pass_rate") assert ( - result.agg(col_pass_rate.max()).select(col_pass_rate == 0.75).to_pandas().pass_rate.all() + result.agg(col_pass_rate.max()) + .select(col_pass_rate == 0.75) + .to_pandas() + .pass_rate.all() ) diff --git a/test/unit/daft/test_has_pattern.py b/test/unit/daft/test_has_pattern.py index a9bcc86d..d0c04950 100644 --- a/test/unit/daft/test_has_pattern.py +++ b/test/unit/daft/test_has_pattern.py @@ -49,5 +49,8 @@ def test_coverage(check: Check): assert result.select(daft.col("status").str.match("PASS")).to_pandas().status.all() col_pass_rate = daft.col("pass_rate") assert ( - result.agg(col_pass_rate.max()).select(col_pass_rate == 0.75).to_pandas().pass_rate.all() + result.agg(col_pass_rate.max()) + .select(col_pass_rate == 0.75) + .to_pandas() + .pass_rate.all() ) diff --git a/test/unit/daft/test_has_workflow.py b/test/unit/daft/test_has_workflow.py index 7079e35e..e7763809 100644 --- a/test/unit/daft/test_has_workflow.py +++ b/test/unit/daft/test_has_workflow.py @@ -33,5 +33,8 @@ def test_coverage(check: Check): col_pass_rate = daft.col("pass_rate") assert ( - result.agg(col_pass_rate.max()).select(col_pass_rate == 4/6).to_pandas().pass_rate.all() + result.agg(col_pass_rate.max()) + .select(col_pass_rate == 4 / 6) + .to_pandas() + .pass_rate.all() ) diff --git a/test/unit/daft/test_is_between.py b/test/unit/daft/test_is_between.py index 1a095380..3d7ba6f9 100644 --- a/test/unit/daft/test_is_between.py +++ b/test/unit/daft/test_is_between.py @@ -48,5 +48,8 @@ def test_coverage(check: Check): col_pass_rate = daft.col("pass_rate") assert ( - result.agg(col_pass_rate.max()).select(col_pass_rate == 0.55).to_pandas().pass_rate.all() + result.agg(col_pass_rate.max()) + .select(col_pass_rate == 0.55) + .to_pandas() + .pass_rate.all() ) diff --git a/test/unit/daft/test_is_complete.py b/test/unit/daft/test_is_complete.py index a82e7c72..6984ae1d 100644 --- a/test/unit/daft/test_is_complete.py +++ b/test/unit/daft/test_is_complete.py @@ -23,5 +23,8 @@ def test_coverage(check: Check): assert result.select(daft.col("status").str.match("PASS")).to_pandas().status.all() col_pass_rate = daft.col("pass_rate") assert ( - result.agg(col_pass_rate.max()).select(col_pass_rate == 0.50).to_pandas().pass_rate.all() + result.agg(col_pass_rate.max()) + .select(col_pass_rate == 0.50) + .to_pandas() + .pass_rate.all() ) diff --git a/test/unit/daft/test_is_contained_in.py b/test/unit/daft/test_is_contained_in.py index 5f20da78..3dd92e5a 100644 --- a/test/unit/daft/test_is_contained_in.py +++ b/test/unit/daft/test_is_contained_in.py @@ -35,5 +35,8 @@ def test_coverage(check: Check): assert result.select(daft.col("status").str.match("PASS")).to_pandas().status.all() col_pass_rate = daft.col("pass_rate") assert ( - result.agg(col_pass_rate.max()).select(col_pass_rate == 0.50).to_pandas().pass_rate.all() + result.agg(col_pass_rate.max()) + .select(col_pass_rate == 0.50) + .to_pandas() + .pass_rate.all() ) diff --git a/test/unit/daft/test_is_daily.py b/test/unit/daft/test_is_daily.py index 9fc46387..b2afb0a5 100644 --- a/test/unit/daft/test_is_daily.py +++ b/test/unit/daft/test_is_daily.py @@ -23,7 +23,7 @@ def test_negative(check: Check): ) df = daft.from_pandas(pd_df) result = check.validate(df) - #assert result.select(daft.col("status").str.match("FAIL")).to_pandas().status.all() + # assert result.select(daft.col("status").str.match("FAIL")).to_pandas().status.all() def test_coverage(check: Check): @@ -36,5 +36,8 @@ def test_coverage(check: Check): assert result.select(daft.col("status").str.match("PASS")).to_pandas().status.all() col_pass_rate = daft.col("pass_rate") assert ( - result.agg(col_pass_rate.max()).select(col_pass_rate == 0.60).to_pandas().pass_rate.all() + result.agg(col_pass_rate.max()) + .select(col_pass_rate == 0.60) + .to_pandas() + .pass_rate.all() ) diff --git a/test/unit/daft/test_is_empty.py b/test/unit/daft/test_is_empty.py index f55d127a..e3ce3f2f 100644 --- a/test/unit/daft/test_is_empty.py +++ b/test/unit/daft/test_is_empty.py @@ -23,5 +23,8 @@ def test_coverage(check: Check): assert result.select(daft.col("status").str.match("PASS")).to_pandas().status.all() col_pass_rate = daft.col("pass_rate") assert ( - result.agg(col_pass_rate.max()).select(col_pass_rate == 0.50).to_pandas().pass_rate.all() + result.agg(col_pass_rate.max()) + .select(col_pass_rate == 0.50) + .to_pandas() + .pass_rate.all() ) diff --git a/test/unit/daft/test_is_equal_than.py b/test/unit/daft/test_is_equal_than.py index 1dbbb542..f240a99c 100644 --- a/test/unit/daft/test_is_equal_than.py +++ b/test/unit/daft/test_is_equal_than.py @@ -34,6 +34,8 @@ def test_coverage(check: Check): assert result.select(daft.col("status").str.match("PASS")).to_pandas().status.all() col_pass_rate = daft.col("pass_rate") assert ( - result.agg(col_pass_rate.max()).select(col_pass_rate == 0.75).to_pandas().pass_rate.all() + result.agg(col_pass_rate.max()) + .select(col_pass_rate == 0.75) + .to_pandas() + .pass_rate.all() ) - diff --git a/test/unit/daft/test_is_greater_or_equal_than.py b/test/unit/daft/test_is_greater_or_equal_than.py index 817e4479..dbca30dd 100644 --- a/test/unit/daft/test_is_greater_or_equal_than.py +++ b/test/unit/daft/test_is_greater_or_equal_than.py @@ -34,6 +34,8 @@ def test_coverage(check: Check): assert result.select(daft.col("status").str.match("PASS")).to_pandas().status.all() col_pass_rate = daft.col("pass_rate") assert ( - result.agg(col_pass_rate.max()).select(col_pass_rate == 0.60).to_pandas().pass_rate.all() + result.agg(col_pass_rate.max()) + .select(col_pass_rate == 0.60) + .to_pandas() + .pass_rate.all() ) - diff --git a/test/unit/daft/test_is_greater_than.py b/test/unit/daft/test_is_greater_than.py index 19d92228..f1b289f0 100644 --- a/test/unit/daft/test_is_greater_than.py +++ b/test/unit/daft/test_is_greater_than.py @@ -34,5 +34,8 @@ def test_coverage(check: Check): assert result.select(daft.col("status").str.match("PASS")).to_pandas().status.all() col_pass_rate = daft.col("pass_rate") assert ( - result.agg(col_pass_rate.max()).select(col_pass_rate == 0.50).to_pandas().pass_rate.all() + result.agg(col_pass_rate.max()) + .select(col_pass_rate == 0.50) + .to_pandas() + .pass_rate.all() ) diff --git a/test/unit/daft/test_is_in_billions.py b/test/unit/daft/test_is_in_billions.py index 1c0bd234..29430edd 100644 --- a/test/unit/daft/test_is_in_billions.py +++ b/test/unit/daft/test_is_in_billions.py @@ -25,5 +25,8 @@ def test_coverage(check: Check): col_pass_rate = daft.col("pass_rate") assert ( - result.agg(col_pass_rate.max()).select(col_pass_rate == 0.50).to_pandas().pass_rate.all() + result.agg(col_pass_rate.max()) + .select(col_pass_rate == 0.50) + .to_pandas() + .pass_rate.all() ) diff --git a/test/unit/daft/test_is_in_millions.py b/test/unit/daft/test_is_in_millions.py index f5b5e273..671be561 100644 --- a/test/unit/daft/test_is_in_millions.py +++ b/test/unit/daft/test_is_in_millions.py @@ -25,5 +25,8 @@ def test_coverage(check: Check): col_pass_rate = daft.col("pass_rate") assert ( - result.agg(col_pass_rate.max()).select(col_pass_rate == 0.50).to_pandas().pass_rate.all() + result.agg(col_pass_rate.max()) + .select(col_pass_rate == 0.50) + .to_pandas() + .pass_rate.all() ) diff --git a/test/unit/daft/test_is_inside_interquartile_range.py b/test/unit/daft/test_is_inside_interquartile_range.py index cbcd0d38..87f24567 100644 --- a/test/unit/daft/test_is_inside_interquartile_range.py +++ b/test/unit/daft/test_is_inside_interquartile_range.py @@ -182,5 +182,8 @@ def test_coverage(check: Check): assert result.select(daft.col("status").str.match("PASS")).to_pandas().status.all() col_pass_rate = daft.col("pass_rate") assert ( - result.agg(col_pass_rate.max()).select(col_pass_rate == 0.50).to_pandas().pass_rate.all() + result.agg(col_pass_rate.max()) + .select(col_pass_rate == 0.50) + .to_pandas() + .pass_rate.all() ) diff --git a/test/unit/daft/test_is_less_or_equal_than.py b/test/unit/daft/test_is_less_or_equal_than.py index 8e7f14f6..ef08e4c8 100644 --- a/test/unit/daft/test_is_less_or_equal_than.py +++ b/test/unit/daft/test_is_less_or_equal_than.py @@ -34,6 +34,8 @@ def test_coverage(check: Check): assert result.select(daft.col("status").str.match("PASS")).to_pandas().status.all() col_pass_rate = daft.col("pass_rate") assert ( - result.agg(col_pass_rate.max()).select(col_pass_rate == 0.60).to_pandas().pass_rate.all() + result.agg(col_pass_rate.max()) + .select(col_pass_rate == 0.60) + .to_pandas() + .pass_rate.all() ) - diff --git a/test/unit/daft/test_is_less_than.py b/test/unit/daft/test_is_less_than.py index 02fc2ef2..d0f369ba 100644 --- a/test/unit/daft/test_is_less_than.py +++ b/test/unit/daft/test_is_less_than.py @@ -34,6 +34,8 @@ def test_coverage(check: Check): assert result.select(daft.col("status").str.match("PASS")).to_pandas().status.all() col_pass_rate = daft.col("pass_rate") assert ( - result.agg(col_pass_rate.max()).select(col_pass_rate == 0.60).to_pandas().pass_rate.all() + result.agg(col_pass_rate.max()) + .select(col_pass_rate == 0.60) + .to_pandas() + .pass_rate.all() ) - diff --git a/test/unit/daft/test_is_negative.py b/test/unit/daft/test_is_negative.py index e8e20e5c..e5c250fd 100644 --- a/test/unit/daft/test_is_negative.py +++ b/test/unit/daft/test_is_negative.py @@ -36,5 +36,8 @@ def test_coverage(check: Check): col_pass_rate = daft.col("pass_rate") assert ( - result.agg(col_pass_rate.max()).select(col_pass_rate == 0.50).to_pandas().pass_rate.all() + result.agg(col_pass_rate.max()) + .select(col_pass_rate == 0.50) + .to_pandas() + .pass_rate.all() ) diff --git a/test/unit/daft/test_is_on_friday.py b/test/unit/daft/test_is_on_friday.py index f3c777d7..a96ab448 100644 --- a/test/unit/daft/test_is_on_friday.py +++ b/test/unit/daft/test_is_on_friday.py @@ -39,6 +39,8 @@ def test_coverage(check: Check): col_pass_rate = daft.col("pass_rate") assert ( - result.agg(col_pass_rate.max()).select(col_pass_rate == 1 / 7).to_pandas().pass_rate.all() + result.agg(col_pass_rate.max()) + .select(col_pass_rate == 1 / 7) + .to_pandas() + .pass_rate.all() ) - diff --git a/test/unit/daft/test_is_on_monday.py b/test/unit/daft/test_is_on_monday.py index 94124871..d8906c30 100644 --- a/test/unit/daft/test_is_on_monday.py +++ b/test/unit/daft/test_is_on_monday.py @@ -39,6 +39,8 @@ def test_coverage(check: Check): col_pass_rate = daft.col("pass_rate") assert ( - result.agg(col_pass_rate.max()).select(col_pass_rate == 1/7).to_pandas().pass_rate.all() + result.agg(col_pass_rate.max()) + .select(col_pass_rate == 1 / 7) + .to_pandas() + .pass_rate.all() ) - diff --git a/test/unit/daft/test_is_on_saturday.py b/test/unit/daft/test_is_on_saturday.py index 398d5a42..a0fb75ab 100644 --- a/test/unit/daft/test_is_on_saturday.py +++ b/test/unit/daft/test_is_on_saturday.py @@ -39,5 +39,8 @@ def test_coverage(check: Check): col_pass_rate = daft.col("pass_rate") assert ( - result.agg(col_pass_rate.max()).select(col_pass_rate == 1 / 7).to_pandas().pass_rate.all() - ) \ No newline at end of file + result.agg(col_pass_rate.max()) + .select(col_pass_rate == 1 / 7) + .to_pandas() + .pass_rate.all() + ) diff --git a/test/unit/daft/test_is_on_schedule.py b/test/unit/daft/test_is_on_schedule.py index 844a74c4..c1e28199 100644 --- a/test/unit/daft/test_is_on_schedule.py +++ b/test/unit/daft/test_is_on_schedule.py @@ -78,5 +78,8 @@ def test_coverage(check: Check): col_pass_rate = daft.col("pass_rate") assert ( - result.agg(col_pass_rate.max()).select(col_pass_rate == 7 / 8).to_pandas().pass_rate.all() + result.agg(col_pass_rate.max()) + .select(col_pass_rate == 7 / 8) + .to_pandas() + .pass_rate.all() ) diff --git a/test/unit/daft/test_is_on_sunday.py b/test/unit/daft/test_is_on_sunday.py index a076b31e..79da1188 100644 --- a/test/unit/daft/test_is_on_sunday.py +++ b/test/unit/daft/test_is_on_sunday.py @@ -39,5 +39,8 @@ def test_coverage(check: Check): col_pass_rate = daft.col("pass_rate") assert ( - result.agg(col_pass_rate.max()).select(col_pass_rate == 1 / 7).to_pandas().pass_rate.all() + result.agg(col_pass_rate.max()) + .select(col_pass_rate == 1 / 7) + .to_pandas() + .pass_rate.all() ) diff --git a/test/unit/daft/test_is_on_thursday.py b/test/unit/daft/test_is_on_thursday.py index d5138ddd..ea1cf84e 100644 --- a/test/unit/daft/test_is_on_thursday.py +++ b/test/unit/daft/test_is_on_thursday.py @@ -39,5 +39,8 @@ def test_coverage(check: Check): col_pass_rate = daft.col("pass_rate") assert ( - result.agg(col_pass_rate.max()).select(col_pass_rate == 1 / 7).to_pandas().pass_rate.all() - ) \ No newline at end of file + result.agg(col_pass_rate.max()) + .select(col_pass_rate == 1 / 7) + .to_pandas() + .pass_rate.all() + ) diff --git a/test/unit/daft/test_is_on_tuesday.py b/test/unit/daft/test_is_on_tuesday.py index ab4b5bd1..11d3abe1 100644 --- a/test/unit/daft/test_is_on_tuesday.py +++ b/test/unit/daft/test_is_on_tuesday.py @@ -39,5 +39,8 @@ def test_coverage(check: Check): col_pass_rate = daft.col("pass_rate") assert ( - result.agg(col_pass_rate.max()).select(col_pass_rate == 1 / 7).to_pandas().pass_rate.all() + result.agg(col_pass_rate.max()) + .select(col_pass_rate == 1 / 7) + .to_pandas() + .pass_rate.all() ) diff --git a/test/unit/daft/test_is_on_wednesday.py b/test/unit/daft/test_is_on_wednesday.py index dd145f25..a79aa4ee 100644 --- a/test/unit/daft/test_is_on_wednesday.py +++ b/test/unit/daft/test_is_on_wednesday.py @@ -39,5 +39,8 @@ def test_coverage(check: Check): col_pass_rate = daft.col("pass_rate") assert ( - result.agg(col_pass_rate.max()).select(col_pass_rate == 1 / 7).to_pandas().pass_rate.all() + result.agg(col_pass_rate.max()) + .select(col_pass_rate == 1 / 7) + .to_pandas() + .pass_rate.all() ) diff --git a/test/unit/daft/test_is_on_weekday.py b/test/unit/daft/test_is_on_weekday.py index 52763a1b..b31e584e 100644 --- a/test/unit/daft/test_is_on_weekday.py +++ b/test/unit/daft/test_is_on_weekday.py @@ -39,5 +39,8 @@ def test_coverage(check: Check): col_pass_rate = daft.col("pass_rate") assert ( - result.agg(col_pass_rate.max()).select(col_pass_rate == 5 / 7).to_pandas().pass_rate.all() + result.agg(col_pass_rate.max()) + .select(col_pass_rate == 5 / 7) + .to_pandas() + .pass_rate.all() ) diff --git a/test/unit/daft/test_is_on_weekend.py b/test/unit/daft/test_is_on_weekend.py index bc90e74e..fe516a7e 100644 --- a/test/unit/daft/test_is_on_weekend.py +++ b/test/unit/daft/test_is_on_weekend.py @@ -39,5 +39,8 @@ def test_coverage(check: Check): col_pass_rate = daft.col("pass_rate") assert ( - result.agg(col_pass_rate.max()).select(col_pass_rate == 2 / 7).to_pandas().pass_rate.all() + result.agg(col_pass_rate.max()) + .select(col_pass_rate == 2 / 7) + .to_pandas() + .pass_rate.all() ) diff --git a/test/unit/daft/test_is_positive.py b/test/unit/daft/test_is_positive.py index 055843eb..e014e072 100644 --- a/test/unit/daft/test_is_positive.py +++ b/test/unit/daft/test_is_positive.py @@ -37,5 +37,8 @@ def test_coverage(check: Check): col_pass_rate = daft.col("pass_rate") assert ( - result.agg(col_pass_rate.max()).select(col_pass_rate == 0.50).to_pandas().pass_rate.all() + result.agg(col_pass_rate.max()) + .select(col_pass_rate == 0.50) + .to_pandas() + .pass_rate.all() ) diff --git a/test/unit/daft/test_is_unique.py b/test/unit/daft/test_is_unique.py index 40449dd4..996a8bb8 100644 --- a/test/unit/daft/test_is_unique.py +++ b/test/unit/daft/test_is_unique.py @@ -38,5 +38,8 @@ def test_coverage(check: Check): col_pass_rate = daft.col("pass_rate") assert ( - result.agg(col_pass_rate.max()).select(col_pass_rate == 0.75).to_pandas().pass_rate.all() + result.agg(col_pass_rate.max()) + .select(col_pass_rate == 0.75) + .to_pandas() + .pass_rate.all() ) diff --git a/test/unit/daft/test_not_contained_in.py b/test/unit/daft/test_not_contained_in.py index 1891a78a..99a494a4 100644 --- a/test/unit/daft/test_not_contained_in.py +++ b/test/unit/daft/test_not_contained_in.py @@ -36,5 +36,8 @@ def test_coverage(check: Check): col_pass_rate = daft.col("pass_rate") assert ( - result.agg(col_pass_rate.max()).select(col_pass_rate == 0.50).to_pandas().pass_rate.all() + result.agg(col_pass_rate.max()) + .select(col_pass_rate == 0.50) + .to_pandas() + .pass_rate.all() ) diff --git a/test/unit/daft/test_satisfies.py b/test/unit/daft/test_satisfies.py index 69d50e58..d498d88d 100644 --- a/test/unit/daft/test_satisfies.py +++ b/test/unit/daft/test_satisfies.py @@ -26,5 +26,8 @@ def test_coverage(check: Check): assert result.select(daft.col("status").str.match("PASS")).to_pandas().status.all() col_pass_rate = daft.col("pass_rate") assert ( - result.agg(col_pass_rate.max()).select(col_pass_rate == 0.50).to_pandas().pass_rate.all() + result.agg(col_pass_rate.max()) + .select(col_pass_rate == 0.50) + .to_pandas() + .pass_rate.all() ) diff --git a/test/unit/pandas_dataframe/test_is_empty.py b/test/unit/pandas_dataframe/test_is_empty.py index 2d14113c..115c53a1 100644 --- a/test/unit/pandas_dataframe/test_is_empty.py +++ b/test/unit/pandas_dataframe/test_is_empty.py @@ -20,4 +20,3 @@ def test_coverage(check: Check): df = pd.DataFrame({"id": [10, None], "id2": [None, "test"]}) result = check.validate(df) assert result.status.str.match("PASS").all() - diff --git a/test/unit/pyspark_dataframe/test_is_custom.py b/test/unit/pyspark_dataframe/test_is_custom.py new file mode 100644 index 00000000..db9a6b7d --- /dev/null +++ b/test/unit/pyspark_dataframe/test_is_custom.py @@ -0,0 +1,45 @@ +import pytest + +from cuallee import Check, CheckLevel, CustomComputeException +import pyspark.sql.functions as F + + +def test_positive(spark): + df = spark.range(10) + check = Check(CheckLevel.WARNING, "pytest") + check.is_custom("id", lambda x: x.withColumn("test", F.col("id") >= 0)) + rs = check.validate(df) + assert rs.first().status == "PASS" + assert rs.first().violations == 0 + assert rs.first().pass_threshold == 1.0 + + +def test_negative(spark): + df = spark.range(10) + check = Check(CheckLevel.WARNING, "pytest") + check.is_custom("id", lambda x: x.withColumn("test", F.col("id") >= 5)) + rs = check.validate(df) + assert rs.first().status == "FAIL" + assert rs.first().violations == 5 + assert rs.first().pass_threshold == 1.0 + + +def test_parameters(spark): + df = spark.range(10) + with pytest.raises( + CustomComputeException, match="Please provide a Callable/Function for validation" + ): + check = Check(CheckLevel.WARNING, "pytest") + check.is_custom("id", "wrong value") + check.validate(df) + + + + +def test_coverage(spark): + df = spark.range(10) + check = Check(CheckLevel.WARNING, "pytest") + check.is_custom("id", lambda x: x.withColumn("test", F.col("id") >= 5), 0.4) + rs = check.validate(df) + assert rs.first().status == "PASS" + assert rs.first().pass_threshold == 0.4 diff --git a/test/unit/pyspark_dataframe/test_is_empty.py b/test/unit/pyspark_dataframe/test_is_empty.py index 23c7abb4..ff2747a7 100644 --- a/test/unit/pyspark_dataframe/test_is_empty.py +++ b/test/unit/pyspark_dataframe/test_is_empty.py @@ -4,7 +4,9 @@ def test_positive(spark): - df = spark.createDataFrame([[None], [None], [None], [None], [None]], schema="id int") + df = spark.createDataFrame( + [[None], [None], [None], [None], [None]], schema="id int" + ) check = Check(CheckLevel.WARNING, "pytest") check.is_empty("id") rs = check.validate(df) @@ -29,7 +31,6 @@ def test_negative(spark, data, violation, pass_rate): assert rs.first().status == "FAIL" assert rs.first().violations == violation assert rs.first().pass_threshold == 1.0 - def test_parameters(): @@ -43,4 +44,3 @@ def test_coverage(spark): rs = check.validate(df) assert rs.first().status == "PASS" assert rs.first().pass_threshold == 0.1 -