Skip to content

Commit

Permalink
Fix #1636 Custom Values passed into correctly into Bot/Installation c…
Browse files Browse the repository at this point in the history
…lass when cloned during token rotation (#1638)
  • Loading branch information
seratch authored Jan 24, 2025
1 parent c9fff5d commit fcd0a35
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 12 deletions.
11 changes: 8 additions & 3 deletions slack_sdk/oauth/installation_store/models/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ def set_custom_value(self, name: str, value: Any):
def get_custom_value(self, name: str) -> Optional[Any]:
return self.custom_values.get(name)

def to_dict(self) -> Dict[str, Any]:
standard_values = {
def _to_standard_value_dict(self) -> Dict[str, Any]:
return {
"app_id": self.app_id,
"enterprise_id": self.enterprise_id,
"enterprise_name": self.enterprise_name,
Expand All @@ -105,6 +105,11 @@ def to_dict(self) -> Dict[str, Any]:
"is_enterprise_install": self.is_enterprise_install,
"installed_at": datetime.utcfromtimestamp(self.installed_at),
}

def to_dict_for_copying(self) -> Dict[str, Any]:
return {"custom_values": self.custom_values, **self._to_standard_value_dict()}

def to_dict(self) -> Dict[str, Any]:
# prioritize standard_values over custom_values
# when the same keys exist in both
return {**self.custom_values, **standard_values}
return {**self.custom_values, **self._to_standard_value_dict()}
11 changes: 8 additions & 3 deletions slack_sdk/oauth/installation_store/models/installation.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ def set_custom_value(self, name: str, value: Any):
def get_custom_value(self, name: str) -> Optional[Any]:
return self.custom_values.get(name)

def to_dict(self) -> Dict[str, Any]:
standard_values = {
def _to_standard_value_dict(self) -> Dict[str, Any]:
return {
"app_id": self.app_id,
"enterprise_id": self.enterprise_id,
"enterprise_name": self.enterprise_name,
Expand Down Expand Up @@ -190,6 +190,11 @@ def to_dict(self) -> Dict[str, Any]:
"token_type": self.token_type,
"installed_at": datetime.utcfromtimestamp(self.installed_at),
}

def to_dict_for_copying(self) -> Dict[str, Any]:
return {"custom_values": self.custom_values, **self._to_standard_value_dict()}

def to_dict(self) -> Dict[str, Any]:
# prioritize standard_values over custom_values
# when the same keys exist in both
return {**self.custom_values, **standard_values}
return {**self.custom_values, **self._to_standard_value_dict()}
6 changes: 3 additions & 3 deletions slack_sdk/oauth/token_rotation/async_rotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ async def perform_token_rotation(

if rotated_bot is not None:
if rotated_installation is None:
rotated_installation = Installation(**installation.to_dict())
rotated_installation = Installation(**installation.to_dict_for_copying())
rotated_installation.bot_token = rotated_bot.bot_token
rotated_installation.bot_refresh_token = rotated_bot.bot_refresh_token
rotated_installation.bot_token_expires_at = rotated_bot.bot_token_expires_at
Expand Down Expand Up @@ -93,7 +93,7 @@ async def perform_bot_token_rotation(
if refresh_response.get("token_type") != "bot":
return None

refreshed_bot = Bot(**bot.to_dict())
refreshed_bot = Bot(**bot.to_dict_for_copying())
refreshed_bot.bot_token = refresh_response["access_token"]
refreshed_bot.bot_refresh_token = refresh_response.get("refresh_token")
refreshed_bot.bot_token_expires_at = int(time()) + int(refresh_response["expires_in"])
Expand Down Expand Up @@ -132,7 +132,7 @@ async def perform_user_token_rotation(
if refresh_response.get("token_type") != "user":
return None

refreshed_installation = Installation(**installation.to_dict())
refreshed_installation = Installation(**installation.to_dict_for_copying())
refreshed_installation.user_token = refresh_response.get("access_token")
refreshed_installation.user_refresh_token = refresh_response.get("refresh_token")
refreshed_installation.user_token_expires_at = int(time()) + int(refresh_response.get("expires_in")) # type: ignore[arg-type] # noqa: E501
Expand Down
6 changes: 3 additions & 3 deletions slack_sdk/oauth/token_rotation/rotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def perform_token_rotation(

if rotated_bot is not None:
if rotated_installation is None:
rotated_installation = Installation(**installation.to_dict())
rotated_installation = Installation(**installation.to_dict_for_copying())
rotated_installation.bot_token = rotated_bot.bot_token
rotated_installation.bot_refresh_token = rotated_bot.bot_refresh_token
rotated_installation.bot_token_expires_at = rotated_bot.bot_token_expires_at
Expand Down Expand Up @@ -85,7 +85,7 @@ def perform_bot_token_rotation(
if refresh_response.get("token_type") != "bot":
return None

refreshed_bot = Bot(**bot.to_dict())
refreshed_bot = Bot(**bot.to_dict_for_copying())
refreshed_bot.bot_token = refresh_response["access_token"]
refreshed_bot.bot_refresh_token = refresh_response.get("refresh_token")
refreshed_bot.bot_token_expires_at = int(time()) + int(refresh_response["expires_in"])
Expand Down Expand Up @@ -125,7 +125,7 @@ def perform_user_token_rotation(
if refresh_response.get("token_type") != "user":
return None

refreshed_installation = Installation(**installation.to_dict())
refreshed_installation = Installation(**installation.to_dict_for_copying())
refreshed_installation.user_token = refresh_response.get("access_token")
refreshed_installation.user_refresh_token = refresh_response.get("refresh_token")
refreshed_installation.user_token_expires_at = int(time()) + int(refresh_response["expires_in"])
Expand Down
4 changes: 4 additions & 0 deletions tests/slack_sdk/oauth/installation_store/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def test_bot(self):
)
self.assertIsNotNone(bot)
self.assertIsNotNone(bot.to_dict())
self.assertIsNotNone(bot.to_dict_for_copying())

def test_bot_custom_fields(self):
bot = Bot(
Expand All @@ -33,6 +34,7 @@ def test_bot_custom_fields(self):
bot.set_custom_value("app_id", "A222")
self.assertEqual(bot.get_custom_value("service_user_id"), "XYZ123")
self.assertEqual(bot.to_dict().get("service_user_id"), "XYZ123")
self.assertEqual(bot.to_dict_for_copying().get("custom_values").get("service_user_id"), "XYZ123")

def test_installation(self):
installation = Installation(
Expand Down Expand Up @@ -73,10 +75,12 @@ def test_installation_custom_fields(self):
self.assertEqual(installation.get_custom_value("service_user_id"), "XYZ123")
self.assertEqual(installation.to_dict().get("service_user_id"), "XYZ123")
self.assertEqual(installation.to_dict().get("app_id"), "A111")
self.assertEqual(installation.to_dict_for_copying().get("custom_values").get("app_id"), "A222")

bot = installation.to_bot()
self.assertEqual(bot.app_id, "A111")
self.assertEqual(bot.get_custom_value("service_user_id"), "XYZ123")

self.assertEqual(bot.to_dict().get("app_id"), "A111")
self.assertEqual(bot.to_dict().get("service_user_id"), "XYZ123")
self.assertEqual(bot.to_dict_for_copying().get("custom_values").get("app_id"), "A222")
25 changes: 25 additions & 0 deletions tests/slack_sdk/oauth/token_rotation/test_token_rotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,31 @@ def test_refresh(self):
)
self.assertIsNone(should_not_be_refreshed)

def test_refresh_with_custom_values(self):
installation = Installation(
app_id="A111",
enterprise_id="E111",
team_id="T111",
user_id="U111",
bot_id="B111",
bot_token="xoxe.xoxp-1-initial",
bot_scopes=["chat:write"],
bot_user_id="U222",
bot_refresh_token="xoxe-1-initial",
bot_token_expires_in=43200,
custom_values={"foo": "bar"},
)
refreshed = self.token_rotator.perform_token_rotation(
installation=installation, minutes_before_expiration=60 * 24 * 365
)
self.assertIsNotNone(refreshed)
self.assertIsNotNone(refreshed.custom_values)

should_not_be_refreshed = self.token_rotator.perform_token_rotation(
installation=installation, minutes_before_expiration=1
)
self.assertIsNone(should_not_be_refreshed)

def test_token_rotation_disabled(self):
installation = Installation(
app_id="A111",
Expand Down
26 changes: 26 additions & 0 deletions tests/slack_sdk_async/oauth/token_rotation/test_token_rotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,32 @@ async def test_refresh(self):
)
self.assertIsNone(should_not_be_refreshed)

@async_test
async def test_refresh_with_custom_values(self):
installation = Installation(
app_id="A111",
enterprise_id="E111",
team_id="T111",
user_id="U111",
bot_id="B111",
bot_token="xoxe.xoxp-1-initial",
bot_scopes=["chat:write"],
bot_user_id="U222",
bot_refresh_token="xoxe-1-initial",
bot_token_expires_in=43200,
custom_values={"foo": "bar"},
)
refreshed = await self.token_rotator.perform_token_rotation(
installation=installation, minutes_before_expiration=60 * 24 * 365
)
self.assertIsNotNone(refreshed)
self.assertIsNotNone(refreshed.custom_values)

should_not_be_refreshed = await self.token_rotator.perform_token_rotation(
installation=installation, minutes_before_expiration=1
)
self.assertIsNone(should_not_be_refreshed)

@async_test
async def test_token_rotation_disabled(self):
installation = Installation(
Expand Down

0 comments on commit fcd0a35

Please sign in to comment.