Skip to content

Commit

Permalink
Merge pull request #45 from a-luna:improve-rate-limit-logic-and-add-c…
Browse files Browse the repository at this point in the history
…lient-ip-tracking_20240504

✨ Improve rate limit logic and add client IP tracking
  • Loading branch information
a-luna authored May 4, 2024
2 parents f55c15f + 2def64b commit df59d1c
Show file tree
Hide file tree
Showing 41 changed files with 227 additions and 206 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM python:3.11
FROM python:3.12
SHELL ["/bin/bash", "-c"]

ARG ENV
Expand Down
2 changes: 1 addition & 1 deletion app/api/api_v1/dependencies/block_name_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from fastapi import HTTPException, Path, Query, status

import app.db.models as db
from app.data.cache import cached_data
from app.core.cache import cached_data
from app.docs.dependencies.custom_parameters import BLOCK_NAME_DESCRIPTION, CHAR_SEARCH_BLOCK_NAME_DESCRIPTION
from app.schemas.enums.block_name import UnicodeBlockName

Expand Down
2 changes: 1 addition & 1 deletion app/api/api_v1/dependencies/filter_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from fastapi import HTTPException, Query, status

from app.api.api_v1.dependencies.filter_param_matcher import filter_param_matcher
from app.data.cache import cached_data
from app.core.cache import cached_data
from app.docs.dependencies.custom_parameters import (
CHAR_NAME_FILTER_DESCRIPTION,
CJK_DEFINITION_FILTER_DESCRIPTION,
Expand Down
2 changes: 1 addition & 1 deletion app/api/api_v1/dependencies/list_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from fastapi import HTTPException, Query, status

from app.api.api_v1.dependencies.util import get_decimal_number_from_hex_codepoint
from app.data.cache import cached_data
from app.core.cache import cached_data
from app.docs.dependencies.custom_parameters import (
ENDING_BEFORE_BLOCK_ID_DESCRIPTION,
ENDING_BEFORE_CODEPOINT_DESCRIPTION,
Expand Down
2 changes: 1 addition & 1 deletion app/api/api_v1/dependencies/plane_abbrev_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from fastapi import HTTPException, Query, status

from app.data.cache import cached_data
from app.core.cache import cached_data
from app.docs.dependencies.custom_parameters import PLANE_NAME_DESCRIPTION


Expand Down
83 changes: 38 additions & 45 deletions app/api/api_v1/dependencies/util.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,71 @@
import re

from fastapi import HTTPException, status

from app.constants import (
ASCII_HEX,
CP_NEED_LEADING_ZEROS_REGEX,
CP_NO_PREFIX_REGEX_STRICT,
CP_OUT_OF_RANGE_REGEX,
CP_PREFIX_1_REGEX_STRICT,
CP_PREFIX_2_REGEX_STRICT,
MAX_CODEPOINT,
)
from app.core.encoding import get_codepoint_string
from app.core.result import Result
from app.data.constants import ASCII_HEX, MAX_CODEPOINT
from app.data.encoding import get_codepoint_string

CP_PREFIX_1_REGEX = re.compile(r"^U\+([A-Fa-f0-9]{4,6})$")
CP_PREFIX_2_REGEX = re.compile(r"^0x([A-Fa-f0-9]{2,6})$")
CP_NO_PREFIX_REGEX = re.compile(r"^([A-Fa-f0-9]{2,6})$")
CP_NEED_LEADING_ZEROS_REGEX = re.compile(r"^U\+([A-Fa-f0-9]{1,3})$")
CP_OUT_OF_RANGE_REGEX = re.compile(r"^(?:U\+)([A-Fa-f0-9]+)|(?:0x)?([A-Fa-f0-9]{7,})$")
from app.core.util import s


def get_decimal_number_from_hex_codepoint(codepoint: str, starting_after: bool = True) -> int:
result = get_codepoint_hex_from_string(codepoint)
if result.failure:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=result.error)
cp_hex = result.value or "0"
codepoint_dec = int(cp_hex, 16)
result = check_codepoint_is_in_unicode_range(codepoint_dec, starting_after)
cp_dec = int(cp_hex, 16)
result = check_codepoint_is_in_unicode_range(cp_dec, starting_after)
if result.success:
return codepoint_dec
return cp_dec
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=result.error)


def get_codepoint_hex_from_string(s: str) -> Result[str]:
match = CP_PREFIX_1_REGEX.match(s)
if match:
def get_codepoint_hex_from_string(cp: str) -> Result[str]:
if match := CP_PREFIX_1_REGEX_STRICT.match(cp):
return Result.Ok(match[1])
match = CP_PREFIX_2_REGEX.match(s)
if match:
if match := CP_PREFIX_2_REGEX_STRICT.match(cp):
return Result.Ok(match[1])
match = CP_NO_PREFIX_REGEX.match(s)
if match:
if match := CP_NO_PREFIX_REGEX_STRICT.match(cp):
return Result.Ok(match[1])
return Result.Fail(get_error_message_for_invalid_codepoint_value(s))
return Result.Fail(get_error_message_for_invalid_codepoint_value(cp))


def check_codepoint_is_in_unicode_range(codepoint: int, starting_after: bool) -> Result[int]:
lower_limit = 0 if starting_after else 1
upper_limit = MAX_CODEPOINT if starting_after else MAX_CODEPOINT + 1
if codepoint in range(lower_limit, upper_limit + 1):
return Result.Ok(codepoint)
error = f"{get_codepoint_string(codepoint)} is not within the Unicode codespace (U+0000 to U+10FFFF)."
return Result.Fail(error)


def get_error_message_for_invalid_codepoint_value(s: str) -> str:
sanitized_codepoint = sanitize_codepoint_value(s)
if match := CP_NEED_LEADING_ZEROS_REGEX.search(s):
def get_error_message_for_invalid_codepoint_value(cp: str) -> str:
sanitized_codepoint = sanitize_codepoint_value(cp)
if match := CP_NEED_LEADING_ZEROS_REGEX.search(cp):
return (
f"The value provided (U+{sanitized_codepoint.upper()}) is invalid because Unicode codepoint values "
"prefixed with 'U+' must contain at least 4 hexadecimal digits. The correct way to request "
f"the character assigned to codepoint 0x{match[1].upper()} is with the value "
f"'{get_codepoint_string(int(match[1], 16))}', which adds the necessary leading zeros."
)
invalid_chars = get_invalid_hex_characters(sanitized_codepoint)
if invalid_chars:
if invalid_chars := get_invalid_hex_characters(sanitized_codepoint):
return (
f"The value provided ({s}) contains {len(invalid_chars)} invalid hexadecimal "
f"character{'s' if len(invalid_chars) > 1 else ''}: [{', '.join(invalid_chars)}]. "
"The codepoint value must be expressed as a hexadecimal value within range 0000...10FFFF, "
"optionally prefixed by 'U+'' or '0x'."
f"The value provided ({cp}) contains {len(invalid_chars)} invalid hexadecimal "
f"character{s(invalid_chars)}: [{', '.join(invalid_chars)}]. The codepoint value must be expressed "
"as a hexadecimal value within range 0000...10FFFF, optionally prefixed by 'U+'' or '0x'."
)
return (
(
if match := CP_OUT_OF_RANGE_REGEX.match(cp):
return (
f"U+{match[1] or match[2]} is not within the range of valid codepoints for Unicode characters "
"(U+0000 to U+10FFFF)."
)
if (match := CP_OUT_OF_RANGE_REGEX.match(s))
else "Error! Value provided is not a valid hexadecimal number."
)
return "Error! Value provided is not a valid hexadecimal number."


def sanitize_codepoint_value(codepoint: str) -> str:
Expand All @@ -71,13 +74,3 @@ def sanitize_codepoint_value(codepoint: str) -> str:

def get_invalid_hex_characters(s: str) -> list[str]:
return sorted({char for char in s if char not in ASCII_HEX})


def check_codepoint_is_in_unicode_range(codepoint: int, starting_after: bool) -> Result[int]:
lower_limit = 0 if starting_after else 1
upper_limit = MAX_CODEPOINT if starting_after else MAX_CODEPOINT + 1
if codepoint in range(lower_limit, upper_limit + 1):
return Result.Ok(codepoint)
cp_hex = get_codepoint_string(codepoint)
error = f"{cp_hex} is not within the range of valid codepoints for Unicode characters (U+0000 to U+10FFFF)."
return Result.Fail(error)
2 changes: 1 addition & 1 deletion app/api/api_v1/endpoints/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
from app.api.api_v1.pagination import paginate_search_results
from app.config.api_settings import get_settings
from app.data.cache import cached_data
from app.core.cache import cached_data

router = APIRouter()

Expand Down
4 changes: 2 additions & 2 deletions app/api/api_v1/endpoints/characters.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from app.api.api_v1.dependencies.filter_param_matcher import filter_param_matcher
from app.api.api_v1.endpoints.util import get_character_details
from app.api.api_v1.pagination import paginate_search_results
from app.data.cache import cached_data
from app.data.encoding import get_codepoint_string
from app.core.cache import cached_data
from app.core.encoding import get_codepoint_string
from app.db.session import DBSession, get_session
from app.docs.dependencies.custom_parameters import (
UNICODE_CHAR_STRING_DESCRIPTION,
Expand Down
2 changes: 1 addition & 1 deletion app/api/api_v1/endpoints/planes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import app.db.models as db
from app.config.api_settings import get_settings
from app.data.cache import cached_data
from app.core.cache import cached_data

router = APIRouter()

Expand Down
7 changes: 4 additions & 3 deletions app/config/api_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

import app.db.models as db
from app.config.dotenv_file import read_dotenv_file
from app.data.constants import UNICODE_PLANES_DEFAULT, UNICODE_VERSION_RELEASE_DATES
from app.constants import UNICODE_PLANES_DEFAULT, UNICODE_VERSION_RELEASE_DATES
from app.core.util import s


class ApiSettingsDict(TypedDict):
Expand Down Expand Up @@ -152,9 +153,9 @@ def api_settings_report(self) -> str:

@property
def rate_limit_settings_report(self) -> str:
rate = f"{self.RATE_LIMIT_PER_PERIOD} request{'s' if self.RATE_LIMIT_PER_PERIOD > 1 else ''}"
rate = f"{self.RATE_LIMIT_PER_PERIOD} request{s(self.RATE_LIMIT_PER_PERIOD)}"
interval = self.RATE_LIMIT_PERIOD_SECONDS.total_seconds()
period = f"{interval} seconds" if interval > 1 else "second"
period = f"{interval}second{s(interval)}"
rate_limit_settings = f"Rate Limit Settings: {rate} per {period}"
burst_enabled = self.RATE_LIMIT_BURST > 1
if burst_enabled: # pragma: no cover
Expand Down
10 changes: 9 additions & 1 deletion app/data/constants.py → app/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@
MAX_CODEPOINT = 1114111
ALL_UNICODE_CODEPOINTS = range(MAX_CODEPOINT + 1)
ASCII_HEX = "0123456789ABCDEFabcdef"
CODEPOINT_WITH_PREFIX_REGEX = re.compile(r"(?:U\+([A-Fa-f0-9]{4,6}))")

DATE_MONTH_NAME = "%b %d, %Y"

CP_PREFIX_1_REGEX = re.compile(r"(?:U\+([A-Fa-f0-9]{4,6}))")
CP_PREFIX_1_REGEX_STRICT = re.compile(r"^U\+([A-Fa-f0-9]{4,6})$")
CP_PREFIX_2_REGEX_STRICT = re.compile(r"^0x([A-Fa-f0-9]{2,6})$")
CP_NO_PREFIX_REGEX_STRICT = re.compile(r"^([A-Fa-f0-9]{2,6})$")
CP_NEED_LEADING_ZEROS_REGEX = re.compile(r"^U\+([A-Fa-f0-9]{1,3})$")
CP_OUT_OF_RANGE_REGEX = re.compile(r"^(?:U\+)([A-Fa-f0-9]+)|(?:0x)?([A-Fa-f0-9]{7,})$")

CharacterFlag = namedtuple("CharacterFlag", ["name", "alias", "db_column"])

Expand Down
25 changes: 5 additions & 20 deletions app/data/cache.py → app/core/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import app.db.models as db
from app.config.api_settings import get_settings
from app.data.constants import (
from app.constants import (
ALL_CONTROL_CHARACTERS,
ALL_UNICODE_CODEPOINTS,
ASCII_HEX,
Expand Down Expand Up @@ -163,16 +163,6 @@ def all_surrogate_codepoints(self) -> set[int]:
def all_private_use_codepoints(self) -> set[int]:
return self.get_all_codepoints_in_block_id_list(self.private_use_block_ids)

@property
def all_assigned_codepoints(self) -> set[int]:
return set(
list(self.all_non_unihan_codepoints)
+ list(self.all_cjk_codepoints)
+ list(self.all_tangut_codepoints)
+ list(self.all_surrogate_codepoints)
+ list(self.all_private_use_codepoints)
)

@property
def official_number_of_unicode_characters(self) -> int:
# The "official" number of characters listed for each version of Unicode is the total number
Expand Down Expand Up @@ -242,9 +232,6 @@ def get_unicode_plane_containing_block_id(self, block_id: int) -> db.UnicodePlan
def codepoint_is_in_unicode_space(self, codepoint: int) -> bool:
return codepoint in self.all_codepoints_in_unicode_space

def codepoint_is_assigned(self, codepoint: int) -> bool:
return codepoint in self.all_assigned_codepoints

def codepoint_is_noncharacter(self, codepoint: int) -> bool:
return codepoint in self.all_noncharacter_codepoints

Expand Down Expand Up @@ -335,13 +322,11 @@ def get_mapped_codepoint_from_hex(self, codepoint_hex: str) -> str: # pragma: n
return self.get_mapped_codepoint_from_int(int(codepoint_hex, 16))

def get_mapped_codepoint_from_int(self, codepoint_dec: int) -> str: # pragma: no cover
if codepoint_dec not in ALL_UNICODE_CODEPOINTS:
if not codepoint_dec:
return ""
if not self.codepoint_is_in_unicode_space(codepoint_dec):
return f"Invalid Codepoint ({codepoint_dec} is not within the Unicode codespace)"
return (
f"{chr(codepoint_dec)} (U+{codepoint_dec:04X} {cached_data.get_character_name(codepoint_dec)})"
if codepoint_dec
else ""
)
return f"{chr(codepoint_dec)} (U+{codepoint_dec:04X} {cached_data.get_character_name(codepoint_dec)})"

def get_all_codepoints_in_block_id_list(self, block_id_list: list[int]) -> set[int]:
blocks = [self.get_unicode_block_by_id(block_id) for block_id in block_id_list]
Expand Down
File renamed without changes.
33 changes: 17 additions & 16 deletions app/core/rate_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
format_timedelta_str,
get_duration_between_timestamps,
get_time_until_timestamp,
s,
)

RATE_LIMIT_ROUTE_REGEX = re.compile(r"^\/v1\/blocks|characters|codepoints|planes")
Expand Down Expand Up @@ -81,9 +82,9 @@ def is_exceeded(self, request: Request) -> Result[None]:
Adapted for Python from this article:
https://vikas-kumar.medium.com/rate-limiting-techniques-245c3a5e9cad
"""
if not self.apply_rate_limit_to_request(request):
client_ip = get_client_ip_address(request)
if not self.apply_rate_limit_to_request(request, client_ip):
return Result.Ok()
client_ip = request.client.host if request.client else "localhost"
arrived_at = self.redis.time()
self.redis.setnx(client_ip, "0")
try:
Expand All @@ -100,10 +101,10 @@ def is_exceeded(self, request: Request) -> Result[None]:
except LockError: # pragma: no cover
return self.lock_error(client_ip)

def apply_rate_limit_to_request(self, request: Request):
def apply_rate_limit_to_request(self, request: Request, client_ip: str):
if self.settings.is_test:
return enable_rate_limit_feature_for_test(request)
return request_origin_is_external(request) and requested_route_is_rate_limited(request) # pragma: no cover
return rate_limit_applies_to_route(request) and client_ip_is_external(request, client_ip) # pragma: no cover

def get_allowed_at(self, tat: float) -> float:
return (dtaware_fromtimestamp(tat) - self.delay_tolerance_ms).timestamp()
Expand Down Expand Up @@ -131,6 +132,12 @@ def lock_error(self, client) -> Result[None]: # pragma: no cover
return Result.Fail(error)


def get_client_ip_address(request: Request) -> str:
if "x-forwarded-for" in request.headers:
return request.headers["x-forwarded-for"]
return request.client.host if request.client else "localhost"


def enable_rate_limit_feature_for_test(request: Request) -> bool:
if "x-verify-rate-limiting" in request.headers:
return request.headers["x-verify-rate-limiting"] == "true"
Expand All @@ -141,26 +148,20 @@ def enable_rate_limit_feature_for_test(request: Request) -> bool:
return False # pragma: no cover


def request_origin_is_external(request: Request) -> bool: # pragma: no cover
if request.client.host in ["localhost", "127.0.0.1", "testserver"]:
def rate_limit_applies_to_route(request: Request) -> bool: # pragma: no cover
return bool(RATE_LIMIT_ROUTE_REGEX.search(request.url.path))


def client_ip_is_external(request: Request, client_ip: str) -> bool: # pragma: no cover
if client_ip in ["localhost", "127.0.0.1", "testserver"] or client_ip.startswith("172.17.0."):
return False
if "sec-fetch-site" in request.headers:
return request.headers["sec-fetch-site"] != "same-site"
return True


def requested_route_is_rate_limited(request: Request) -> bool: # pragma: no cover
return bool(RATE_LIMIT_ROUTE_REGEX.search(request.url.path))


def get_time_portion(ts: float) -> str:
return dtaware_fromtimestamp(ts).time().strftime("%I:%M:%S.%f %p")


def s(x: list | int | float) -> str:
if isinstance(x, list):
return "s" if len(x) > 1 else ""
return "s" if x > 1 else ""


rate_limit = RateLimit(redis)
Loading

0 comments on commit df59d1c

Please sign in to comment.