Skip to content

Commit

Permalink
Chore: Make release 1.2.5
Browse files Browse the repository at this point in the history
  • Loading branch information
martinroberson committed Nov 20, 2024
1 parent b367f39 commit 20b500d
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 72 deletions.
36 changes: 34 additions & 2 deletions gs_quant/api/gs/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,38 @@ def wrapper(cls, *args, **kwargs):
return wrapper


def _cached_async(fn):
_fn_cache_lock = threading.Lock()
# short-term cache to avoid retrieving the same data several times in succession
fallback_cache: AssetCache = get_default_cache()

@wraps(fn)
async def wrapper(cls, *args, **kwargs):
if os.environ.get(ENABLE_ASSET_CACHING):
_logger.info("Asset caching is enabled")
asset_cache = cls.get_cache() or fallback_cache
k = asset_cache.construct_key(GsSession.current, fn.__name__, *args, **kwargs)
with Tracer("acquiring cache lock"):
_logger.debug('cache get: %s', k)
with _fn_cache_lock:
result = asset_cache.cache.get(GsSession.current, k)
if result:
_logger.debug('cache hit: %s', k)
return result
with Tracer("Executing function"):
result = await fn(cls, *args, **kwargs)
with Tracer("acquiring cache lock"):
_logger.debug('cache set: %s', k)
with _fn_cache_lock:
asset_cache.cache.put(GsSession.current, k, result, ttl=asset_cache.ttl)
else:
_logger.info("Asset caching is disabled, calling function")
result = await fn(cls, *args, **kwargs)
return result

return wrapper


class GsIdType(Enum):
"""GS Asset API identifier type enumeration"""

Expand Down Expand Up @@ -199,7 +231,7 @@ def get_many_assets(
return response['results']

@classmethod
@_cached
@_cached_async
async def get_many_assets_async(
cls,
fields: IdList = None,
Expand Down Expand Up @@ -322,7 +354,7 @@ def get_asset(
return GsSession.current._get('/assets/{id}'.format(id=asset_id), cls=GsAsset)

@classmethod
@_cached
@_cached_async
async def get_asset_async(
cls,
asset_id: str,
Expand Down
25 changes: 17 additions & 8 deletions gs_quant/api/gs/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,37 +190,39 @@ def _check_cache(cls, url, **kwargs):
return cached_val, cache_key, session

@classmethod
def _post_with_cache_check(cls, url, domain=None, **kwargs):
result, cache_key, session = cls._check_cache(url, **kwargs)
def _post_with_cache_check(cls, url, validator=lambda x: x, domain=None, **kwargs):
result, cache_key, session = cls._check_cache(url=url, **kwargs)
if result is None:
result = session._post(url, domain=domain, **kwargs)
result = validator(session._post(url, domain=domain, **kwargs))
if cls._api_request_cache:
cls._api_request_cache.put(session, cache_key, result)
return result

@classmethod
def _get_with_cache_check(cls, url, domain=None, **kwargs):
def _get_with_cache_check(cls, url, validator=lambda x: x, domain=None, **kwargs):
result, cache_key, session = cls._check_cache(url, **kwargs)
if result is None:
result = session._get(url, domain=domain, **kwargs)
result = validator(session._get(url, domain=domain, **kwargs))
if cls._api_request_cache:
cls._api_request_cache.put(session, cache_key, result)
return result

@classmethod
async def _get_with_cache_check_async(cls, url, domain=None, **kwargs):
async def _get_with_cache_check_async(cls, url, validator=lambda x: x, domain=None, **kwargs):
result, cache_key, session = cls._check_cache(url, **kwargs)
if result is None:
result = await session._get_async(url, domain=domain, **kwargs)
result = validator(result)
if cls._api_request_cache:
cls._api_request_cache.put(session, cache_key, result)
return result

@classmethod
async def _post_with_cache_check_async(cls, url, domain=None, **kwargs):
async def _post_with_cache_check_async(cls, url, validator=lambda x: x, domain=None, **kwargs):
result, cache_key, session = cls._check_cache(url, **kwargs)
if result is None:
result = await session._post_async(url, domain=domain, **kwargs)
result = validator(result)
if cls._api_request_cache:
cls._api_request_cache.put(session, cache_key, result)
return result
Expand Down Expand Up @@ -945,9 +947,16 @@ def get_data_providers(cls,

@classmethod
def get_market_data(cls, query, request_id=None, ignore_errors: bool = False) -> pd.DataFrame:
def validate(body):
for e in body['responses']:
container = e['queryResponse'][0]
if 'errorMessages' in container:
msg = f'measure service request {body["requestId"]} failed: {container["errorMessages"]}'
raise MqValueError(msg)
return body
start = time.perf_counter()
try:
body = cls._post_with_cache_check('/data/measures', payload=query)
body = cls._post_with_cache_check(url='/data/measures', validator=validate, payload=query)
except Exception as e:
log_warning(request_id, _logger, f'Market data query {query} failed due to {e}')
raise e
Expand Down
171 changes: 112 additions & 59 deletions gs_quant/test/data/test_data_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
specific language governing permissions and limitations
under the License.
"""

import datetime as dt
from unittest.mock import patch

Expand All @@ -23,15 +24,13 @@
from gs_quant.data import Dataset, DataContext


class NotExpectedToBeCalledSession():
class NotExpectedToBeCalledSession:
redirect_to_mds = True

@classmethod
def _get(cls, url, **kwargs):
if url == '/data/datasets/FXSPOT_STANDARD':
return {
"id": "FXSPOT_STANDARD"
}
if url == "/data/datasets/FXSPOT_STANDARD":
return {"id": "FXSPOT_STANDARD"}
else:
raise Exception("Not expecting to be called at this point")

Expand All @@ -40,71 +39,106 @@ def _post(cls, url, **kwargs):
raise Exception("Not expecting to be called at this point")


class FakeSession():
class MarketDataErrorSession:
redirect_to_mds = True

@classmethod
def _post(cls, url, **kwargs):
return {
"requestId": "890",
"responses": [{"queryResponse": [{"errorMessages": ["Test Failure"]}]}],
}


class FakeSession:
redirect_to_mds = True

@classmethod
def _get(cls, url, **kwargs):
if url == '/data/catalog/FXSPOT_STANDARD':
return {
'id': 'FXSPOT_STANDARD',
'fields': {
'date': {'type': 'string', 'format': 'date'},
'assetId': {'type': 'string'},
'spot': {'type': 'number', },
'updateTime': {'type': 'string', 'format': 'date-time'}
}
}
elif url == '/data/datasets/FXSPOT_STANDARD':
if url == "/data/catalog/FXSPOT_STANDARD":
return {
"id": "FXSPOT_STANDARD"
"id": "FXSPOT_STANDARD",
"fields": {
"date": {"type": "string", "format": "date"},
"assetId": {"type": "string"},
"spot": {
"type": "number",
},
"updateTime": {"type": "string", "format": "date-time"},
},
}
elif url == "/data/datasets/FXSPOT_STANDARD":
return {"id": "FXSPOT_STANDARD"}
else:
raise Exception("Need to mock _get request here")

@classmethod
def _post(cls, url, **kwargs):
if url == '/data/FXSPOT_STANDARD/last/query':
if url == "/data/FXSPOT_STANDARD/last/query":
return {
'requestId': '1234',
'data': [
"requestId": "1234",
"data": [
{
'date': '2023-10-25',
'assetId': 'MATGYV0J9MPX534Z',
'bbid': 'USDJPY',
'spot': 150.123,
'updateTime': '2023-10-25T21:53:56Z'
"date": "2023-10-25",
"assetId": "MATGYV0J9MPX534Z",
"bbid": "USDJPY",
"spot": 150.123,
"updateTime": "2023-10-25T21:53:56Z",
}
]}
elif url == '/data/FXSPOT_STANDARD/query':
],
}
elif url == "/data/FXSPOT_STANDARD/query":
return {
'requestId': '5678',
'data': [
"requestId": "5678",
"data": [
{
'date': '2023-10-26',
'assetId': 'MATGYV0J9MPX534Z',
'bbid': 'USDJPY',
'spot': 152.234,
'updateTime': '2023-10-26T21:53:56Z'
"date": "2023-10-26",
"assetId": "MATGYV0J9MPX534Z",
"bbid": "USDJPY",
"spot": 152.234,
"updateTime": "2023-10-26T21:53:56Z",
}
]
],
}
elif url == '/data/measures':
elif url == "/data/measures":
return {
'requestId': '890', 'responses': [
"requestId": "890",
"responses": [
{
'queryResponse': [
{'measure': 'Curve', 'dataSetIds': ['DATASET_FOO'], 'entityTypes': ['ASSET'],
'response': {
'data': [
{'date': '2023-04-11', 'assetId': 'MATGYV0J9MPX534Z', 'pricingLocation': 'HKG',
'name': 'USDJPY', 'spot': 133.},
{'date': '2023-04-11', 'assetId': 'MATGYV0J9MPX534Z', 'pricingLocation': 'LDN',
'name': 'USDJPY', 'spot': 134.0},
{'date': '2023-04-11', 'assetId': 'MATGYV0J9MPX534Z', 'pricingLocation': 'NYC',
'name': 'USDJPY', 'spot': 136.0}]}}]
"queryResponse": [
{
"measure": "Curve",
"dataSetIds": ["DATASET_FOO"],
"entityTypes": ["ASSET"],
"response": {
"data": [
{
"date": "2023-04-11",
"assetId": "MATGYV0J9MPX534Z",
"pricingLocation": "HKG",
"name": "USDJPY",
"spot": 133.0,
},
{
"date": "2023-04-11",
"assetId": "MATGYV0J9MPX534Z",
"pricingLocation": "LDN",
"name": "USDJPY",
"spot": 134.0,
},
{
"date": "2023-04-11",
"assetId": "MATGYV0J9MPX534Z",
"pricingLocation": "NYC",
"name": "USDJPY",
"spot": 136.0,
},
]
},
}
]
}
]
],
}


Expand All @@ -119,10 +153,12 @@ def teardown_method(self, test_method):

def test_last_data(self):
ds = Dataset("FXSPOT_STANDARD")
with patch.object(GsDataApi, 'get_session', return_value=FakeSession()):
df = ds.get_data_last(as_of=dt.date(2023, 10, 25), bbid='USDJPY')
with patch.object(GsDataApi, 'get_session', return_value=NotExpectedToBeCalledSession()):
df2 = ds.get_data_last(dt.date(2023, 10, 25), bbid='USDJPY')
with patch.object(GsDataApi, "get_session", return_value=FakeSession()):
df = ds.get_data_last(as_of=dt.date(2023, 10, 25), bbid="USDJPY")
with patch.object(
GsDataApi, "get_session", return_value=NotExpectedToBeCalledSession()
):
df2 = ds.get_data_last(dt.date(2023, 10, 25), bbid="USDJPY")
assert not df.empty
assert_frame_equal(df, df2)
cache_events = self.cache.get_events()
Expand All @@ -134,10 +170,16 @@ def test_last_data(self):

def test_query_data(self):
ds = Dataset("FXSPOT_STANDARD")
with patch.object(GsDataApi, 'get_session', return_value=FakeSession()):
df = ds.get_data(dt.date(2023, 10, 26), dt.date(2023, 10, 26), bbid='USDJPY')
with patch.object(GsDataApi, 'get_session', return_value=NotExpectedToBeCalledSession()):
df2 = ds.get_data(dt.date(2023, 10, 26), dt.date(2023, 10, 26), bbid='USDJPY')
with patch.object(GsDataApi, "get_session", return_value=FakeSession()):
df = ds.get_data(
dt.date(2023, 10, 26), dt.date(2023, 10, 26), bbid="USDJPY"
)
with patch.object(
GsDataApi, "get_session", return_value=NotExpectedToBeCalledSession()
):
df2 = ds.get_data(
dt.date(2023, 10, 26), dt.date(2023, 10, 26), bbid="USDJPY"
)

assert_frame_equal(df, df2)
cache_events = self.cache.get_events()
Expand All @@ -152,9 +194,20 @@ def test_market_data(self):
with DataContext(dt.date(2023, 4, 11), dt.date(2023, 4, 11)):
q = GsDataApi.build_market_data_query([asset_id], QueryType.SPOT)

with patch.object(GsDataApi, 'get_session', return_value=FakeSession()):
with patch.object(
GsDataApi, "get_session", return_value=MarketDataErrorSession()
):
try:
df = GsDataApi.get_market_data(q)
except Exception:
pass
cache_events = self.cache.get_events()
assert len(cache_events) == 0
with patch.object(GsDataApi, "get_session", return_value=FakeSession()):
df = GsDataApi.get_market_data(q)
with patch.object(GsDataApi, 'get_session', return_value=NotExpectedToBeCalledSession()):
with patch.object(
GsDataApi, "get_session", return_value=NotExpectedToBeCalledSession()
):
df2 = GsDataApi.get_market_data(q)

assert_frame_equal(df, df2)
Expand Down
8 changes: 7 additions & 1 deletion gs_quant/timeseries/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@
USE_DISPLAY_NAME = os.environ.get(ENABLE_DISPLAY_NAME) == "1"
_logger = logging.getLogger(__name__)


class Entitlement(Enum):
INTERNAL = 'internal'


try:
from quant_extensions.timeseries.rolling import rolling_apply
except ImportError as e:
Expand Down Expand Up @@ -203,7 +208,7 @@ def check_forward_looking(pricing_date, source, name="function"):

def plot_measure(asset_class: tuple, asset_type: Optional[tuple] = None,
dependencies: Optional[List[QueryType]] = tuple(), asset_type_excluded: Optional[tuple] = None,
display_name: Optional[str] = None):
display_name: Optional[str] = None, entitlements: Optional[List[Entitlement]] = []):
# Indicates that fn should be exported to plottool as a member function / pseudo-measure.
# Set category to None for no restrictions, else provide a tuple of allowed values.
def decorator(fn):
Expand All @@ -219,6 +224,7 @@ def decorator(fn):
fn.asset_type = asset_type
fn.asset_type_excluded = asset_type_excluded
fn.dependencies = dependencies
fn.entitlements = entitlements

if USE_DISPLAY_NAME:
fn.display_name = display_name
Expand Down
Loading

0 comments on commit 20b500d

Please sign in to comment.