Skip to content

Commit

Permalink
Reorder to allow type coercion for Unions
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniel Tom committed Nov 24, 2023
1 parent 4e0a368 commit 7ff7b3d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/pydantic_spark/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,18 @@ def get_type(value: dict) -> Tuple[str, dict]:
a = value.get("additionalProperties")
ft = value.get("coerce_type")
metadata = {}

if ft is not None:
return ft, metadata

if ao is not None:
if len(ao) == 2 and (ao[0].get("type") == "null" or ao[1].get("type") == "null"):
# this is an optional column. We will remove the null type
t = ao[0].get("type") if ao[0].get("type") != "null" else ao[1].get("type")
f = ao[0].get("format") if ao[0].get("type") != "null" else ao[1].get("format")
else:
NotImplementedError(f"Union type {ao} is not supported yet")

if "default" in value:
metadata["default"] = value.get("default")
if r is not None:
Expand All @@ -80,8 +85,6 @@ def get_type(value: dict) -> Tuple[str, dict]:
else:
spark_type = get_type_of_definition(r, schema)
classes_seen[class_name] = spark_type
elif ft is not None:
spark_type = ft
elif t == "array":
items = value.get("items")
tn, metadata = get_type(items)
Expand Down
2 changes: 2 additions & 0 deletions tests/test_to_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,11 @@ def test_enum():
def test_coerce_type():
class TestCoerceType(SparkBase):
c1: int = Field(json_schema_extra={"coerce_type": CoerceType.integer})
c2: str | int = Field(json_schema_extra={"coerce_type": CoerceType.string})

result = TestCoerceType.spark_schema()
assert result["fields"][0]["type"] == "integer"
assert result["fields"][1]["type"] == "string"


class Nested2ModelCoerceType(SparkBase):
Expand Down

0 comments on commit 7ff7b3d

Please sign in to comment.