diff --git a/notebooks/api/0.8/13-forgot-user-password.ipynb b/notebooks/api/0.8/13-forgot-user-password.ipynb
new file mode 100644
index 00000000000..8ad3cdf0918
--- /dev/null
+++ b/notebooks/api/0.8/13-forgot-user-password.ipynb
@@ -0,0 +1,181 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "0",
+ "metadata": {},
+ "source": [
+ "# Forgot User Password"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1",
+ "metadata": {},
+ "source": [
+ "## Initialize the server"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# stdlib\n",
+ "\n",
+ "# syft absolute\n",
+ "import syft as sy\n",
+ "from syft import SyftError\n",
+ "from syft import SyftSuccess\n",
+ "\n",
+ "server = sy.orchestra.launch(\n",
+ " name=\"test-datasite-1\",\n",
+ " dev_mode=True,\n",
+ " create_producer=True,\n",
+ " n_consumers=3,\n",
+ " reset=True,\n",
+ " port=8081,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3",
+ "metadata": {},
+ "source": [
+ "## Register a new user"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "datasite_client = server.login(email=\"info@openmined.org\", password=\"changethis\")\n",
+ "res = datasite_client.register(\n",
+ " email=\"new_syft_user@openmined.org\",\n",
+ " password=\"verysecurepassword\",\n",
+ " password_verify=\"verysecurepassword\",\n",
+ " name=\"New User\",\n",
+ ")\n",
+ "\n",
+ "if not isinstance(res, SyftSuccess):\n",
+ " raise Exception(f\"Res isn't SyftSuccess, its {res}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5",
+ "metadata": {},
+ "source": [
+ "### Ask for a password reset - Notifier disabled Workflow"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6",
+ "metadata": {},
+ "source": [
+ "### Call for users.forgot_password"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "7",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "guest_client = server.login_as_guest()\n",
+ "res = guest_client.users.forgot_password(email=\"new_syft_user@openmined.org\")\n",
+ "\n",
+ "if not isinstance(res, SyftSuccess):\n",
+ " raise Exception(f\"Res isn't SyftSuccess, its {res}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "8",
+ "metadata": {},
+ "source": [
+ "### Admin generates a temp token"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "9",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "temp_token = datasite_client.users.request_password_reset(\n",
+ " datasite_client.notifications[-1].linked_obj.resolve.id\n",
+ ")\n",
+ "\n",
+ "if not isinstance(temp_token, str):\n",
+ " raise Exception(f\"temp_token isn't a string, its {temp_token}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "10",
+ "metadata": {},
+ "source": [
+ "### User use this token to reset password"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "11",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "res = guest_client.users.reset_password(token=temp_token, new_password=\"Password123\")\n",
+ "\n",
+ "if not isinstance(res, SyftSuccess):\n",
+ " raise Exception(f\"Res isn't SyftSuccess, its {res}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "12",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "new_user_session = server.login(\n",
+ " email=\"new_syft_user@openmined.org\", password=\"Password123\"\n",
+ ")\n",
+ "\n",
+ "if isinstance(new_user_session, SyftError):\n",
+ " raise Exception(f\"Res isn't SyftSuccess, its {new_user_session}\")"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.4"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json
index 0eb298b0648..9cf7d42efd0 100644
--- a/packages/syft/src/syft/protocol/protocol_version.json
+++ b/packages/syft/src/syft/protocol/protocol_version.json
@@ -4,5 +4,51 @@
},
"2": {
"release_name": "0.8.8.json"
+ },
+ "dev": {
+ "object_versions": {
+ "User": {
+ "2": {
+ "version": 2,
+ "hash": "af6fb5b2e1606e97838f4a60f0536ad95db606d455e94acbd1977df866608a2c",
+ "action": "add"
+ }
+ },
+ "UserNotificationActivity": {
+ "1": {
+ "version": 1,
+ "hash": "422fd01c6d9af38688a9982abd34e80794a1f6ddd444cca225d77f49189847a9",
+ "action": "add"
+ }
+ },
+ "NotifierSettings": {
+ "2": {
+ "version": 2,
+ "hash": "be8b52597fc628d1b7cd22b776ee81416e1adbb04a45188778eb0e32ed1416b4",
+ "action": "add"
+ }
+ },
+ "PwdTokenResetConfig": {
+ "1": {
+ "version": 1,
+ "hash": "0415a272428f22add4896c64aa9f29c8c1d35619e2433da6564eb5f1faff39ac",
+ "action": "add"
+ }
+ },
+ "ServerSettingsUpdate": {
+ "3": {
+ "version": 3,
+ "hash": "335c7946f2e52d09c7b26f511120cd340717c74c5cca9107e84f839da993c55c",
+ "action": "add"
+ }
+ },
+ "ServerSettings": {
+ "3": {
+ "version": 3,
+ "hash": "997667e1cba22d151857aacc2caba6b1ca73c1648adbd03461dc74a0c0c372b3",
+ "action": "add"
+ }
+ }
+ }
}
}
diff --git a/packages/syft/src/syft/service/notification/email_templates.py b/packages/syft/src/syft/service/notification/email_templates.py
index f8baceee38a..fec2810b02a 100644
--- a/packages/syft/src/syft/service/notification/email_templates.py
+++ b/packages/syft/src/syft/service/notification/email_templates.py
@@ -1,4 +1,5 @@
# stdlib
+from datetime import datetime
from typing import TYPE_CHECKING
from typing import cast
@@ -22,6 +23,91 @@ def email_body(notification: "Notification", context: AuthedServiceContext) -> s
return ""
+@serializable(canonical_name="PasswordResetTemplate", version=1)
+class PasswordResetTemplate(EmailTemplate):
+ @staticmethod
+ def email_title(notification: "Notification", context: AuthedServiceContext) -> str:
+ return "Password Reset Requested"
+
+ @staticmethod
+ def email_body(notification: "Notification", context: AuthedServiceContext) -> str:
+ user_service = context.server.get_service("userservice")
+
+ user = user_service.get_by_verify_key(notification.to_user_verify_key)
+ if not user:
+ raise Exception("User not found!")
+
+ user.reset_token = user_service.generate_new_password_reset_token(
+ context.server.settings.pwd_token_config
+ )
+ user.reset_token_date = datetime.now()
+
+ result = user_service.stash.update(
+ credentials=context.credentials, user=user, has_permission=True
+ )
+ if result.is_err():
+ raise Exception("Couldn't update the user password")
+
+ head = """
+
+ """
+ body = f"""
+
+
Password Reset
+
We received a request to reset your password. Your new temporary token is:
+
{user.reset_token}
+
Use
+
+ syft_client.reset_password(token='{user.reset_token}', new_password=*****)
+
.
+ to reset your password.
+
If you didn't request a password reset, please ignore this email.
+
+ """
+ return f"""{head} {body}"""
+
+
@serializable(canonical_name="OnboardEmailTemplate", version=1)
class OnBoardEmailTemplate(EmailTemplate):
@staticmethod
diff --git a/packages/syft/src/syft/service/notifier/notifier.py b/packages/syft/src/syft/service/notifier/notifier.py
index 6cc24a3d720..d9cdef82c2c 100644
--- a/packages/syft/src/syft/service/notifier/notifier.py
+++ b/packages/syft/src/syft/service/notifier/notifier.py
@@ -6,8 +6,9 @@
# 2) .....settings().x_enabled
# 2) .....user_settings().x
-
# stdlib
+from collections.abc import Callable
+from datetime import datetime
from typing import TypeVar
# third party
@@ -18,8 +19,12 @@
# relative
from ...serde.serializable import serializable
from ...server.credentials import SyftVerifyKey
+from ...types.syft_migration import migrate
from ...types.syft_object import SYFT_OBJECT_VERSION_1
+from ...types.syft_object import SYFT_OBJECT_VERSION_2
from ...types.syft_object import SyftObject
+from ...types.transforms import drop
+from ...types.transforms import make_set_default
from ..context import AuthedServiceContext
from ..notification.notifications import Notification
from ..response import SyftError
@@ -38,6 +43,14 @@ def send(
TBaseNotifier = TypeVar("TBaseNotifier", bound=BaseNotifier)
+@serializable()
+class UserNotificationActivity(SyftObject):
+ __canonical_name__ = "UserNotificationActivity"
+ __version__ = SYFT_OBJECT_VERSION_1
+ count: int = 1
+ date: datetime = datetime.now()
+
+
@serializable(canonical_name="EmailNotifier", version=1)
class EmailNotifier(BaseNotifier):
smtp_client: SMTPClient
@@ -131,7 +144,7 @@ class NotificationPreferences(SyftObject):
@serializable()
-class NotifierSettings(SyftObject):
+class NotifierSettingsV1(SyftObject):
__canonical_name__ = "NotifierSettings"
__version__ = SYFT_OBJECT_VERSION_1
__repr_attrs__ = [
@@ -139,6 +152,34 @@ class NotifierSettings(SyftObject):
"email_enabled",
]
active: bool = False
+
+ notifiers: dict[NOTIFIERS, type[TBaseNotifier]] = {
+ NOTIFIERS.EMAIL: EmailNotifier,
+ }
+
+ notifiers_status: dict[NOTIFIERS, bool] = {
+ NOTIFIERS.EMAIL: True,
+ NOTIFIERS.SMS: False,
+ NOTIFIERS.SLACK: False,
+ NOTIFIERS.APP: False,
+ }
+
+ email_sender: str | None = ""
+ email_server: str | None = ""
+ email_port: int | None = 587
+ email_username: str | None = ""
+ email_password: str | None = ""
+
+
+@serializable()
+class NotifierSettings(SyftObject):
+ __canonical_name__ = "NotifierSettings"
+ __version__ = SYFT_OBJECT_VERSION_2
+ __repr_attrs__ = [
+ "active",
+ "email_enabled",
+ ]
+ active: bool = False
# Flag to identify which notification is enabled
# For now, consider only the email notification
# In future, Admin, must be able to have a better
@@ -161,6 +202,9 @@ class NotifierSettings(SyftObject):
email_username: str | None = ""
email_password: str | None = ""
+ email_activity: dict[str, dict[SyftVerifyKey, UserNotificationActivity]] = {}
+ email_rate_limit: dict[str, int] = {}
+
@property
def email_enabled(self) -> bool:
return self.notifiers_status[NOTIFIERS.EMAIL]
@@ -237,3 +281,17 @@ def select_notifiers(self, notification: Notification) -> list[BaseNotifier]:
notifier_objs.append(self.notifiers[notifier_type]()) # type: ignore[misc]
return notifier_objs
+
+
+@migrate(NotifierSettingsV1, NotifierSettings)
+def migrate_server_settings_v1_to_current() -> list[Callable]:
+ return [
+ make_set_default("email_activity", {}),
+ make_set_default("email_rate_limit", {}),
+ ]
+
+
+@migrate(NotifierSettings, NotifierSettingsV1)
+def migrate_server_settings_v2_to_v1() -> list[Callable]:
+ # Use drop function on "notifications_enabled" attrubute
+ return [drop(["email_activity"]), drop(["email_rate_limit"])]
diff --git a/packages/syft/src/syft/service/notifier/notifier_enums.py b/packages/syft/src/syft/service/notifier/notifier_enums.py
index 023843f7d6c..f8c2d887ff4 100644
--- a/packages/syft/src/syft/service/notifier/notifier_enums.py
+++ b/packages/syft/src/syft/service/notifier/notifier_enums.py
@@ -6,6 +6,14 @@
from ...serde.serializable import serializable
+@serializable(canonical_name="EMAIL_TYPES", version=1)
+class EMAIL_TYPES(Enum):
+ PASSWORD_RESET_EMAIL = "PasswordResetTemplate" # nosec
+ ONBOARD_EMAIL = "OnBoardEmailTemplate"
+ REQUEST_EMAIL = "RequestEmailTemplate"
+ REQUEST_UPDATE_EMAIL = "RequestUpdateEmailTemplate"
+
+
@serializable(canonical_name="NOTIFIERS", version=1)
class NOTIFIERS(Enum):
EMAIL = auto()
diff --git a/packages/syft/src/syft/service/notifier/notifier_service.py b/packages/syft/src/syft/service/notifier/notifier_service.py
index 391bf16abd4..c8c09ba3d50 100644
--- a/packages/syft/src/syft/service/notifier/notifier_service.py
+++ b/packages/syft/src/syft/service/notifier/notifier_service.py
@@ -1,4 +1,5 @@
# stdlib
+from datetime import datetime
import logging
import traceback
@@ -13,12 +14,15 @@
from ...serde.serializable import serializable
from ...store.document_store import DocumentStore
from ..context import AuthedServiceContext
+from ..notification.email_templates import PasswordResetTemplate
from ..notification.notifications import Notification
from ..response import SyftError
from ..response import SyftSuccess
from ..service import AbstractService
from .notifier import NotificationPreferences
from .notifier import NotifierSettings
+from .notifier import UserNotificationActivity
+from .notifier_enums import EMAIL_TYPES
from .notifier_enums import NOTIFIERS
from .notifier_stash import NotifierStash
@@ -188,6 +192,10 @@ def turn_on(
message="You must provide a sender email address to enable notifications."
)
+ # If email_rate_limit isn't defined yet.
+ if not notifier.email_rate_limit:
+ notifier.email_rate_limit = {PasswordResetTemplate.__name__: 3}
+
if email_sender:
try:
EmailStr._validate(email_sender)
@@ -320,6 +328,8 @@ def init_notifier(
notifier.email_sender = email_sender
notifier.email_server = smtp_host
notifier.email_port = smtp_port
+ # Default daily email rate limit per user
+ notifier.email_rate_limit = {PasswordResetTemplate.__name__: 3}
notifier.active = True
notifier_stash.set(server.signing_key.verify_key, notifier)
@@ -328,6 +338,22 @@ def init_notifier(
except Exception:
raise Exception(f"Error initializing notifier. \n {traceback.format_exc()}")
+ def set_email_rate_limit(
+ self, context: AuthedServiceContext, email_type: EMAIL_TYPES, daily_limit: int
+ ) -> SyftSuccess | SyftError:
+ notifier = self.stash.get(context.credentials)
+ if notifier.is_err():
+ return SyftError(message="Couldn't set the email rate limit.")
+
+ notifier = notifier.ok()
+
+ notifier.email_rate_limit[email_type.value] = daily_limit
+ result = self.stash.update(credentials=context.credentials, settings=notifier)
+ if result.is_err():
+ return SyftError(message="Couldn't update the notifier.")
+
+ return SyftSuccess(message="Email rate limit updated!")
+
# This is not a public API.
# This method is used by other services to dispatch notifications internally
def dispatch_notification(
@@ -343,7 +369,51 @@ def dispatch_notification(
notifier = notifier.ok()
# If notifier is active
- if notifier.active:
+ if notifier.active and notification.email_template is not None:
+ logging.debug("Checking user email activity")
+ if notifier.email_activity.get(notification.email_template.__name__, None):
+ user_activity = notifier.email_activity[
+ notification.email_template.__name__
+ ].get(notification.to_user_verify_key, None)
+ # If there's no user activity
+ if user_activity is None:
+ notifier.email_activity[notification.email_template.__name__][
+ notification.to_user_verify_key, None
+ ] = UserNotificationActivity(count=1, date=datetime.now())
+ else: # If there's a previous user activity
+ current_state: UserNotificationActivity = notifier.email_activity[
+ notification.email_template.__name__
+ ][notification.to_user_verify_key]
+ date_refresh = abs(datetime.now() - current_state.date).days > 1
+
+ limit = notifier.email_rate_limit.get(
+ notification.email_template.__name__, 0
+ )
+ still_in_limit = current_state.count < limit
+ # Time interval reseted.
+ if date_refresh:
+ current_state.count = 1
+ current_state.date = datetime.now()
+ # Time interval didn't reset yet.
+ elif still_in_limit or not limit:
+ current_state.count += 1
+ current_state.date = datetime.now()
+ else:
+ return SyftError(
+ message="Couldn't send the email. You have surpassed the"
+ + " email threshold limit. Please try again later."
+ )
+ else:
+ notifier.email_activity[notification.email_template.__name__] = {
+ notification.to_user_verify_key: UserNotificationActivity(
+ count=1, date=datetime.now()
+ )
+ }
+
+ result = self.stash.update(credentials=admin_key, settings=notifier)
+ if result.is_err():
+ return SyftError(message="Couldn't update the notifier.")
+
resp = notifier.send_notifications(
context=context, notification=notification
)
diff --git a/packages/syft/src/syft/service/settings/settings.py b/packages/syft/src/syft/service/settings/settings.py
index bf3369d7cb9..67720658c80 100644
--- a/packages/syft/src/syft/service/settings/settings.py
+++ b/packages/syft/src/syft/service/settings/settings.py
@@ -3,6 +3,11 @@
import logging
from typing import Any
+# third party
+from pydantic import field_validator
+from pydantic import model_validator
+from typing_extensions import Self
+
# relative
from ...abstract_server import ServerSideType
from ...abstract_server import ServerType
@@ -13,6 +18,7 @@
from ...types.syft_object import PartialSyftObject
from ...types.syft_object import SYFT_OBJECT_VERSION_1
from ...types.syft_object import SYFT_OBJECT_VERSION_2
+from ...types.syft_object import SYFT_OBJECT_VERSION_3
from ...types.syft_object import SyftObject
from ...types.transforms import drop
from ...types.transforms import make_set_default
@@ -26,6 +32,32 @@
logger = logging.getLogger(__name__)
+@serializable()
+class PwdTokenResetConfig(SyftObject):
+ __canonical_name__ = "PwdTokenResetConfig"
+ __version__ = SYFT_OBJECT_VERSION_1
+ ascii: bool = True
+ numbers: bool = True
+ token_len: int = 12
+ token_exp_min: int = 30
+
+ @model_validator(mode="after")
+ def validate_char_types(self) -> Self:
+ if not self.ascii and not self.numbers:
+ raise ValueError(
+ "Invalid config, at least one of the ascii/number options must be true."
+ )
+
+ return self
+
+ @field_validator("token_len")
+ @classmethod
+ def check_token_len(cls, value: int) -> int:
+ if value < 4:
+ raise ValueError("Token length must be greater than 4.")
+ return value
+
+
@serializable()
class ServerSettingsUpdateV1(PartialSyftObject):
__canonical_name__ = "ServerSettingsUpdate"
@@ -43,7 +75,7 @@ class ServerSettingsUpdateV1(PartialSyftObject):
@serializable()
-class ServerSettingsUpdate(PartialSyftObject):
+class ServerSettingsUpdateV2(PartialSyftObject):
__canonical_name__ = "ServerSettingsUpdate"
__version__ = SYFT_OBJECT_VERSION_2
id: UID
@@ -59,6 +91,24 @@ class ServerSettingsUpdate(PartialSyftObject):
notifications_enabled: bool
+@serializable()
+class ServerSettingsUpdate(PartialSyftObject):
+ __canonical_name__ = "ServerSettingsUpdate"
+ __version__ = SYFT_OBJECT_VERSION_3
+ id: UID
+ name: str
+ organization: str
+ description: str
+ on_board: bool
+ signup_enabled: bool
+ admin_email: str
+ association_request_auto_approval: bool
+ welcome_markdown: HTMLObject | MarkdownDescription
+ eager_execution_enabled: bool
+ notifications_enabled: bool
+ pwd_token_config: PwdTokenResetConfig
+
+
@serializable()
class ServerSettingsV1(SyftObject):
__canonical_name__ = "ServerSettings"
@@ -93,7 +143,7 @@ class ServerSettingsV1(SyftObject):
@serializable()
-class ServerSettings(SyftObject):
+class ServerSettingsV2(SyftObject):
__canonical_name__ = "ServerSettings"
__version__ = SYFT_OBJECT_VERSION_2
__repr_attrs__ = [
@@ -125,6 +175,41 @@ class ServerSettings(SyftObject):
)
notifications_enabled: bool
+
+@serializable()
+class ServerSettings(SyftObject):
+ __canonical_name__ = "ServerSettings"
+ __version__ = SYFT_OBJECT_VERSION_3
+ __repr_attrs__ = [
+ "name",
+ "organization",
+ "description",
+ "deployed_on",
+ "signup_enabled",
+ "admin_email",
+ ]
+
+ id: UID
+ name: str = "Server"
+ deployed_on: str
+ organization: str = "OpenMined"
+ verify_key: SyftVerifyKey
+ on_board: bool = True
+ description: str = "This is the default description for a Datasite Server."
+ server_type: ServerType = ServerType.DATASITE
+ signup_enabled: bool
+ admin_email: str
+ server_side_type: ServerSideType = ServerSideType.HIGH_SIDE
+ show_warnings: bool
+ association_request_auto_approval: bool
+ eager_execution_enabled: bool = False
+ default_worker_pool: str = DEFAULT_WORKER_POOL_NAME
+ welcome_markdown: HTMLObject | MarkdownDescription = HTMLObject(
+ text=DEFAULT_WELCOME_MSG
+ )
+ notifications_enabled: bool
+ pwd_token_config: PwdTokenResetConfig = PwdTokenResetConfig()
+
def _repr_html_(self) -> Any:
# .api.services.notifications.settings() is how the server itself would dispatch notifications.
# .api.services.notifications.user_settings() sets if a specific user wants or not to receive notifications.
@@ -176,22 +261,53 @@ def _repr_html_(self) -> Any:
"""
-@migrate(ServerSettingsV1, ServerSettings)
+# Server Settings Migration
+
+
+# set
+@migrate(ServerSettingsV1, ServerSettingsV2)
def migrate_server_settings_v1_to_v2() -> list[Callable]:
return [make_set_default("notifications_enabled", False)]
-@migrate(ServerSettings, ServerSettingsV1)
+@migrate(ServerSettingsV2, ServerSettings)
+def migrate_server_settings_v2_to_current() -> list[Callable]:
+ return [make_set_default("pwd_token_config", PwdTokenResetConfig())]
+
+
+# drop
+@migrate(ServerSettingsV2, ServerSettingsV1)
def migrate_server_settings_v2_to_v1() -> list[Callable]:
# Use drop function on "notifications_enabled" attrubute
return [drop(["notifications_enabled"])]
-@migrate(ServerSettingsUpdateV1, ServerSettingsUpdate)
+@migrate(ServerSettings, ServerSettingsV2)
+def migrate_server_settings_current_to_v2() -> list[Callable]:
+ # Use drop function on "notifications_enabled" attrubute
+ return [drop(["pwd_token_config"])]
+
+
+# Server Settings Update Migration
+
+
+# set
+@migrate(ServerSettingsUpdateV1, ServerSettingsUpdateV2)
def migrate_server_settings_update_v1_to_v2() -> list[Callable]:
return [make_set_default("notifications_enabled", False)]
-@migrate(ServerSettingsUpdate, ServerSettingsUpdateV1)
+@migrate(ServerSettingsUpdateV2, ServerSettingsUpdate)
+def migrate_server_settings_update_v2_to_current() -> list[Callable]:
+ return [make_set_default("pwd_token_config", PwdTokenResetConfig())]
+
+
+# drop
+@migrate(ServerSettingsUpdateV2, ServerSettingsUpdateV1)
def migrate_server_settings_update_v2_to_v1() -> list[Callable]:
return [drop(["notifications_enabled"])]
+
+
+@migrate(ServerSettingsUpdate, ServerSettingsUpdateV2)
+def migrate_server_settings_update_current_to_v2() -> list[Callable]:
+ return [drop(["pwd_token_config"])]
diff --git a/packages/syft/src/syft/service/settings/settings_service.py b/packages/syft/src/syft/service/settings/settings_service.py
index 17d54df1dc0..b54abacd078 100644
--- a/packages/syft/src/syft/service/settings/settings_service.py
+++ b/packages/syft/src/syft/service/settings/settings_service.py
@@ -20,6 +20,7 @@
from ...util.schema import GUEST_COMMANDS
from ..context import AuthedServiceContext
from ..context import UnauthedServiceContext
+from ..notifier.notifier_enums import EMAIL_TYPES
from ..response import SyftError
from ..response import SyftSuccess
from ..service import AbstractService
@@ -68,7 +69,12 @@ def set(
else:
return SyftError(message=result.err())
- @service_method(path="settings.update", name="update", autosplat=["settings"])
+ @service_method(
+ path="settings.update",
+ name="update",
+ autosplat=["settings"],
+ roles=ADMIN_ROLE_LEVEL,
+ )
def update(
self, context: AuthedServiceContext, settings: ServerSettingsUpdate
) -> Result[SyftSuccess, SyftError]:
@@ -254,6 +260,13 @@ def enable_eager_execution(
message = "enabled" if enable else "disabled"
return SyftSuccess(message=f"Eager execution {message}")
+ @service_method(path="settings.set_email_rate_limit", name="set_email_rate_limit")
+ def set_email_rate_limit(
+ self, context: AuthedServiceContext, email_type: EMAIL_TYPES, daily_limit: int
+ ) -> SyftSuccess | SyftError:
+ notifier_service = context.server.get_service("notifierservice")
+ return notifier_service.set_email_rate_limit(context, email_type, daily_limit)
+
@service_method(
path="settings.allow_association_request_auto_approval",
name="allow_association_request_auto_approval",
diff --git a/packages/syft/src/syft/service/user/user.py b/packages/syft/src/syft/service/user/user.py
index 383ed3c5f6a..cb9b8860c03 100644
--- a/packages/syft/src/syft/service/user/user.py
+++ b/packages/syft/src/syft/service/user/user.py
@@ -1,6 +1,8 @@
# stdlib
from collections.abc import Callable
+from datetime import datetime
from getpass import getpass
+import re
from typing import Any
# third party
@@ -17,8 +19,10 @@
from ...server.credentials import SyftSigningKey
from ...server.credentials import SyftVerifyKey
from ...types.syft_metaclass import Empty
+from ...types.syft_migration import migrate
from ...types.syft_object import PartialSyftObject
from ...types.syft_object import SYFT_OBJECT_VERSION_1
+from ...types.syft_object import SYFT_OBJECT_VERSION_2
from ...types.syft_object import SyftObject
from ...types.transforms import TransformContext
from ...types.transforms import drop
@@ -35,7 +39,7 @@
@serializable()
-class User(SyftObject):
+class UserV1(SyftObject):
# version
__canonical_name__ = "User"
__version__ = SYFT_OBJECT_VERSION_1
@@ -68,10 +72,77 @@ class User(SyftObject):
__repr_attrs__ = ["name", "email"]
+@serializable()
+class User(SyftObject):
+ # version
+ __canonical_name__ = "User"
+ __version__ = SYFT_OBJECT_VERSION_2
+
+ id: UID | None = None # type: ignore[assignment]
+
+ # fields
+ notifications_enabled: dict[NOTIFIERS, bool] = {
+ NOTIFIERS.EMAIL: True,
+ NOTIFIERS.SMS: False,
+ NOTIFIERS.SLACK: False,
+ NOTIFIERS.APP: False,
+ }
+ email: EmailStr | None = None
+ name: str | None = None
+ hashed_password: str | None = None
+ salt: str | None = None
+ signing_key: SyftSigningKey | None = None
+ verify_key: SyftVerifyKey | None = None
+ role: ServiceRole | None = None
+ institution: str | None = None
+ website: str | None = None
+ created_at: str | None = None
+ # TODO where do we put this flag?
+ mock_execution_permission: bool = False
+ reset_token: str | None = None
+ reset_token_date: datetime | None = None
+ # serde / storage rules
+ __attr_searchable__ = ["name", "email", "verify_key", "role", "reset_token"]
+ __attr_unique__ = ["email", "signing_key", "verify_key"]
+ __repr_attrs__ = ["name", "email"]
+
+
+@migrate(UserV1, User)
+def migrate_server_user_update_v1_current() -> list[Callable]:
+ return [
+ make_set_default("reset_token", None),
+ make_set_default("reset_token_date", None),
+ drop("__attr_searchable__"),
+ make_set_default(
+ "__attr_searchable__",
+ ["name", "email", "verify_key", "role", "reset_token"],
+ ),
+ ]
+
+
+@migrate(User, UserV1)
+def migrate_server_user_downgrade_current_v1() -> list[Callable]:
+ return [
+ drop("reset_token"),
+ drop("reset_token_date"),
+ drop("__attr_searchable__"),
+ make_set_default(
+ "__attr_searchable__", ["name", "email", "verify_key", "role"]
+ ),
+ ]
+
+
def default_role(role: ServiceRole) -> Callable:
return make_set_default(key="role", value=role)
+def validate_password(password: str) -> bool:
+ # Define the regex pattern for the password
+ pattern = re.compile(r"^(?=.*[a-z])(?=.*[A-Z])(?=.*\d).{8,}$")
+
+ return bool(pattern.match(password))
+
+
def hash_password(context: TransformContext) -> TransformContext:
if context.output is None:
return context
@@ -325,6 +396,24 @@ def user_create_to_user() -> list[Callable]:
]
+@transform(UserV1, UserView)
+def userv1_to_view_user() -> list[Callable]:
+ return [
+ keep(
+ [
+ "id",
+ "email",
+ "name",
+ "role",
+ "institution",
+ "website",
+ "mock_execution_permission",
+ "notifications_enabled",
+ ]
+ )
+ ]
+
+
@transform(User, UserView)
def user_to_view_user() -> list[Callable]:
return [
@@ -353,6 +442,11 @@ class UserPrivateKey(SyftObject):
role: ServiceRole
+@transform(UserV1, UserPrivateKey)
+def userv1_to_user_verify() -> list[Callable]:
+ return [keep(["email", "signing_key", "id", "role"])]
+
+
@transform(User, UserPrivateKey)
def user_to_user_verify() -> list[Callable]:
return [keep(["email", "signing_key", "id", "role"])]
diff --git a/packages/syft/src/syft/service/user/user_service.py b/packages/syft/src/syft/service/user/user_service.py
index 584ed6cce48..336f94b5e42 100644
--- a/packages/syft/src/syft/service/user/user_service.py
+++ b/packages/syft/src/syft/service/user/user_service.py
@@ -1,4 +1,8 @@
# stdlib
+from datetime import datetime
+from datetime import timedelta
+import secrets
+import string
# relative
from ...abstract_server import ServerType
@@ -17,6 +21,7 @@
from ..context import ServerServiceContext
from ..context import UnauthedServiceContext
from ..notification.email_templates import OnBoardEmailTemplate
+from ..notification.email_templates import PasswordResetTemplate
from ..notification.notification_service import CreateNotification
from ..notification.notification_service import NotificationService
from ..notifier.notifier_enums import NOTIFIERS
@@ -26,6 +31,7 @@
from ..service import SERVICE_TO_TYPES
from ..service import TYPE_TO_SERVICE
from ..service import service_method
+from ..settings.settings import PwdTokenResetConfig
from ..settings.settings_stash import SettingsStash
from .user import User
from .user import UserCreate
@@ -36,6 +42,7 @@
from .user import UserViewPage
from .user import check_pwd
from .user import salt_and_hash_password
+from .user import validate_password
from .user_roles import ADMIN_ROLE_LEVEL
from .user_roles import DATA_OWNER_ROLE_LEVEL
from .user_roles import DATA_SCIENTIST_ROLE_LEVEL
@@ -84,6 +91,206 @@ def create(
user = result.ok()
return user.to(UserView)
+ @service_method(
+ path="user.forgot_password", name="forgot_password", roles=GUEST_ROLE_LEVEL
+ )
+ def forgot_password(
+ self, context: AuthedServiceContext, email: str
+ ) -> SyftSuccess | SyftError:
+ success_msg = (
+ "If the email is valid, we sent a password "
+ + "reset token to your email or a password request to the admin."
+ )
+ result = self.stash.get_by_email(credentials=context.credentials, email=email)
+ # Isn't a valid email
+ if result.is_err():
+ return SyftSuccess(message=success_msg)
+ user = result.ok()
+
+ user_role = self.get_role_for_credentials(user.verify_key)
+ if user_role == ServiceRole.ADMIN:
+ return SyftError(
+ message="You can't request password reset for an Admin user."
+ )
+
+ # Email is valid
+ # Notifications disabled
+ # We should just sent a notification to the admin/user about password reset
+ # Notifications Enabled
+ # Instead of changing the password here, we would change it in email template generation.
+ root_key = self.admin_verify_key()
+ root_context = AuthedServiceContext(server=context.server, credentials=root_key)
+ link = LinkedObject.with_context(user, context=root_context)
+ notifier_service = context.server.get_service("notifierservice")
+ # Notifier is active
+ notifier = notifier_service.settings(context=root_context)
+ notification_is_enabled = notifier.active
+ # Email is enabled
+ email_is_enabled = notifier.email_enabled
+ # User Preferences allow email notification
+ user_allow_email_notifications = user.notifications_enabled[NOTIFIERS.EMAIL]
+
+ # This checks if the user will safely receive the email reset.
+ not_receive_emails = (
+ not notification_is_enabled
+ or not email_is_enabled
+ or not user_allow_email_notifications
+ )
+
+ # If notifier service is not enabled.
+ if not_receive_emails:
+ message = CreateNotification(
+ subject="You requested password reset.",
+ from_user_verify_key=root_key,
+ to_user_verify_key=user.verify_key,
+ linked_obj=link,
+ )
+
+ method = context.server.get_service_method(NotificationService.send)
+ result = method(context=root_context, notification=message)
+
+ message = CreateNotification(
+ subject="User requested password reset.",
+ from_user_verify_key=user.verify_key,
+ to_user_verify_key=root_key,
+ linked_obj=link,
+ )
+
+ result = method(context=root_context, notification=message)
+ if isinstance(result, SyftError):
+ return result
+ else:
+ # Email notification is Enabled
+ # Therefore, we can directly send a message to the
+ # user with its new password.
+ message = CreateNotification(
+ subject="You requested a password reset.",
+ from_user_verify_key=root_key,
+ to_user_verify_key=user.verify_key,
+ linked_obj=link,
+ notifier_types=[NOTIFIERS.EMAIL],
+ email_template=PasswordResetTemplate,
+ )
+
+ method = context.server.get_service_method(NotificationService.send)
+ result = method(context=root_context, notification=message)
+ if isinstance(result, SyftError):
+ return result
+
+ return SyftSuccess(message=success_msg)
+
+ @service_method(
+ path="user.request_password_reset",
+ name="request_password_reset",
+ roles=ADMIN_ROLE_LEVEL,
+ )
+ def request_password_reset(
+ self, context: AuthedServiceContext, uid: UID
+ ) -> str | SyftError:
+ result = self.stash.get_by_uid(credentials=context.credentials, uid=uid)
+ if result.is_err():
+ return SyftError(
+ message=(
+ f"Failed to retrieve user with UID: {uid}. Error: {str(result.err())}"
+ )
+ )
+ user = result.ok()
+ if user is None:
+ return SyftError(message=f"No user exists for given: {uid}")
+
+ user_role = self.get_role_for_credentials(user.verify_key)
+ if user_role == ServiceRole.ADMIN:
+ return SyftError(
+ message="You can't request password reset for an Admin user."
+ )
+
+ user.reset_token = self.generate_new_password_reset_token(
+ context.server.settings.pwd_token_config
+ )
+ user.reset_token_date = datetime.now()
+
+ result = self.stash.update(
+ credentials=context.credentials, user=user, has_permission=True
+ )
+ if result.is_err():
+ return SyftError(
+ message=(
+ f"Failed to update user with UID: {uid}. Error: {str(result.err())}"
+ )
+ )
+
+ return user.reset_token
+
+ @service_method(
+ path="user.reset_password", name="reset_password", roles=GUEST_ROLE_LEVEL
+ )
+ def reset_password(
+ self, context: AuthedServiceContext, token: str, new_password: str
+ ) -> SyftSuccess | SyftError:
+ """Resets a certain user password using a temporary token."""
+ result = self.stash.get_by_reset_token(
+ credentials=context.credentials, token=token
+ )
+ invalid_token_error = SyftError(
+ message=("Failed to reset user password. Token is invalid or expired!")
+ )
+
+ if result.is_err():
+ return SyftError(message="Failed to reset user password.")
+
+ user = result.ok()
+
+ # If token isn't found
+ if user is None:
+ return invalid_token_error
+
+ now = datetime.now()
+ time_difference = now - user.reset_token_date
+
+ # If token expired
+ expiration_time = context.server.settings.pwd_token_config.token_exp_min
+ if time_difference > timedelta(minutes=expiration_time):
+ return invalid_token_error
+
+ if not validate_password(new_password):
+ return SyftError(
+ message="Your new password must have at least 8 \
+ characters, Upper case and lower case characters\
+ and at least one number."
+ )
+
+ salt, hashed = salt_and_hash_password(new_password, 12)
+ user.hashed_password = hashed
+ user.salt = salt
+
+ user.reset_token = None
+ user.reset_token_date = None
+
+ result = self.stash.update(
+ credentials=context.credentials, user=user, has_permission=True
+ )
+ if result.is_err():
+ return SyftError(
+ message=(f"Failed to update user password. Error: {str(result.err())}")
+ )
+ return SyftSuccess(message="User Password updated successfully!")
+
+ def generate_new_password_reset_token(
+ self, token_config: PwdTokenResetConfig
+ ) -> str:
+ valid_characters = ""
+ if token_config.ascii:
+ valid_characters += string.ascii_letters
+
+ if token_config.numbers:
+ valid_characters += string.digits
+
+ generated_token = "".join(
+ secrets.choice(valid_characters) for _ in range(token_config.token_len)
+ )
+
+ return generated_token
+
@service_method(path="user.view", name="view", roles=DATA_SCIENTIST_ROLE_LEVEL)
def view(
self, context: AuthedServiceContext, uid: UID
diff --git a/packages/syft/src/syft/service/user/user_stash.py b/packages/syft/src/syft/service/user/user_stash.py
index 2b2b42db9e8..894d9a65115 100644
--- a/packages/syft/src/syft/service/user/user_stash.py
+++ b/packages/syft/src/syft/service/user/user_stash.py
@@ -23,6 +23,7 @@
# 🟡 TODO 27: it would be nice if these could be defined closer to the User
EmailPartitionKey = PartitionKey(key="email", type_=str)
+PasswordResetTokenPartitionKey = PartitionKey(key="reset_token", type_=str)
RolePartitionKey = PartitionKey(key="role", type_=ServiceRole)
SigningKeyPartitionKey = PartitionKey(key="signing_key", type_=SyftSigningKey)
VerifyKeyPartitionKey = PartitionKey(key="verify_key", type_=SyftVerifyKey)
@@ -74,6 +75,12 @@ def get_by_uid(
qks = QueryKeys(qks=[UIDPartitionKey.with_obj(uid)])
return self.query_one(credentials=credentials, qks=qks)
+ def get_by_reset_token(
+ self, credentials: SyftVerifyKey, token: str
+ ) -> Result[User | None, str]:
+ qks = QueryKeys(qks=[PasswordResetTokenPartitionKey.with_obj(token)])
+ return self.query_one(credentials=credentials, qks=qks)
+
def get_by_email(
self, credentials: SyftVerifyKey, email: str
) -> Result[User | None, str]:
diff --git a/packages/syft/src/syft/types/syft_object_registry.py b/packages/syft/src/syft/types/syft_object_registry.py
index d5cc342635e..3d0548f6cf1 100644
--- a/packages/syft/src/syft/types/syft_object_registry.py
+++ b/packages/syft/src/syft/types/syft_object_registry.py
@@ -131,7 +131,10 @@ def get_transform(
klass_from = type_from_mro.__name__
version_from = None
for type_to_mro in type_to.mro():
- if issubclass(type_to_mro, SyftBaseObject):
+ if (
+ issubclass(type_to_mro, SyftBaseObject)
+ and type_to_mro != SyftBaseObject
+ ):
klass_to = type_to_mro.__canonical_name__
version_to = type_to_mro.__version__
else:
diff --git a/tox.ini b/tox.ini
index 5fd5cf7fc6f..4e5ae2d2001 100644
--- a/tox.ini
+++ b/tox.ini
@@ -1078,7 +1078,7 @@ commands =
description = Prepare Migration Data
pip_pre = True
deps =
- syft==0.8.7
+ syft==0.8.8
nbmake
allowlist_externals =
bash