diff --git a/README.md b/README.md index fe0ccf1..9c1910d 100644 --- a/README.md +++ b/README.md @@ -231,6 +231,7 @@ env.str("EMAIL", validate=[Length(min=4), Email()]) By default, a validation error is raised immediately upon calling a parser method for an invalid environment variable. To defer validation and raise an exception with the combined error messages for all invalid variables, pass `eager=False` to `Env`. +Unvalidated variables may have any type, usually `str | None`. Call `env.seal()` after all variables have been parsed. ```python diff --git a/src/environs/__init__.py b/src/environs/__init__.py index b4aeb2e..ac221af 100644 --- a/src/environs/__init__.py +++ b/src/environs/__init__.py @@ -8,9 +8,23 @@ import re import typing from collections.abc import Mapping +from datetime import ( + date as _date, +) +from datetime import ( + datetime as _datetime, +) +from datetime import ( + time as _time, +) +from datetime import ( + timedelta as _timedelta, +) +from decimal import Decimal from enum import Enum from pathlib import Path from urllib.parse import ParseResult, urlparse +from uuid import UUID import marshmallow as ma from dotenv.main import _walk_to_root, load_dotenv @@ -27,7 +41,36 @@ Subcast = typing.Union[typing.Type, typing.Callable[..., _T], ma.fields.Field] FieldType = typing.Type[ma.fields.Field] FieldOrFactory = typing.Union[FieldType, FieldFactory] -ParserMethod = typing.Callable + + +_int = int +_bool = bool +_str = str +_float = float +_list = typing.List[typing.Any] +_dict = typing.Dict[str, typing.Any] + + +class _ParserMethod(typing.Generic[_T]): + """Duck typing, do not use""" + + @typing.overload # type: ignore[no-overload-impl] + def __call__( + self, + name: str, + default: None, + subcast: typing.Optional[Subcast] = None, + **kwargs, + ) -> typing.Optional[_T]: ... + + @typing.overload + def __call__( + self, + name: str, + default: typing.Union[ma.utils._Missing, _T] = ma.missing, + subcast: typing.Optional[Subcast] = None, + **kwargs, + ) -> _T: ... _EXPANDED_VAR_PATTERN = re.compile(r"(? ParserMethod: +): def method( self: "Env", name: str, @@ -140,7 +183,7 @@ def method( return method -def _func2method(func: typing.Callable, method_name: str) -> ParserMethod: +def _func2method(func: typing.Callable, method_name: str): def method( self: "Env", name: str, @@ -359,20 +402,20 @@ def _format_num(self, value) -> int: class Env: """An environment variable reader.""" - __call__: ParserMethod = _field2method(ma.fields.Field, "__call__") + __call__: _ParserMethod[str] = _field2method(ma.fields.Field, "__call__") - int = _field2method(ma.fields.Int, "int") - bool = _field2method(ma.fields.Bool, "bool") - str = _field2method(ma.fields.Str, "str") - float = _field2method(ma.fields.Float, "float") - decimal = _field2method(ma.fields.Decimal, "decimal") - list = _field2method( + int: _ParserMethod[_int] = _field2method(ma.fields.Int, "int") + bool: _ParserMethod[_bool] = _field2method(ma.fields.Bool, "bool") + str: _ParserMethod[_str] = _field2method(ma.fields.Str, "str") + float: _ParserMethod[_float] = _field2method(ma.fields.Float, "float") + decimal: _ParserMethod[Decimal] = _field2method(ma.fields.Decimal, "decimal") + list: _ParserMethod[_list] = _field2method( _make_list_field, "list", preprocess=_preprocess_list, preprocess_kwarg_names=("subcast", "delimiter"), ) - dict = _field2method( + dict: _ParserMethod[_dict] = _field2method( ma.fields.Dict, "dict", preprocess=_preprocess_dict, @@ -384,19 +427,29 @@ class Env: "delimiter", ), ) - json = _field2method(ma.fields.Field, "json", preprocess=_preprocess_json) - datetime = _field2method(ma.fields.DateTime, "datetime") - date = _field2method(ma.fields.Date, "date") - time = _field2method(ma.fields.Time, "time") - path = _field2method(PathField, "path") - log_level = _field2method(LogLevelField, "log_level") - timedelta = _field2method(ma.fields.TimeDelta, "timedelta") - uuid = _field2method(ma.fields.UUID, "uuid") - url = _field2method(URLField, "url") - enum = _func2method(_enum_parser, "enum") - dj_db_url = _func2method(_dj_db_url_parser, "dj_db_url") - dj_email_url = _func2method(_dj_email_url_parser, "dj_email_url") - dj_cache_url = _func2method(_dj_cache_url_parser, "dj_cache_url") + json: _ParserMethod[_dict] = _field2method( + ma.fields.Field, "json", preprocess=_preprocess_json + ) + datetime: _ParserMethod[_datetime] = _field2method(ma.fields.DateTime, "datetime") + date: _ParserMethod[_date] = _field2method(ma.fields.Date, "date") + time: _ParserMethod[_time] = _field2method(ma.fields.Time, "time") + path: _ParserMethod[Path] = _field2method(PathField, "path") + log_level: _ParserMethod[_int] = _field2method(LogLevelField, "log_level") + timedelta: _ParserMethod[_timedelta] = _field2method( + ma.fields.TimeDelta, "timedelta" + ) + uuid: _ParserMethod[UUID] = _field2method(ma.fields.UUID, "uuid") + url: _ParserMethod[_str] = _field2method(URLField, "url") + enum: _ParserMethod[Enum] = _func2method(_enum_parser, "enum") + dj_db_url: _ParserMethod[typing.Dict[_str, _str]] = _func2method( + _dj_db_url_parser, "dj_db_url" + ) + dj_email_url: _ParserMethod[typing.Dict[_str, _str]] = _func2method( + _dj_email_url_parser, "dj_email_url" + ) + dj_cache_url: _ParserMethod[typing.Dict[_str, _str]] = _func2method( + _dj_cache_url_parser, "dj_cache_url" + ) def __init__(self, *, eager: _BoolType = True, expand_vars: _BoolType = False): self.eager = eager @@ -406,7 +459,7 @@ def __init__(self, *, eager: _BoolType = True, expand_vars: _BoolType = False): self._values: typing.Dict[_StrType, typing.Any] = {} self._errors: ErrorMapping = collections.defaultdict(list) self._prefix: typing.Optional[_StrType] = None - self.__custom_parsers__: typing.Dict[_StrType, ParserMethod] = {} + self.__custom_parsers__: typing.Dict[_StrType, _ParserMethod] = {} def __repr__(self) -> _StrType: return f"<{self.__class__.__name__}(eager={self.eager}, expand_vars={self.expand_vars})>" # noqa: E501