-
Notifications
You must be signed in to change notification settings - Fork 35
/
summ_ai_api_client.py
341 lines (295 loc) · 13.9 KB
/
summ_ai_api_client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
"""
This module contains the API client to interact with the SUMM.AI API
"""
from __future__ import annotations
import asyncio
import logging
from functools import partial
from itertools import chain
from typing import TYPE_CHECKING
import aiohttp
from django.conf import settings
from django.contrib import messages
from django.utils.translation import gettext_lazy as _
from django.utils.translation import ngettext_lazy
from ..cms.utils.stringify_list import iter_to_string
from ..core.utils.machine_translation_api_client import MachineTranslationApiClient
from ..core.utils.machine_translation_provider import MachineTranslationProvider
from .utils import (
HTMLSegment,
PatientTaskQueue,
SummAiInvalidJSONError,
SummAiRateLimitingExceeded,
SummAiRuntimeError,
TextField,
TranslationHelper,
worker,
)
if TYPE_CHECKING:
from asyncio import AbstractEventLoop
from collections.abc import Callable
from typing import Any, Iterator
from aiohttp import ClientSession
from django.forms.models import ModelFormMetaclass
from django.http import HttpRequest
from ..cms.models.pages.page import Page
logger = logging.getLogger(__name__)
class SummAiApiClient(MachineTranslationApiClient):
"""
SUMM.AI API client to get German pages in Easy German language.
"""
def __init__(self, request: HttpRequest, form_class: ModelFormMetaclass) -> None:
"""
Constructor initializes the class variables
:param region: The current region
:param form_class: The :class:`~integreat_cms.cms.forms.custom_content_model_form.CustomContentModelForm`
subclass of the current content type
"""
super().__init__(request, form_class)
if not MachineTranslationProvider.is_permitted(
request.region, request.user, form_class._meta.model
):
raise RuntimeError(
f'Machine translations are disabled for content type "{form_class._meta.model}" and {request.user!r}.'
)
if not settings.SUMM_AI_ENABLED:
raise RuntimeError("SUMM.AI is disabled globally.")
if not self.region.summ_ai_enabled:
raise RuntimeError(f"SUMM.AI is disabled in {self.region!r}.")
async def translate_text_field(
self, session: ClientSession, text_field: TextField
) -> TextField:
"""
Uses :meth:`aiohttp.ClientSession.post` to perform an asynchronous POST request to the SUMM.AI API.
After the translation is finished, the processing is delegated to the specific textfield's
:meth:`~integreat_cms.summ_ai_api.utils.TextField.translate`.
:param session: The session object which is used for the request
:param text_field: The text field to be translated
:return: The modified text field containing the translated text
Note that :func:`~integreat_cms.summ_ai_api.utils.worker` currently not only counts :class:`~integreat_cms.summ_ai_api.utils.SummAiRateLimitingExceeded`
but also :class:`~integreat_cms.summ_ai_api.utils.SummAiInvalidJSONError` as a rate limit hit and enqueues the task again.
:raises SummAiRuntimeError: If text_field is none or text is empty
:raises SummAiInvalidJSONError: Invalid JSON response returned by the API
:raises SummAiRateLimitingExceeded: If query runs into rate limit (429 or 529 response)
"""
logger.debug("Translating %r", text_field)
# Use test region for development
user = settings.TEST_REGION_SLUG if settings.DEBUG else self.region.slug
# Set the language level to "plain" if the region prefers Plain German
output_language_level = (
"plain"
if self.request.region.slug in settings.SUMM_AI_PLAIN_GERMAN_REGIONS
else "easy"
)
if (
text_field is None
or (isinstance(text_field, TextField) and not text_field.text)
or (isinstance(text_field, HTMLSegment) and text_field.segment is None)
):
# This is normally filtered out before this function is called,
# something must have gone wrong.
# Raise an exception without immediately catching it!
raise SummAiRuntimeError("Field to translate is None or empty")
try:
async with session.post(
settings.SUMM_AI_API_URL,
headers={"Authorization": f"Bearer {settings.SUMM_AI_API_KEY}"},
json={
"input_text": text_field.text,
"user": user,
"separator": settings.SUMM_AI_SEPARATOR,
"is_test": settings.SUMM_AI_TEST_MODE,
"is_initial": settings.SUMM_AI_IS_INITIAL,
"output_language_level": output_language_level,
},
) as response:
# Wait for the response
try:
response_data = await response.json()
except aiohttp.ContentTypeError as e:
logger.error(
"SUMM.AI API %s response failed to parse as JSON: %s: %s",
response.status,
type(e),
e,
)
raise SummAiInvalidJSONError(
f"API delivered invalid JSON: {response.status} - {await response.text()}"
) from e
if self.validate_response(response_data, response.status):
# Let the field handle the translated text
text_field.translate(response_data["translated_text"])
# If text is not in response, validate_response()
# will raise exceptions - so we don't need an else branch.
except (aiohttp.ClientError, asyncio.TimeoutError, SummAiRuntimeError) as e:
logger.error(
"SUMM.AI translation of %r failed because of %s: %s",
text_field,
type(e),
e,
)
text_field.exception = e
return text_field
async def translate_text_fields(
self, loop: AbstractEventLoop, text_fields: Iterator[TextField]
) -> chain[list[TextField]]:
"""
Translate a list of text fields from German into Easy German.
Create an async task
:meth:`~integreat_cms.summ_ai_api.summ_ai_api_client.SummAiApiClient.translate_text_field`
for each entry.
:param loop: The asyncio event loop
:param text_fields: The text fields to be translated
:returns: The list of completed text fields
"""
# Set a custom SUMM.AI timeout
timeout = aiohttp.ClientTimeout(total=60 * settings.SUMM_AI_TIMEOUT)
async with aiohttp.ClientSession(timeout=timeout) as session:
# Create tasks for each text field
tasks = [
# translate_text_field() gives us a coroutine that can be executed
# asynchronously as a task. If we have to repeat the task
# (e.g. if we run into rate limiting and have to resend the request),
# we need a NEW coroutine object.
# For that case, we need a representation of our function which can be
# evaluated when needed, giving a new coroutine for the task each time.
partial(self.translate_text_field, session, text_field)
for text_field in text_fields
]
# If the translation is aborted, set the exception field
# to both signal that this wasn't translated and to display a reason why
def abort_function(task: partial, reason: Any) -> None:
# Retrieve field from arguments to translate_text_field()
field = task.args[1]
# Set the exception
field.exception = f"Machine translation aborted: {reason}"
# A "patient" task queue which only hands out sleep tasks after a task was reported as failed
task_generator = PatientTaskQueue(tasks, abort_function=abort_function)
# Wait for all tasks to finish and collect the results
worker_results = await asyncio.gather(
*[
worker(loop, task_generator, str(i))
for i in range(settings.SUMM_AI_MAX_CONCURRENT_REQUESTS)
]
)
# Put all results in one single list
all_results = chain(worker_results)
return all_results
def translate_queryset(self, queryset: list[Page], language_slug: str) -> None:
"""
Translate a queryset of content objects from German into Easy German.
To increase the speed of the translations, all operations are parallelized.
:param queryset: The queryset which should be translated
:param language_slug: The target language slug to translate into
"""
# Make sure both languages exist
self.request.region.get_language_or_404(settings.SUMM_AI_GERMAN_LANGUAGE_SLUG)
easy_german = self.request.region.get_language_or_404(
settings.SUMM_AI_EASY_GERMAN_LANGUAGE_SLUG
)
# Initialize translation helpers for each object instance
translation_helpers = [
TranslationHelper(self.request, self.form_class, object_instance)
for object_instance in queryset
]
# Aggregate all strings that need to be translated
text_fields = chain(
*[
translation_helper.get_text_fields()
for translation_helper in translation_helpers
]
)
# Initialize async event loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Translate queryset asynchronously in parallel
loop.run_until_complete(self.translate_text_fields(loop, text_fields))
# Commit changes to the database
successes = []
errors = []
for translation_helper in translation_helpers:
if TYPE_CHECKING:
assert translation_helper.german_translation
if translation_helper.commit(easy_german):
successes.append(translation_helper.german_translation.title)
else:
errors.append(translation_helper.german_translation.title)
if translation_helpers:
meta = type(translation_helpers[0].object_instance)._meta
model_name = meta.verbose_name.title()
model_name_plural = meta.verbose_name_plural
else:
model_name = model_name_plural = ""
if successes:
messages.success(
self.request,
ngettext_lazy(
"{model_name} {object_names} has been successfully translated into Easy German.",
"The following {model_name_plural} have been successfully translated into Easy German: {object_names}",
len(successes),
).format(
model_name=model_name,
model_name_plural=model_name_plural,
object_names=iter_to_string(successes),
),
)
if errors:
messages.error(
self.request,
ngettext_lazy(
"{model_name} {object_names} could not be automatically translated into Easy German.",
"The following {model_name_plural} could not be automatically translated into Easy German: {object_names}",
len(errors),
).format(
model_name=model_name,
model_name_plural=model_name_plural,
object_names=iter_to_string(errors),
),
)
@classmethod
def validate_response(cls, response_data: dict, response_status: int) -> bool:
"""
Checks if translated text is found in SummAiApi-response
:param response_data: The response-data from SummAiApi
:param response_status: The response-status form SummAiApi-Request
:returns: True or False
:raises SummAiRuntimeError: The response doesn't contain the field translated_text.
"""
cls.check_internal_server_error(response_status)
cls.check_rate_limit_exceeded(response_status)
if "translated_text" not in response_data:
if "error" in response_data:
raise SummAiRuntimeError(
f"API error: {response_status} - {response_data['error']}"
)
raise SummAiRuntimeError(
f"Unexpected API result: {response_status} - {response_data!r}"
)
return True
@staticmethod
def check_internal_server_error(response_status: int) -> bool:
"""
Checks if we got a HTTP 500 error
:param response_status: The response-status form SummAiApi-Request
:returns: False (if the response_status is not 500)
:raises SummAiRuntimeError: If the response_status is 500
"""
if response_status == 500:
logger.error("SUMM.AI API has internal server error")
raise SummAiRuntimeError("API has internal server error")
return False
@staticmethod
def check_rate_limit_exceeded(response_status: int) -> bool:
"""
Checks if the limit of requests was exceeded (triggered by response_status=429 or 529) and logs this occurrence
:param response_status: The response-status form SummAiApi-Request
:returns: False (if the response_status is neither 429 nor 529)
:raises SummAiRateLimitingExceeded: If the response_status is 429 or 529
"""
if response_status in (429, 529):
logger.error(
"SUMM.AI translation is waiting for %ss because the rate limit has been exceeded",
settings.SUMM_AI_RATE_LIMIT_COOLDOWN,
)
raise SummAiRateLimitingExceeded
return False