Skip to content

Commit

Permalink
Merge pull request #555 from graphistry/fix/refresh_sso
Browse files Browse the repository at this point in the history
Attempt to fix refresh() for SSO
  • Loading branch information
aucahuasi authored Apr 6, 2024
2 parents d140bb3 + 56aaf31 commit 4cc316b
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 13 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm

## [Development]

## [0.33.7 - 2024-04-06]

* Fix refresh() for SSO

## [0.33.6 - 2024-04-05]

### Added
Expand Down
10 changes: 8 additions & 2 deletions graphistry/arrow_uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from graphistry.privacy import Mode, Privacy

from .ArrowFileUploader import ArrowFileUploader

from .exceptions import TokenExpireException
from .validate.validate_encodings import validate_encodings
from .util import setup_logger
logger = setup_logger(__name__)
Expand Down Expand Up @@ -362,10 +364,14 @@ def refresh(self, token=None):
try:
json_response = out.json()
if not ('token' in json_response):
if (
"non_field_errors" in json_response and "Token has expired." in json_response["non_field_errors"]
):
raise TokenExpireException(out.text)
raise Exception(out.text)
except Exception:
except Exception as e:
logger.error('Error: %s', out, exc_info=True)
raise Exception(out.text)
raise e

self.token = out.json()['token']
return self
Expand Down
7 changes: 7 additions & 0 deletions graphistry/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,10 @@ class SsoRetrieveTokenTimeoutException(SsoException):
Koa, 15 Sep 2022 Custom Exception to Sso retrieve token time out exception scenario
"""
pass


class TokenExpireException(Exception):
"""
Koa, 15 Mar 2024 Custom Exception for JWT Token expiry when refresh
"""
pass
40 changes: 29 additions & 11 deletions graphistry/pygraphistry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from . import bolt_util
from .plotter import Plotter
from .util import in_databricks, setup_logger, in_ipython, make_iframe
from .exceptions import SsoRetrieveTokenTimeoutException
from .exceptions import SsoRetrieveTokenTimeoutException, TokenExpireException

from .messages import (
MSG_REGISTER_MISSING_PASSWORD,
Expand Down Expand Up @@ -219,7 +219,6 @@ def sso_login(org_name=None, idp_name=None, sso_timeout=SSO_GET_TOKEN_ELAPSE_SEC
SSO Login logic.
"""

if PyGraphistry._config['store_token_creds_in_memory']:
PyGraphistry.relogin = lambda: PyGraphistry.sso_login(
org_name, idp_name, sso_timeout, sso_opt_into_type
Expand All @@ -232,22 +231,24 @@ def sso_login(org_name=None, idp_name=None, sso_timeout=SSO_GET_TOKEN_ELAPSE_SEC
+ PyGraphistry.server(), # noqa: W503
certificate_validation=PyGraphistry.certificate_validation(),
).sso_login(org_name, idp_name)

try:
# print(f"@sso_login - arrow_uploader.token: {arrow_uploader.token}")
if arrow_uploader.token:
PyGraphistry.api_token(arrow_uploader.token)
PyGraphistry._is_authenticated = True
arrow_uploader.token = None
return PyGraphistry.api_token()
except Exception: # required to log on
# print("required to log on")
except (Exception, TokenExpireException) as e: # required to log on

logger.debug(f"@sso_login - arrow_uploader.sso_state: {arrow_uploader.sso_state}")
PyGraphistry.sso_state(arrow_uploader.sso_state)

auth_url = arrow_uploader.sso_auth_url
# print("auth_url : {}".format(auth_url))
if auth_url and not PyGraphistry.api_token():

if auth_url:
PyGraphistry._handle_auth_url(auth_url, sso_timeout, sso_opt_into_type)
return auth_url
raise e

@staticmethod
def _handle_auth_url(auth_url, sso_timeout, sso_opt_into_type):
Expand All @@ -266,7 +267,6 @@ def _handle_auth_url(auth_url, sso_timeout, sso_opt_into_type):
SSO Login logic.
"""

if in_ipython() or in_databricks() or sso_opt_into_type == 'display': # If run in notebook, just display the HTML
# from IPython.core.display import HTML
from IPython.display import display, HTML
Expand Down Expand Up @@ -296,8 +296,10 @@ def _handle_auth_url(auth_url, sso_timeout, sso_opt_into_type):
try:
if not token:
if elapsed_time % 10 == 1:
print("Waiting for token : {} seconds ...".format(sso_timeout - elapsed_time + 1))

count_down = "Waiting for token : {} seconds ...".format(sso_timeout - elapsed_time + 1)
print(count_down)
from IPython.display import display, HTML
display(HTML(f'<strong>{count_down}</string>'))
time.sleep(1)
elapsed_time = elapsed_time + 1
if elapsed_time > sso_timeout:
Expand All @@ -316,10 +318,23 @@ def _handle_auth_url(auth_url, sso_timeout, sso_opt_into_type):
print("Successfully logged in")
return PyGraphistry.api_token()
else:
print("Please run graphistry.sso_get_token() to complete the authentication after you have authenticated via SSO")
return None
else:
print("Please run graphistry.sso_get_token() to complete the authentication")
# print("Start getting token ...")
# token = None
# for i in range(10):
# token, org_name = PyGraphistry._sso_get_token()
# if token:
# # set org_name to sso org
# PyGraphistry._config['org_name'] = org_name
# print("Successfully logged in")
# return PyGraphistry.api_token()
# print("Keep trying to get token ...")
# time.sleep(5)

print("Please run graphistry.sso_get_token() to complete the authentication")
return None

@staticmethod
def sso_get_token():
Expand Down Expand Up @@ -384,9 +399,12 @@ def refresh(token=None, fail_silent=False):
PyGraphistry._is_authenticated = True
return PyGraphistry.api_token()
except Exception as e:

if PyGraphistry.store_token_creds_in_memory():
logger.debug("JWT refresh via creds")
logger.debug("2. @PyGraphistry refresh :relogin")
if isinstance(e, TokenExpireException):
print("Token is expired, you need to relogin")
return PyGraphistry.relogin()

if not fail_silent:
Expand Down

0 comments on commit 4cc316b

Please sign in to comment.