diff --git a/notebooks/api/0.8/13-forgot-user-password.ipynb b/notebooks/api/0.8/13-forgot-user-password.ipynb new file mode 100644 index 00000000000..8ad3cdf0918 --- /dev/null +++ b/notebooks/api/0.8/13-forgot-user-password.ipynb @@ -0,0 +1,181 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Forgot User Password" + ] + }, + { + "cell_type": "markdown", + "id": "1", + "metadata": {}, + "source": [ + "## Initialize the server" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "# stdlib\n", + "\n", + "# syft absolute\n", + "import syft as sy\n", + "from syft import SyftError\n", + "from syft import SyftSuccess\n", + "\n", + "server = sy.orchestra.launch(\n", + " name=\"test-datasite-1\",\n", + " dev_mode=True,\n", + " create_producer=True,\n", + " n_consumers=3,\n", + " reset=True,\n", + " port=8081,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "3", + "metadata": {}, + "source": [ + "## Register a new user" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "datasite_client = server.login(email=\"info@openmined.org\", password=\"changethis\")\n", + "res = datasite_client.register(\n", + " email=\"new_syft_user@openmined.org\",\n", + " password=\"verysecurepassword\",\n", + " password_verify=\"verysecurepassword\",\n", + " name=\"New User\",\n", + ")\n", + "\n", + "if not isinstance(res, SyftSuccess):\n", + " raise Exception(f\"Res isn't SyftSuccess, its {res}\")" + ] + }, + { + "cell_type": "markdown", + "id": "5", + "metadata": {}, + "source": [ + "### Ask for a password reset - Notifier disabled Workflow" + ] + }, + { + "cell_type": "markdown", + "id": "6", + "metadata": {}, + "source": [ + "### Call for users.forgot_password" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "guest_client = server.login_as_guest()\n", + "res = guest_client.users.forgot_password(email=\"new_syft_user@openmined.org\")\n", + "\n", + "if not isinstance(res, SyftSuccess):\n", + " raise Exception(f\"Res isn't SyftSuccess, its {res}\")" + ] + }, + { + "cell_type": "markdown", + "id": "8", + "metadata": {}, + "source": [ + "### Admin generates a temp token" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "temp_token = datasite_client.users.request_password_reset(\n", + " datasite_client.notifications[-1].linked_obj.resolve.id\n", + ")\n", + "\n", + "if not isinstance(temp_token, str):\n", + " raise Exception(f\"temp_token isn't a string, its {temp_token}\")" + ] + }, + { + "cell_type": "markdown", + "id": "10", + "metadata": {}, + "source": [ + "### User use this token to reset password" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "res = guest_client.users.reset_password(token=temp_token, new_password=\"Password123\")\n", + "\n", + "if not isinstance(res, SyftSuccess):\n", + " raise Exception(f\"Res isn't SyftSuccess, its {res}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "new_user_session = server.login(\n", + " email=\"new_syft_user@openmined.org\", password=\"Password123\"\n", + ")\n", + "\n", + "if isinstance(new_user_session, SyftError):\n", + " raise Exception(f\"Res isn't SyftSuccess, its {new_user_session}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index 0eb298b0648..9cf7d42efd0 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -4,5 +4,51 @@ }, "2": { "release_name": "0.8.8.json" + }, + "dev": { + "object_versions": { + "User": { + "2": { + "version": 2, + "hash": "af6fb5b2e1606e97838f4a60f0536ad95db606d455e94acbd1977df866608a2c", + "action": "add" + } + }, + "UserNotificationActivity": { + "1": { + "version": 1, + "hash": "422fd01c6d9af38688a9982abd34e80794a1f6ddd444cca225d77f49189847a9", + "action": "add" + } + }, + "NotifierSettings": { + "2": { + "version": 2, + "hash": "be8b52597fc628d1b7cd22b776ee81416e1adbb04a45188778eb0e32ed1416b4", + "action": "add" + } + }, + "PwdTokenResetConfig": { + "1": { + "version": 1, + "hash": "0415a272428f22add4896c64aa9f29c8c1d35619e2433da6564eb5f1faff39ac", + "action": "add" + } + }, + "ServerSettingsUpdate": { + "3": { + "version": 3, + "hash": "335c7946f2e52d09c7b26f511120cd340717c74c5cca9107e84f839da993c55c", + "action": "add" + } + }, + "ServerSettings": { + "3": { + "version": 3, + "hash": "997667e1cba22d151857aacc2caba6b1ca73c1648adbd03461dc74a0c0c372b3", + "action": "add" + } + } + } } } diff --git a/packages/syft/src/syft/service/notification/email_templates.py b/packages/syft/src/syft/service/notification/email_templates.py index f8baceee38a..fec2810b02a 100644 --- a/packages/syft/src/syft/service/notification/email_templates.py +++ b/packages/syft/src/syft/service/notification/email_templates.py @@ -1,4 +1,5 @@ # stdlib +from datetime import datetime from typing import TYPE_CHECKING from typing import cast @@ -22,6 +23,91 @@ def email_body(notification: "Notification", context: AuthedServiceContext) -> s return "" +@serializable(canonical_name="PasswordResetTemplate", version=1) +class PasswordResetTemplate(EmailTemplate): + @staticmethod + def email_title(notification: "Notification", context: AuthedServiceContext) -> str: + return "Password Reset Requested" + + @staticmethod + def email_body(notification: "Notification", context: AuthedServiceContext) -> str: + user_service = context.server.get_service("userservice") + + user = user_service.get_by_verify_key(notification.to_user_verify_key) + if not user: + raise Exception("User not found!") + + user.reset_token = user_service.generate_new_password_reset_token( + context.server.settings.pwd_token_config + ) + user.reset_token_date = datetime.now() + + result = user_service.stash.update( + credentials=context.credentials, user=user, has_permission=True + ) + if result.is_err(): + raise Exception("Couldn't update the user password") + + head = """ + + """ + body = f""" +
+

Password Reset

+

We received a request to reset your password. Your new temporary token is:

+

{user.reset_token}

+

Use + + syft_client.reset_password(token='{user.reset_token}', new_password=*****) + . + to reset your password.

+

If you didn't request a password reset, please ignore this email.

+
+ """ + return f"""{head} {body}""" + + @serializable(canonical_name="OnboardEmailTemplate", version=1) class OnBoardEmailTemplate(EmailTemplate): @staticmethod diff --git a/packages/syft/src/syft/service/notifier/notifier.py b/packages/syft/src/syft/service/notifier/notifier.py index 6cc24a3d720..d9cdef82c2c 100644 --- a/packages/syft/src/syft/service/notifier/notifier.py +++ b/packages/syft/src/syft/service/notifier/notifier.py @@ -6,8 +6,9 @@ # 2) .....settings().x_enabled # 2) .....user_settings().x - # stdlib +from collections.abc import Callable +from datetime import datetime from typing import TypeVar # third party @@ -18,8 +19,12 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey +from ...types.syft_migration import migrate from ...types.syft_object import SYFT_OBJECT_VERSION_1 +from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SyftObject +from ...types.transforms import drop +from ...types.transforms import make_set_default from ..context import AuthedServiceContext from ..notification.notifications import Notification from ..response import SyftError @@ -38,6 +43,14 @@ def send( TBaseNotifier = TypeVar("TBaseNotifier", bound=BaseNotifier) +@serializable() +class UserNotificationActivity(SyftObject): + __canonical_name__ = "UserNotificationActivity" + __version__ = SYFT_OBJECT_VERSION_1 + count: int = 1 + date: datetime = datetime.now() + + @serializable(canonical_name="EmailNotifier", version=1) class EmailNotifier(BaseNotifier): smtp_client: SMTPClient @@ -131,7 +144,7 @@ class NotificationPreferences(SyftObject): @serializable() -class NotifierSettings(SyftObject): +class NotifierSettingsV1(SyftObject): __canonical_name__ = "NotifierSettings" __version__ = SYFT_OBJECT_VERSION_1 __repr_attrs__ = [ @@ -139,6 +152,34 @@ class NotifierSettings(SyftObject): "email_enabled", ] active: bool = False + + notifiers: dict[NOTIFIERS, type[TBaseNotifier]] = { + NOTIFIERS.EMAIL: EmailNotifier, + } + + notifiers_status: dict[NOTIFIERS, bool] = { + NOTIFIERS.EMAIL: True, + NOTIFIERS.SMS: False, + NOTIFIERS.SLACK: False, + NOTIFIERS.APP: False, + } + + email_sender: str | None = "" + email_server: str | None = "" + email_port: int | None = 587 + email_username: str | None = "" + email_password: str | None = "" + + +@serializable() +class NotifierSettings(SyftObject): + __canonical_name__ = "NotifierSettings" + __version__ = SYFT_OBJECT_VERSION_2 + __repr_attrs__ = [ + "active", + "email_enabled", + ] + active: bool = False # Flag to identify which notification is enabled # For now, consider only the email notification # In future, Admin, must be able to have a better @@ -161,6 +202,9 @@ class NotifierSettings(SyftObject): email_username: str | None = "" email_password: str | None = "" + email_activity: dict[str, dict[SyftVerifyKey, UserNotificationActivity]] = {} + email_rate_limit: dict[str, int] = {} + @property def email_enabled(self) -> bool: return self.notifiers_status[NOTIFIERS.EMAIL] @@ -237,3 +281,17 @@ def select_notifiers(self, notification: Notification) -> list[BaseNotifier]: notifier_objs.append(self.notifiers[notifier_type]()) # type: ignore[misc] return notifier_objs + + +@migrate(NotifierSettingsV1, NotifierSettings) +def migrate_server_settings_v1_to_current() -> list[Callable]: + return [ + make_set_default("email_activity", {}), + make_set_default("email_rate_limit", {}), + ] + + +@migrate(NotifierSettings, NotifierSettingsV1) +def migrate_server_settings_v2_to_v1() -> list[Callable]: + # Use drop function on "notifications_enabled" attrubute + return [drop(["email_activity"]), drop(["email_rate_limit"])] diff --git a/packages/syft/src/syft/service/notifier/notifier_enums.py b/packages/syft/src/syft/service/notifier/notifier_enums.py index 023843f7d6c..f8c2d887ff4 100644 --- a/packages/syft/src/syft/service/notifier/notifier_enums.py +++ b/packages/syft/src/syft/service/notifier/notifier_enums.py @@ -6,6 +6,14 @@ from ...serde.serializable import serializable +@serializable(canonical_name="EMAIL_TYPES", version=1) +class EMAIL_TYPES(Enum): + PASSWORD_RESET_EMAIL = "PasswordResetTemplate" # nosec + ONBOARD_EMAIL = "OnBoardEmailTemplate" + REQUEST_EMAIL = "RequestEmailTemplate" + REQUEST_UPDATE_EMAIL = "RequestUpdateEmailTemplate" + + @serializable(canonical_name="NOTIFIERS", version=1) class NOTIFIERS(Enum): EMAIL = auto() diff --git a/packages/syft/src/syft/service/notifier/notifier_service.py b/packages/syft/src/syft/service/notifier/notifier_service.py index 391bf16abd4..c8c09ba3d50 100644 --- a/packages/syft/src/syft/service/notifier/notifier_service.py +++ b/packages/syft/src/syft/service/notifier/notifier_service.py @@ -1,4 +1,5 @@ # stdlib +from datetime import datetime import logging import traceback @@ -13,12 +14,15 @@ from ...serde.serializable import serializable from ...store.document_store import DocumentStore from ..context import AuthedServiceContext +from ..notification.email_templates import PasswordResetTemplate from ..notification.notifications import Notification from ..response import SyftError from ..response import SyftSuccess from ..service import AbstractService from .notifier import NotificationPreferences from .notifier import NotifierSettings +from .notifier import UserNotificationActivity +from .notifier_enums import EMAIL_TYPES from .notifier_enums import NOTIFIERS from .notifier_stash import NotifierStash @@ -188,6 +192,10 @@ def turn_on( message="You must provide a sender email address to enable notifications." ) + # If email_rate_limit isn't defined yet. + if not notifier.email_rate_limit: + notifier.email_rate_limit = {PasswordResetTemplate.__name__: 3} + if email_sender: try: EmailStr._validate(email_sender) @@ -320,6 +328,8 @@ def init_notifier( notifier.email_sender = email_sender notifier.email_server = smtp_host notifier.email_port = smtp_port + # Default daily email rate limit per user + notifier.email_rate_limit = {PasswordResetTemplate.__name__: 3} notifier.active = True notifier_stash.set(server.signing_key.verify_key, notifier) @@ -328,6 +338,22 @@ def init_notifier( except Exception: raise Exception(f"Error initializing notifier. \n {traceback.format_exc()}") + def set_email_rate_limit( + self, context: AuthedServiceContext, email_type: EMAIL_TYPES, daily_limit: int + ) -> SyftSuccess | SyftError: + notifier = self.stash.get(context.credentials) + if notifier.is_err(): + return SyftError(message="Couldn't set the email rate limit.") + + notifier = notifier.ok() + + notifier.email_rate_limit[email_type.value] = daily_limit + result = self.stash.update(credentials=context.credentials, settings=notifier) + if result.is_err(): + return SyftError(message="Couldn't update the notifier.") + + return SyftSuccess(message="Email rate limit updated!") + # This is not a public API. # This method is used by other services to dispatch notifications internally def dispatch_notification( @@ -343,7 +369,51 @@ def dispatch_notification( notifier = notifier.ok() # If notifier is active - if notifier.active: + if notifier.active and notification.email_template is not None: + logging.debug("Checking user email activity") + if notifier.email_activity.get(notification.email_template.__name__, None): + user_activity = notifier.email_activity[ + notification.email_template.__name__ + ].get(notification.to_user_verify_key, None) + # If there's no user activity + if user_activity is None: + notifier.email_activity[notification.email_template.__name__][ + notification.to_user_verify_key, None + ] = UserNotificationActivity(count=1, date=datetime.now()) + else: # If there's a previous user activity + current_state: UserNotificationActivity = notifier.email_activity[ + notification.email_template.__name__ + ][notification.to_user_verify_key] + date_refresh = abs(datetime.now() - current_state.date).days > 1 + + limit = notifier.email_rate_limit.get( + notification.email_template.__name__, 0 + ) + still_in_limit = current_state.count < limit + # Time interval reseted. + if date_refresh: + current_state.count = 1 + current_state.date = datetime.now() + # Time interval didn't reset yet. + elif still_in_limit or not limit: + current_state.count += 1 + current_state.date = datetime.now() + else: + return SyftError( + message="Couldn't send the email. You have surpassed the" + + " email threshold limit. Please try again later." + ) + else: + notifier.email_activity[notification.email_template.__name__] = { + notification.to_user_verify_key: UserNotificationActivity( + count=1, date=datetime.now() + ) + } + + result = self.stash.update(credentials=admin_key, settings=notifier) + if result.is_err(): + return SyftError(message="Couldn't update the notifier.") + resp = notifier.send_notifications( context=context, notification=notification ) diff --git a/packages/syft/src/syft/service/settings/settings.py b/packages/syft/src/syft/service/settings/settings.py index bf3369d7cb9..67720658c80 100644 --- a/packages/syft/src/syft/service/settings/settings.py +++ b/packages/syft/src/syft/service/settings/settings.py @@ -3,6 +3,11 @@ import logging from typing import Any +# third party +from pydantic import field_validator +from pydantic import model_validator +from typing_extensions import Self + # relative from ...abstract_server import ServerSideType from ...abstract_server import ServerType @@ -13,6 +18,7 @@ from ...types.syft_object import PartialSyftObject from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.syft_object import SYFT_OBJECT_VERSION_2 +from ...types.syft_object import SYFT_OBJECT_VERSION_3 from ...types.syft_object import SyftObject from ...types.transforms import drop from ...types.transforms import make_set_default @@ -26,6 +32,32 @@ logger = logging.getLogger(__name__) +@serializable() +class PwdTokenResetConfig(SyftObject): + __canonical_name__ = "PwdTokenResetConfig" + __version__ = SYFT_OBJECT_VERSION_1 + ascii: bool = True + numbers: bool = True + token_len: int = 12 + token_exp_min: int = 30 + + @model_validator(mode="after") + def validate_char_types(self) -> Self: + if not self.ascii and not self.numbers: + raise ValueError( + "Invalid config, at least one of the ascii/number options must be true." + ) + + return self + + @field_validator("token_len") + @classmethod + def check_token_len(cls, value: int) -> int: + if value < 4: + raise ValueError("Token length must be greater than 4.") + return value + + @serializable() class ServerSettingsUpdateV1(PartialSyftObject): __canonical_name__ = "ServerSettingsUpdate" @@ -43,7 +75,7 @@ class ServerSettingsUpdateV1(PartialSyftObject): @serializable() -class ServerSettingsUpdate(PartialSyftObject): +class ServerSettingsUpdateV2(PartialSyftObject): __canonical_name__ = "ServerSettingsUpdate" __version__ = SYFT_OBJECT_VERSION_2 id: UID @@ -59,6 +91,24 @@ class ServerSettingsUpdate(PartialSyftObject): notifications_enabled: bool +@serializable() +class ServerSettingsUpdate(PartialSyftObject): + __canonical_name__ = "ServerSettingsUpdate" + __version__ = SYFT_OBJECT_VERSION_3 + id: UID + name: str + organization: str + description: str + on_board: bool + signup_enabled: bool + admin_email: str + association_request_auto_approval: bool + welcome_markdown: HTMLObject | MarkdownDescription + eager_execution_enabled: bool + notifications_enabled: bool + pwd_token_config: PwdTokenResetConfig + + @serializable() class ServerSettingsV1(SyftObject): __canonical_name__ = "ServerSettings" @@ -93,7 +143,7 @@ class ServerSettingsV1(SyftObject): @serializable() -class ServerSettings(SyftObject): +class ServerSettingsV2(SyftObject): __canonical_name__ = "ServerSettings" __version__ = SYFT_OBJECT_VERSION_2 __repr_attrs__ = [ @@ -125,6 +175,41 @@ class ServerSettings(SyftObject): ) notifications_enabled: bool + +@serializable() +class ServerSettings(SyftObject): + __canonical_name__ = "ServerSettings" + __version__ = SYFT_OBJECT_VERSION_3 + __repr_attrs__ = [ + "name", + "organization", + "description", + "deployed_on", + "signup_enabled", + "admin_email", + ] + + id: UID + name: str = "Server" + deployed_on: str + organization: str = "OpenMined" + verify_key: SyftVerifyKey + on_board: bool = True + description: str = "This is the default description for a Datasite Server." + server_type: ServerType = ServerType.DATASITE + signup_enabled: bool + admin_email: str + server_side_type: ServerSideType = ServerSideType.HIGH_SIDE + show_warnings: bool + association_request_auto_approval: bool + eager_execution_enabled: bool = False + default_worker_pool: str = DEFAULT_WORKER_POOL_NAME + welcome_markdown: HTMLObject | MarkdownDescription = HTMLObject( + text=DEFAULT_WELCOME_MSG + ) + notifications_enabled: bool + pwd_token_config: PwdTokenResetConfig = PwdTokenResetConfig() + def _repr_html_(self) -> Any: # .api.services.notifications.settings() is how the server itself would dispatch notifications. # .api.services.notifications.user_settings() sets if a specific user wants or not to receive notifications. @@ -176,22 +261,53 @@ def _repr_html_(self) -> Any: """ -@migrate(ServerSettingsV1, ServerSettings) +# Server Settings Migration + + +# set +@migrate(ServerSettingsV1, ServerSettingsV2) def migrate_server_settings_v1_to_v2() -> list[Callable]: return [make_set_default("notifications_enabled", False)] -@migrate(ServerSettings, ServerSettingsV1) +@migrate(ServerSettingsV2, ServerSettings) +def migrate_server_settings_v2_to_current() -> list[Callable]: + return [make_set_default("pwd_token_config", PwdTokenResetConfig())] + + +# drop +@migrate(ServerSettingsV2, ServerSettingsV1) def migrate_server_settings_v2_to_v1() -> list[Callable]: # Use drop function on "notifications_enabled" attrubute return [drop(["notifications_enabled"])] -@migrate(ServerSettingsUpdateV1, ServerSettingsUpdate) +@migrate(ServerSettings, ServerSettingsV2) +def migrate_server_settings_current_to_v2() -> list[Callable]: + # Use drop function on "notifications_enabled" attrubute + return [drop(["pwd_token_config"])] + + +# Server Settings Update Migration + + +# set +@migrate(ServerSettingsUpdateV1, ServerSettingsUpdateV2) def migrate_server_settings_update_v1_to_v2() -> list[Callable]: return [make_set_default("notifications_enabled", False)] -@migrate(ServerSettingsUpdate, ServerSettingsUpdateV1) +@migrate(ServerSettingsUpdateV2, ServerSettingsUpdate) +def migrate_server_settings_update_v2_to_current() -> list[Callable]: + return [make_set_default("pwd_token_config", PwdTokenResetConfig())] + + +# drop +@migrate(ServerSettingsUpdateV2, ServerSettingsUpdateV1) def migrate_server_settings_update_v2_to_v1() -> list[Callable]: return [drop(["notifications_enabled"])] + + +@migrate(ServerSettingsUpdate, ServerSettingsUpdateV2) +def migrate_server_settings_update_current_to_v2() -> list[Callable]: + return [drop(["pwd_token_config"])] diff --git a/packages/syft/src/syft/service/settings/settings_service.py b/packages/syft/src/syft/service/settings/settings_service.py index 17d54df1dc0..b54abacd078 100644 --- a/packages/syft/src/syft/service/settings/settings_service.py +++ b/packages/syft/src/syft/service/settings/settings_service.py @@ -20,6 +20,7 @@ from ...util.schema import GUEST_COMMANDS from ..context import AuthedServiceContext from ..context import UnauthedServiceContext +from ..notifier.notifier_enums import EMAIL_TYPES from ..response import SyftError from ..response import SyftSuccess from ..service import AbstractService @@ -68,7 +69,12 @@ def set( else: return SyftError(message=result.err()) - @service_method(path="settings.update", name="update", autosplat=["settings"]) + @service_method( + path="settings.update", + name="update", + autosplat=["settings"], + roles=ADMIN_ROLE_LEVEL, + ) def update( self, context: AuthedServiceContext, settings: ServerSettingsUpdate ) -> Result[SyftSuccess, SyftError]: @@ -254,6 +260,13 @@ def enable_eager_execution( message = "enabled" if enable else "disabled" return SyftSuccess(message=f"Eager execution {message}") + @service_method(path="settings.set_email_rate_limit", name="set_email_rate_limit") + def set_email_rate_limit( + self, context: AuthedServiceContext, email_type: EMAIL_TYPES, daily_limit: int + ) -> SyftSuccess | SyftError: + notifier_service = context.server.get_service("notifierservice") + return notifier_service.set_email_rate_limit(context, email_type, daily_limit) + @service_method( path="settings.allow_association_request_auto_approval", name="allow_association_request_auto_approval", diff --git a/packages/syft/src/syft/service/user/user.py b/packages/syft/src/syft/service/user/user.py index 383ed3c5f6a..cb9b8860c03 100644 --- a/packages/syft/src/syft/service/user/user.py +++ b/packages/syft/src/syft/service/user/user.py @@ -1,6 +1,8 @@ # stdlib from collections.abc import Callable +from datetime import datetime from getpass import getpass +import re from typing import Any # third party @@ -17,8 +19,10 @@ from ...server.credentials import SyftSigningKey from ...server.credentials import SyftVerifyKey from ...types.syft_metaclass import Empty +from ...types.syft_migration import migrate from ...types.syft_object import PartialSyftObject from ...types.syft_object import SYFT_OBJECT_VERSION_1 +from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SyftObject from ...types.transforms import TransformContext from ...types.transforms import drop @@ -35,7 +39,7 @@ @serializable() -class User(SyftObject): +class UserV1(SyftObject): # version __canonical_name__ = "User" __version__ = SYFT_OBJECT_VERSION_1 @@ -68,10 +72,77 @@ class User(SyftObject): __repr_attrs__ = ["name", "email"] +@serializable() +class User(SyftObject): + # version + __canonical_name__ = "User" + __version__ = SYFT_OBJECT_VERSION_2 + + id: UID | None = None # type: ignore[assignment] + + # fields + notifications_enabled: dict[NOTIFIERS, bool] = { + NOTIFIERS.EMAIL: True, + NOTIFIERS.SMS: False, + NOTIFIERS.SLACK: False, + NOTIFIERS.APP: False, + } + email: EmailStr | None = None + name: str | None = None + hashed_password: str | None = None + salt: str | None = None + signing_key: SyftSigningKey | None = None + verify_key: SyftVerifyKey | None = None + role: ServiceRole | None = None + institution: str | None = None + website: str | None = None + created_at: str | None = None + # TODO where do we put this flag? + mock_execution_permission: bool = False + reset_token: str | None = None + reset_token_date: datetime | None = None + # serde / storage rules + __attr_searchable__ = ["name", "email", "verify_key", "role", "reset_token"] + __attr_unique__ = ["email", "signing_key", "verify_key"] + __repr_attrs__ = ["name", "email"] + + +@migrate(UserV1, User) +def migrate_server_user_update_v1_current() -> list[Callable]: + return [ + make_set_default("reset_token", None), + make_set_default("reset_token_date", None), + drop("__attr_searchable__"), + make_set_default( + "__attr_searchable__", + ["name", "email", "verify_key", "role", "reset_token"], + ), + ] + + +@migrate(User, UserV1) +def migrate_server_user_downgrade_current_v1() -> list[Callable]: + return [ + drop("reset_token"), + drop("reset_token_date"), + drop("__attr_searchable__"), + make_set_default( + "__attr_searchable__", ["name", "email", "verify_key", "role"] + ), + ] + + def default_role(role: ServiceRole) -> Callable: return make_set_default(key="role", value=role) +def validate_password(password: str) -> bool: + # Define the regex pattern for the password + pattern = re.compile(r"^(?=.*[a-z])(?=.*[A-Z])(?=.*\d).{8,}$") + + return bool(pattern.match(password)) + + def hash_password(context: TransformContext) -> TransformContext: if context.output is None: return context @@ -325,6 +396,24 @@ def user_create_to_user() -> list[Callable]: ] +@transform(UserV1, UserView) +def userv1_to_view_user() -> list[Callable]: + return [ + keep( + [ + "id", + "email", + "name", + "role", + "institution", + "website", + "mock_execution_permission", + "notifications_enabled", + ] + ) + ] + + @transform(User, UserView) def user_to_view_user() -> list[Callable]: return [ @@ -353,6 +442,11 @@ class UserPrivateKey(SyftObject): role: ServiceRole +@transform(UserV1, UserPrivateKey) +def userv1_to_user_verify() -> list[Callable]: + return [keep(["email", "signing_key", "id", "role"])] + + @transform(User, UserPrivateKey) def user_to_user_verify() -> list[Callable]: return [keep(["email", "signing_key", "id", "role"])] diff --git a/packages/syft/src/syft/service/user/user_service.py b/packages/syft/src/syft/service/user/user_service.py index 584ed6cce48..336f94b5e42 100644 --- a/packages/syft/src/syft/service/user/user_service.py +++ b/packages/syft/src/syft/service/user/user_service.py @@ -1,4 +1,8 @@ # stdlib +from datetime import datetime +from datetime import timedelta +import secrets +import string # relative from ...abstract_server import ServerType @@ -17,6 +21,7 @@ from ..context import ServerServiceContext from ..context import UnauthedServiceContext from ..notification.email_templates import OnBoardEmailTemplate +from ..notification.email_templates import PasswordResetTemplate from ..notification.notification_service import CreateNotification from ..notification.notification_service import NotificationService from ..notifier.notifier_enums import NOTIFIERS @@ -26,6 +31,7 @@ from ..service import SERVICE_TO_TYPES from ..service import TYPE_TO_SERVICE from ..service import service_method +from ..settings.settings import PwdTokenResetConfig from ..settings.settings_stash import SettingsStash from .user import User from .user import UserCreate @@ -36,6 +42,7 @@ from .user import UserViewPage from .user import check_pwd from .user import salt_and_hash_password +from .user import validate_password from .user_roles import ADMIN_ROLE_LEVEL from .user_roles import DATA_OWNER_ROLE_LEVEL from .user_roles import DATA_SCIENTIST_ROLE_LEVEL @@ -84,6 +91,206 @@ def create( user = result.ok() return user.to(UserView) + @service_method( + path="user.forgot_password", name="forgot_password", roles=GUEST_ROLE_LEVEL + ) + def forgot_password( + self, context: AuthedServiceContext, email: str + ) -> SyftSuccess | SyftError: + success_msg = ( + "If the email is valid, we sent a password " + + "reset token to your email or a password request to the admin." + ) + result = self.stash.get_by_email(credentials=context.credentials, email=email) + # Isn't a valid email + if result.is_err(): + return SyftSuccess(message=success_msg) + user = result.ok() + + user_role = self.get_role_for_credentials(user.verify_key) + if user_role == ServiceRole.ADMIN: + return SyftError( + message="You can't request password reset for an Admin user." + ) + + # Email is valid + # Notifications disabled + # We should just sent a notification to the admin/user about password reset + # Notifications Enabled + # Instead of changing the password here, we would change it in email template generation. + root_key = self.admin_verify_key() + root_context = AuthedServiceContext(server=context.server, credentials=root_key) + link = LinkedObject.with_context(user, context=root_context) + notifier_service = context.server.get_service("notifierservice") + # Notifier is active + notifier = notifier_service.settings(context=root_context) + notification_is_enabled = notifier.active + # Email is enabled + email_is_enabled = notifier.email_enabled + # User Preferences allow email notification + user_allow_email_notifications = user.notifications_enabled[NOTIFIERS.EMAIL] + + # This checks if the user will safely receive the email reset. + not_receive_emails = ( + not notification_is_enabled + or not email_is_enabled + or not user_allow_email_notifications + ) + + # If notifier service is not enabled. + if not_receive_emails: + message = CreateNotification( + subject="You requested password reset.", + from_user_verify_key=root_key, + to_user_verify_key=user.verify_key, + linked_obj=link, + ) + + method = context.server.get_service_method(NotificationService.send) + result = method(context=root_context, notification=message) + + message = CreateNotification( + subject="User requested password reset.", + from_user_verify_key=user.verify_key, + to_user_verify_key=root_key, + linked_obj=link, + ) + + result = method(context=root_context, notification=message) + if isinstance(result, SyftError): + return result + else: + # Email notification is Enabled + # Therefore, we can directly send a message to the + # user with its new password. + message = CreateNotification( + subject="You requested a password reset.", + from_user_verify_key=root_key, + to_user_verify_key=user.verify_key, + linked_obj=link, + notifier_types=[NOTIFIERS.EMAIL], + email_template=PasswordResetTemplate, + ) + + method = context.server.get_service_method(NotificationService.send) + result = method(context=root_context, notification=message) + if isinstance(result, SyftError): + return result + + return SyftSuccess(message=success_msg) + + @service_method( + path="user.request_password_reset", + name="request_password_reset", + roles=ADMIN_ROLE_LEVEL, + ) + def request_password_reset( + self, context: AuthedServiceContext, uid: UID + ) -> str | SyftError: + result = self.stash.get_by_uid(credentials=context.credentials, uid=uid) + if result.is_err(): + return SyftError( + message=( + f"Failed to retrieve user with UID: {uid}. Error: {str(result.err())}" + ) + ) + user = result.ok() + if user is None: + return SyftError(message=f"No user exists for given: {uid}") + + user_role = self.get_role_for_credentials(user.verify_key) + if user_role == ServiceRole.ADMIN: + return SyftError( + message="You can't request password reset for an Admin user." + ) + + user.reset_token = self.generate_new_password_reset_token( + context.server.settings.pwd_token_config + ) + user.reset_token_date = datetime.now() + + result = self.stash.update( + credentials=context.credentials, user=user, has_permission=True + ) + if result.is_err(): + return SyftError( + message=( + f"Failed to update user with UID: {uid}. Error: {str(result.err())}" + ) + ) + + return user.reset_token + + @service_method( + path="user.reset_password", name="reset_password", roles=GUEST_ROLE_LEVEL + ) + def reset_password( + self, context: AuthedServiceContext, token: str, new_password: str + ) -> SyftSuccess | SyftError: + """Resets a certain user password using a temporary token.""" + result = self.stash.get_by_reset_token( + credentials=context.credentials, token=token + ) + invalid_token_error = SyftError( + message=("Failed to reset user password. Token is invalid or expired!") + ) + + if result.is_err(): + return SyftError(message="Failed to reset user password.") + + user = result.ok() + + # If token isn't found + if user is None: + return invalid_token_error + + now = datetime.now() + time_difference = now - user.reset_token_date + + # If token expired + expiration_time = context.server.settings.pwd_token_config.token_exp_min + if time_difference > timedelta(minutes=expiration_time): + return invalid_token_error + + if not validate_password(new_password): + return SyftError( + message="Your new password must have at least 8 \ + characters, Upper case and lower case characters\ + and at least one number." + ) + + salt, hashed = salt_and_hash_password(new_password, 12) + user.hashed_password = hashed + user.salt = salt + + user.reset_token = None + user.reset_token_date = None + + result = self.stash.update( + credentials=context.credentials, user=user, has_permission=True + ) + if result.is_err(): + return SyftError( + message=(f"Failed to update user password. Error: {str(result.err())}") + ) + return SyftSuccess(message="User Password updated successfully!") + + def generate_new_password_reset_token( + self, token_config: PwdTokenResetConfig + ) -> str: + valid_characters = "" + if token_config.ascii: + valid_characters += string.ascii_letters + + if token_config.numbers: + valid_characters += string.digits + + generated_token = "".join( + secrets.choice(valid_characters) for _ in range(token_config.token_len) + ) + + return generated_token + @service_method(path="user.view", name="view", roles=DATA_SCIENTIST_ROLE_LEVEL) def view( self, context: AuthedServiceContext, uid: UID diff --git a/packages/syft/src/syft/service/user/user_stash.py b/packages/syft/src/syft/service/user/user_stash.py index 2b2b42db9e8..894d9a65115 100644 --- a/packages/syft/src/syft/service/user/user_stash.py +++ b/packages/syft/src/syft/service/user/user_stash.py @@ -23,6 +23,7 @@ # 🟡 TODO 27: it would be nice if these could be defined closer to the User EmailPartitionKey = PartitionKey(key="email", type_=str) +PasswordResetTokenPartitionKey = PartitionKey(key="reset_token", type_=str) RolePartitionKey = PartitionKey(key="role", type_=ServiceRole) SigningKeyPartitionKey = PartitionKey(key="signing_key", type_=SyftSigningKey) VerifyKeyPartitionKey = PartitionKey(key="verify_key", type_=SyftVerifyKey) @@ -74,6 +75,12 @@ def get_by_uid( qks = QueryKeys(qks=[UIDPartitionKey.with_obj(uid)]) return self.query_one(credentials=credentials, qks=qks) + def get_by_reset_token( + self, credentials: SyftVerifyKey, token: str + ) -> Result[User | None, str]: + qks = QueryKeys(qks=[PasswordResetTokenPartitionKey.with_obj(token)]) + return self.query_one(credentials=credentials, qks=qks) + def get_by_email( self, credentials: SyftVerifyKey, email: str ) -> Result[User | None, str]: diff --git a/packages/syft/src/syft/types/syft_object_registry.py b/packages/syft/src/syft/types/syft_object_registry.py index d5cc342635e..3d0548f6cf1 100644 --- a/packages/syft/src/syft/types/syft_object_registry.py +++ b/packages/syft/src/syft/types/syft_object_registry.py @@ -131,7 +131,10 @@ def get_transform( klass_from = type_from_mro.__name__ version_from = None for type_to_mro in type_to.mro(): - if issubclass(type_to_mro, SyftBaseObject): + if ( + issubclass(type_to_mro, SyftBaseObject) + and type_to_mro != SyftBaseObject + ): klass_to = type_to_mro.__canonical_name__ version_to = type_to_mro.__version__ else: diff --git a/tox.ini b/tox.ini index 5fd5cf7fc6f..4e5ae2d2001 100644 --- a/tox.ini +++ b/tox.ini @@ -1078,7 +1078,7 @@ commands = description = Prepare Migration Data pip_pre = True deps = - syft==0.8.7 + syft==0.8.8 nbmake allowlist_externals = bash