Skip to content

Commit

Permalink
Infra: apply sourcery (#1892)
Browse files Browse the repository at this point in the history
* chore: update deps

* chore: applied sourcery

* chore: fix tests

* chore: fix remaining issues

* chore: handle flaky test
  • Loading branch information
Goldziher authored Jun 28, 2023
1 parent 8492653 commit e5f2b64
Show file tree
Hide file tree
Showing 110 changed files with 544 additions and 631 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ repos:
- id: slotscheck
exclude: "test_*|docs"
- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v1.4.0"
rev: "v1.4.1"
hooks:
- id: mypy
exclude: "test_apps|tools|docs|tests/examples|tests/docker_service_fixtures"
Expand Down Expand Up @@ -119,7 +119,7 @@ repos:
prometheus_client,
]
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.315
rev: v1.1.316
hooks:
- id: pyright
exclude: "test_apps|tools|docs|_openapi|tests/examples|tests/docker_service_fixtures"
Expand Down
18 changes: 9 additions & 9 deletions .sourcery.yaml
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
ignore:
- .git
- venv
- .venv
- env
- .env
- .tox
- node_modules
- vendor
- .tox/
- .venv/
- dist/
- docs/_build/
- docs/_static/
- node_modules/
- vendor/
- venv/

rule_settings:
enable: [default]
disable: []
disable: [dont-import-test-modules]
rule_types:
- refactoring
- suggestion
Expand Down
7 changes: 1 addition & 6 deletions docs/examples/contrib/jwt/using_jwt_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,7 @@ async def retrieve_user_handler(token: Token, connection: "ASGIConnection[Any, A
@post("/login")
async def login_handler(data: User) -> Response[User]:
MOCK_DB[str(data.id)] = data
response = jwt_auth.login(identifier=str(data.id), response_body=data)

# you can do whatever you want to update the response instance here
# e.g. response.set_cookie(...)

return response
return jwt_auth.login(identifier=str(data.id), response_body=data)


# We also have some other routes, for example:
Expand Down
7 changes: 1 addition & 6 deletions docs/examples/contrib/jwt/using_jwt_cookie_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,7 @@ async def retrieve_user_handler(token: "Token", connection: "ASGIConnection[Any,
@post("/login")
async def login_handler(data: "User") -> "Response[User]":
MOCK_DB[str(data.id)] = data
response = jwt_cookie_auth.login(identifier=str(data.id), response_body=data)

# you can do whatever you want to update the response instance here
# e.g. response.set_cookie(...)

return response
return jwt_cookie_auth.login(identifier=str(data.id), response_body=data)


# We also have some other routes, for example:
Expand Down
16 changes: 2 additions & 14 deletions docs/examples/contrib/jwt/using_oauth2_password_bearer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,26 +48,14 @@ async def retrieve_user_handler(token: "Token", connection: "ASGIConnection[Any,
@post("/login")
async def login_handler(request: "Request[Any, Any, Any]", data: "User") -> "Response[OAuth2Login]":
MOCK_DB[str(data.id)] = data
# if we do not define a response body, the login process will return a standard OAuth2 login response. Note the `Response[OAuth2Login]` return type.
response = oauth2_auth.login(identifier=str(data.id))

# you can do whatever you want to update the response instance here
# e.g. response.set_cookie(...)

return response
return oauth2_auth.login(identifier=str(data.id))


@post("/login_custom")
async def login_custom_response_handler(data: "User") -> "Response[User]":
MOCK_DB[str(data.id)] = data

# If you'd like to define a custom response body, use the `response_body` parameter. Note the `Response[User]` return type.
response = oauth2_auth.login(identifier=str(data.id), response_body=data)

# you can do whatever you want to update the response instance here
# e.g. response.set_cookie(...)

return response
return oauth2_auth.login(identifier=str(data.id), response_body=data)


# We also have some other routes, for example:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ async def get_todo_by_title(todo_name, session: AsyncSession) -> TodoItem:
result = await session.execute(query)
try:
return result.scalar_one()
except NoResultFound:
raise NotFoundException(detail=f"TODO {todo_name!r} not found")
except NoResultFound as e:
raise NotFoundException(detail=f"TODO {todo_name!r} not found") from e


async def get_todo_list(done: Optional[bool], session: AsyncSession) -> List[TodoItem]:
Expand All @@ -77,11 +77,11 @@ async def add_item(data: TodoType, state: State) -> TodoType:
try:
async with session.begin():
session.add(new_todo)
except IntegrityError:
except IntegrityError as e:
raise ClientException(
status_code=HTTP_409_CONFLICT,
detail=f"TODO {new_todo.title!r} already exists",
)
) from e

return serialize_todo(new_todo)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,16 @@ async def provide_transaction(db_session: AsyncSession) -> AsyncGenerator[AsyncS
raise ClientException(
status_code=HTTP_409_CONFLICT,
detail=str(exc),
)
) from exc


async def get_todo_by_title(todo_name, session: AsyncSession) -> TodoItem:
query = select(TodoItem).where(TodoItem.title == todo_name)
result = await session.execute(query)
try:
return result.scalar_one()
except NoResultFound:
raise NotFoundException(detail=f"TODO {todo_name!r} not found")
except NoResultFound as e:
raise NotFoundException(detail=f"TODO {todo_name!r} not found") from e


async def get_todo_list(done: Optional[bool], session: AsyncSession) -> List[TodoItem]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,16 @@ async def provide_transaction(db_session: AsyncSession) -> AsyncGenerator[AsyncS
raise ClientException(
status_code=HTTP_409_CONFLICT,
detail=str(exc),
)
) from exc


async def get_todo_by_title(todo_name, session: AsyncSession) -> TodoItem:
query = select(TodoItem).where(TodoItem.title == todo_name)
result = await session.execute(query)
try:
return result.scalar_one()
except NoResultFound:
raise NotFoundException(detail=f"TODO {todo_name!r} not found")
except NoResultFound as e:
raise NotFoundException(detail=f"TODO {todo_name!r} not found") from e


async def get_todo_list(done: Optional[bool], session: AsyncSession) -> List[TodoItem]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,16 @@ async def provide_transaction(state: State) -> AsyncGenerator[AsyncSession, None
raise ClientException(
status_code=HTTP_409_CONFLICT,
detail=str(exc),
)
) from exc


async def get_todo_by_title(todo_name, session: AsyncSession) -> TodoItem:
query = select(TodoItem).where(TodoItem.title == todo_name)
result = await session.execute(query)
try:
return result.scalar_one()
except NoResultFound:
raise NotFoundException(detail=f"TODO {todo_name!r} not found")
except NoResultFound as e:
raise NotFoundException(detail=f"TODO {todo_name!r} not found") from e


async def get_todo_list(done: Optional[bool], session: AsyncSession) -> List[TodoItem]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async def provide_transaction(state: State) -> AsyncGenerator[AsyncSession, None
raise ClientException(
status_code=HTTP_409_CONFLICT,
detail=str(exc),
)
) from exc


def serialize_todo(todo: TodoItem) -> TodoType:
Expand All @@ -63,8 +63,8 @@ async def get_todo_by_title(todo_name, session: AsyncSession) -> TodoItem:
result = await session.execute(query)
try:
return result.scalar_one()
except NoResultFound:
raise NotFoundException(detail=f"TODO {todo_name!r} not found")
except NoResultFound as e:
raise NotFoundException(detail=f"TODO {todo_name!r} not found") from e


async def get_todo_list(done: Optional[bool], session: AsyncSession) -> List[TodoItem]:
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/parameters/header_and_cookie_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ async def get_user(
token: Annotated[str, Parameter(header="X-API-KEY")],
cookie: Annotated[str, Parameter(cookie="my-cookie-param")],
) -> User:
if not (token == VALID_TOKEN and cookie == VALID_COOKIE_VALUE):
if token != VALID_TOKEN or cookie != VALID_COOKIE_VALUE:
raise NotAuthorizedException
return User.parse_obj(USER_DB[user_id])

Expand Down
8 changes: 1 addition & 7 deletions docs/examples/security/using_session_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,7 @@ class UserLoginPayload(BaseModel):
async def retrieve_user_handler(
session: Dict[str, Any], connection: "ASGIConnection[Any, Any, Any, Any]"
) -> Optional[User]:
# we retrieve the user instance based on session data

user_id = session.get("user_id")
if user_id:
return MOCK_DB.get(user_id)

return None
return MOCK_DB.get(user_id) if (user_id := session.get("user_id")) else None


@post("/login")
Expand Down
3 changes: 1 addition & 2 deletions litestar/_asgi/routing_trie/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,9 @@ def add_route_to_trie(
"""
current_node = root_node

is_mount = hasattr(route, "route_handler") and getattr(route.route_handler, "is_mount", False) # pyright: ignore
has_path_parameters = bool(route.path_parameters)

if is_mount: # pyright: ignore
if (route_handler := getattr(route, "route_handler", None)) and getattr(route_handler, "is_mount", False):
current_node = add_mount_route(
current_node=current_node,
mount_routes=mount_routes,
Expand Down
4 changes: 3 additions & 1 deletion litestar/_asgi/routing_trie/traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,9 @@ def parse_path_to_route(
remaining_path = path[match.end() :]
# since we allow regular handlers under static paths, we must validate that the request does not match
# any such handler.
if not mount_node.children or not any(sub_route in path for sub_route in mount_node.children): # type: ignore
if not mount_node.children or all(
sub_route not in path for sub_route in mount_node.children # type: ignore
):
asgi_app, handler = parse_node_handlers(node=mount_node, method=method)
remaining_path = remaining_path or "/"
if not mount_node.is_static:
Expand Down
3 changes: 1 addition & 2 deletions litestar/_asgi/routing_trie/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ def validate_node(node: RouteTrieNode) -> None:
node.is_mount
and node.children
and any(
v
for v in chain.from_iterable(
chain.from_iterable(
list(child.path_parameters.values())
if isinstance(child.path_parameters, dict)
else child.path_parameters
Expand Down
15 changes: 9 additions & 6 deletions litestar/_kwargs/kwargs_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def _get_param_definitions(
layered_parameter = layered_parameters[field_name]
field = signature_field if signature_field.is_parameter_field else layered_parameter
default_value = (
signature_field.default_value if not signature_field.is_empty else layered_parameter.default_value
layered_parameter.default_value if signature_field.is_empty else signature_field.default_value
)

param_definitions.add(
Expand Down Expand Up @@ -309,10 +309,10 @@ def create_for_signature_model(
data_signature_field = signature_fields.get("data")

media_type: RequestEncodingType | str | None = None
if data_signature_field and isinstance(data_signature_field.kwarg_model, BodyKwarg):
media_type = data_signature_field.kwarg_model.media_type

if data_signature_field:
if isinstance(data_signature_field.kwarg_model, BodyKwarg):
media_type = data_signature_field.kwarg_model.media_type

if media_type in (RequestEncodingType.MULTI_PART, RequestEncodingType.URL_ENCODED):
expected_form_data = (media_type, data_signature_field, data_dto)
elif data_dto:
Expand Down Expand Up @@ -467,8 +467,11 @@ def _validate_raw_kwargs(
f"Make sure to use distinct keys for your dependencies, path parameters and aliased parameters."
)

used_reserved_kwargs = {*parameter_names, *path_parameters, *dependency_keys}.intersection(RESERVED_KWARGS)
if used_reserved_kwargs:
if used_reserved_kwargs := {
*parameter_names,
*path_parameters,
*dependency_keys,
}.intersection(RESERVED_KWARGS):
raise ImproperlyConfiguredException(
f"Reserved kwargs ({', '.join(RESERVED_KWARGS)}) cannot be used for dependencies and parameter arguments. "
f"The following kwargs have been used: {', '.join(used_reserved_kwargs)}"
Expand Down
6 changes: 4 additions & 2 deletions litestar/_kwargs/parameter_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def create_parameter_definition(
Returns:
A ParameterDefinition tuple.
"""
default_value = signature_field.default_value if not signature_field.is_empty else None
default_value = None if signature_field.is_empty else signature_field.default_value
kwargs_model = signature_field.kwarg_model if isinstance(signature_field.kwarg_model, ParameterKwarg) else None

field_alias = kwargs_model.query if kwargs_model and kwargs_model.query else field_name
Expand All @@ -60,7 +60,9 @@ def create_parameter_definition(
field_alias=field_alias,
default_value=default_value,
is_required=signature_field.is_required
and (default_value is None and not (signature_field.is_optional or signature_field.is_any)),
and default_value is None
and not signature_field.is_optional
and not signature_field.is_any,
is_sequence=signature_field.is_non_string_sequence,
)

Expand Down
2 changes: 1 addition & 1 deletion litestar/_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def parse_multipart_form(body: bytes, boundary: bytes, multipart_form_part_limit
line_index = line_end_index + 2
colon_index = form_line.index(":")
current_idx = colon_index + 2
form_header_field = form_line[0:colon_index].lower()
form_header_field = form_line[:colon_index].lower()
form_header_value, form_parameters = parse_content_header(form_line[current_idx:])

if form_header_field == "content-disposition":
Expand Down
2 changes: 1 addition & 1 deletion litestar/_openapi/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def get_layered_parameter(
layer_field = layered_parameters[field_name]

field = signature_field if signature_field.is_parameter_field else layer_field
default_value = signature_field.default_value if not signature_field.is_empty else layer_field.default_value
default_value = layer_field.default_value if signature_field.is_empty else signature_field.default_value
field_type = signature_field.field_type if signature_field is not Empty else layer_field.field_type # type: ignore

parameter_name = field_name
Expand Down
11 changes: 4 additions & 7 deletions litestar/_openapi/schema_generation/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,13 +695,10 @@ def for_typed_dict(self, field_type: TypedDictClass) -> Schema:
Returns:
A schema instance.
"""
annotations: dict[str, Any] = {}
for k, v in get_type_hints(field_type, include_extras=True).items():
if get_origin(v) in (Required, NotRequired):
annotations[k] = get_args(v)[0]
else:
annotations[k] = v

annotations: dict[str, Any] = {
k: get_args(v)[0] if get_origin(v) in (Required, NotRequired) else v
for k, v in get_type_hints(field_type, include_extras=True).items()
}
return Schema(
required=sorted(getattr(field_type, "__required_keys__", [])),
properties={k: self.for_field(SignatureField.create(v, k)) for k, v in annotations.items()},
Expand Down
6 changes: 2 additions & 4 deletions litestar/_openapi/typescript_converter/schema_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,10 @@ def normalize_typescript_namespace(value: str, allow_quoted: bool) -> str:
Returns:
A normalized value
"""
if not allow_quoted and not (value[0].isalpha() or value[0] in {"_", "$"}):
if not allow_quoted and not value[0].isalpha() and value[0] not in {"_", "$"}:
raise ValueError(f"invalid typescript namespace {value}")
if allow_quoted:
if allowed_key_re.fullmatch(value):
return value
return f'"{value}"'
return value if allowed_key_re.fullmatch(value) else f'"{value}"'
return invalid_namespace_re.sub("", value)


Expand Down
7 changes: 2 additions & 5 deletions litestar/_openapi/typescript_converter/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,12 @@

def _as_string(value: Any) -> str:
if isinstance(value, str):
return '"' + value + '"'
return f'"{value}"'

if isinstance(value, bool):
return "true" if value else "false"

if value is None:
return "null"

return str(value)
return "null" if value is None else str(value)


class TypeScriptElement(ABC):
Expand Down
2 changes: 1 addition & 1 deletion litestar/_openapi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def default_operation_id_creator(
)

components_namespace = ""
for component in (c if not isinstance(c, PathParameterDefinition) else c.name for c in path_components):
for component in (c.name if isinstance(c, PathParameterDefinition) else c for c in path_components):
if component.title() not in components_namespace:
components_namespace += component.title()

Expand Down
Loading

0 comments on commit e5f2b64

Please sign in to comment.