Skip to content

Commit

Permalink
Merge pull request #8774 from OpenMined/eelco/cleanup-table
Browse files Browse the repository at this point in the history
refactor table generation
  • Loading branch information
tcp authored May 3, 2024
2 parents ee441d3 + e0ce3df commit feb7a59
Show file tree
Hide file tree
Showing 13 changed files with 11,280 additions and 212 deletions.
10,887 changes: 10,887 additions & 0 deletions notebooks/notebook_ui/table_examples.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions packages/syft/src/syft/assets/css/style.css
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
.syft-widget body[data-jp-theme-light="false"] {
body[data-jp-theme-light="false"] {
--primary-color: #111111;
--secondary-color: #212121;
--tertiary-color: #cfcdd6;
--button-color: #111111;
}

.syft-widget body {
body {
--primary-color: #ffffff;
--secondary-color: #f5f5f5;
--tertiary-color: #000000de;
Expand Down
11 changes: 6 additions & 5 deletions packages/syft/src/syft/service/code_history/code_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from ...types.syft_object import SYFT_OBJECT_VERSION_2
from ...types.syft_object import SyftObject
from ...types.syft_object import SyftVerifyKey
from ...types.syft_object import get_repr_values_table
from ...types.uid import UID
from ...util.notebook_ui.components.table import create_table_template
from ...util.notebook_ui.components.table_template import create_table_template
from ...util.table import prepare_table_data
from ..code.user_code import UserCode
from ..response import SyftError

Expand Down Expand Up @@ -55,7 +55,8 @@ def _coll_repr_(self) -> dict[str, int]:
return {"Number of versions": len(self.user_code_history)}

def _repr_html_(self) -> str:
rows = get_repr_values_table(self.user_code_history, True)
# TODO techdebt: move this to _coll_repr_
rows, _ = prepare_table_data(self.user_code_history)
for i, r in enumerate(rows):
r["Version"] = f"v{i}"
raw_code = self.user_code_history[i].raw_code
Expand All @@ -64,7 +65,7 @@ def _repr_html_(self) -> str:
raw_code = "\n".join(raw_code.split("\n", 5))
r["Code"] = raw_code
# rows = sorted(rows, key=lambda x: x["Version"])
return create_table_template(rows, "CodeHistory", table_icon=None)
return create_table_template(rows, "CodeHistory", icon=None)

def __getitem__(self, index: int | str) -> UserCode | SyftError:
if isinstance(index, str):
Expand Down Expand Up @@ -139,4 +140,4 @@ def _repr_html_(self) -> str:
rows = []
for user, funcs in self.user_dict.items():
rows += [{"user": user, "UserCodes": funcs}]
return create_table_template(rows, "UserCodeHistory", table_icon=None)
return create_table_template(rows, "UserCodeHistory", icon=None)
3 changes: 2 additions & 1 deletion packages/syft/src/syft/service/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ class Dataset(SyftObject):
__attr_searchable__ = ["name", "citation", "url", "description", "action_ids"]
__attr_unique__ = ["name"]
__repr_attrs__ = ["name", "url", "created_at"]
__table_sort_attr__ = "Created at"

def __init__(
self,
Expand All @@ -476,7 +477,7 @@ def _coll_repr_(self) -> dict[str, Any]:
"Assets": len(self.asset_list),
"Size": f"{self.mb_size} (MB)",
"Url": self.url,
"created at": str(self.created_at),
"Created at": str(self.created_at),
}

def _repr_html_(self) -> Any:
Expand Down
3 changes: 2 additions & 1 deletion packages/syft/src/syft/service/notification/notifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class Notification(SyftObject):
"status",
]
__repr_attrs__ = ["subject", "status", "created_at", "linked_obj"]
__table_sort_attr__ = "Created at"

def _repr_html_(self) -> str:
return f"""
Expand Down Expand Up @@ -101,7 +102,7 @@ def _coll_repr_(self) -> dict[str, str]:
return {
"Subject": self.subject,
"Status": self.determine_status().name.capitalize(),
"Created At": str(self.created_at),
"Created at": str(self.created_at),
"Linked object": f"{self.linked_obj.object_type.__canonical_name__} ({self.linked_obj.object_uid})",
}

Expand Down
3 changes: 3 additions & 0 deletions packages/syft/src/syft/service/request/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ class Request(SyncableSyftObject):
"auto",
"auto",
"auto",
"auto",
]

__attr_searchable__ = [
Expand All @@ -368,6 +369,7 @@ class Request(SyncableSyftObject):
"requesting_user_verify_key",
]
__exclude_sync_diff_attrs__ = ["node_uid"]
__table_sort_attr__ = "Request time"

def _repr_html_(self) -> Any:
# add changes
Expand Down Expand Up @@ -465,6 +467,7 @@ def _coll_repr_(self) -> dict[str, str | dict[str, str]]:
]

return {
"Request time": str(self.request_time),
"Description": self.html_description,
"Requested By": "\n".join(user_data),
"Status": status_badge,
Expand Down
1 change: 1 addition & 0 deletions packages/syft/src/syft/service/worker/worker_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ class WorkerPool(SyftObject):
"workers",
"created_at",
]
__table_sort_attr__ = "Created at"

name: str
image_id: UID | None = None
Expand Down
15 changes: 14 additions & 1 deletion packages/syft/src/syft/types/datetime.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# stdlib
from datetime import datetime
from functools import total_ordering
import re
from typing import Any

# third party
Expand All @@ -12,6 +13,13 @@
from .syft_object import SyftObject
from .uid import UID

DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S"
DATETIME_REGEX = r"\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}"


def str_is_datetime(str_: str) -> bool:
return bool(re.match(DATETIME_REGEX, str_))


@serializable()
@total_ordering
Expand All @@ -26,9 +34,14 @@ class DateTime(SyftObject):
def now(cls) -> Self:
return cls(utc_timestamp=datetime.utcnow().timestamp())

@classmethod
def from_str(cls, datetime_str: str) -> "DateTime":
dt = datetime.strptime(datetime_str, DATETIME_FORMAT)
return cls(utc_timestamp=dt.timestamp())

def __str__(self) -> str:
utc_datetime = datetime.utcfromtimestamp(self.utc_timestamp)
return utc_datetime.strftime("%Y-%m-%d %H:%M:%S")
return utc_datetime.strftime(DATETIME_FORMAT)

def __hash__(self) -> int:
return hash(self.utc_timestamp)
Expand Down
191 changes: 2 additions & 189 deletions packages/syft/src/syft/types/syft_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
from hashlib import sha256
import inspect
from inspect import Signature
import re
import traceback
import types
from types import NoneType
from types import UnionType
Expand All @@ -27,7 +25,6 @@
from typing import get_origin

# third party
import pandas as pd
import pydantic
from pydantic import ConfigDict
from pydantic import EmailStr
Expand All @@ -44,7 +41,7 @@
from ..serde.serialize import _serialize as serialize
from ..util.autoreload import autoreload_enabled
from ..util.markdown import as_markdown_python_code
from ..util.notebook_ui.components.table import create_table_template
from ..util.table import list_dict_repr_html
from ..util.util import aggressive_set_attr
from ..util.util import full_name_with_qualname
from ..util.util import get_qualname_for
Expand Down Expand Up @@ -422,6 +419,7 @@ def make_id(cls, values: Any) -> Any:
)
__validate_private_attrs__: ClassVar[bool] = True
__table_coll_widths__: ClassVar[list[str] | None] = None
__table_sort_attr__: ClassVar[str | None] = None

def __syft_get_funcs__(self) -> list[tuple[str, Signature]]:
funcs = print_type_cache[type(self)]
Expand Down Expand Up @@ -764,191 +762,6 @@ def short_uid(uid: UID | None) -> str | None:
return str(uid)[:6] + "..."


def get_repr_values_table(
_self: Mapping | Iterable,
is_homogenous: bool,
extra_fields: list | None = None,
) -> dict:
if extra_fields is None:
extra_fields = []

cols = defaultdict(list)
for item in iter(_self.items() if isinstance(_self, Mapping) else _self):
# unpack dict
if isinstance(_self, Mapping):
key, item = item
cols["key"].append(key)

# get id
id_ = getattr(item, "id", None)
include_id = getattr(item, "__syft_include_id_coll_repr__", True)
if id_ is not None and include_id:
cols["id"].append({"value": str(id_), "type": "clipboard"})

if type(item) == type:
t = full_name_with_qualname(item)
else:
try:
t = item.__class__.__name__
except Exception:
t = item.__repr__()

if not is_homogenous:
cols["type"].append(t)

# if has _coll_repr_

if hasattr(item, "_coll_repr_"):
ret_val = item._coll_repr_()
if "id" in ret_val:
del ret_val["id"]
for key in ret_val.keys():
cols[key].append(ret_val[key])
else:
for field in extra_fields:
value = item
try:
attrs = field.split(".")
for i, attr in enumerate(attrs):
# find indexing like abc[1]
res = re.search(r"\[[+-]?\d+\]", attr)
has_index = False
if res:
has_index = True
index_str = res.group()
index = int(index_str.replace("[", "").replace("]", ""))
attr = attr.replace(index_str, "")

value = getattr(value, attr, None)
if isinstance(value, list) and has_index:
value = value[index]
# If the object has a special representation when nested we will use that instead
if (
hasattr(value, "__repr_syft_nested__")
and i == len(attrs) - 1
):
value = value.__repr_syft_nested__()
if (
isinstance(value, list)
and i == len(attrs) - 1
and len(value) > 0
and hasattr(value[0], "__repr_syft_nested__")
):
value = [
(
x.__repr_syft_nested__()
if hasattr(x, "__repr_syft_nested__")
else x
)
for x in value
]
if value is None:
value = "n/a"

except Exception as e:
print(e)
value = None
cols[field].append(str(value))

df = pd.DataFrame(cols)

if "created_at" in df.columns:
df.sort_values(by="created_at", ascending=False, inplace=True)

return df.to_dict("records") # type: ignore


def _get_grid_template_columns(first_value: Any) -> tuple[str | None, str | None]:
grid_template_cols = getattr(first_value, "__table_coll_widths__", None)
if isinstance(grid_template_cols, list):
grid_template_columns = " ".join(grid_template_cols)
grid_template_cell_columns = "unset"
else:
grid_template_columns = None
grid_template_cell_columns = None
return grid_template_columns, grid_template_cell_columns


def list_dict_repr_html(self: Mapping | Set | Iterable) -> str:
try:
max_check = 1
items_checked = 0
has_syft = False
extra_fields: list = []
if isinstance(self, Mapping):
values: Any = list(self.values())
elif isinstance(self, Set):
values = list(self)
else:
values = self

if len(values) == 0:
return self.__repr__()

for item in iter(self.values() if isinstance(self, Mapping) else self):
items_checked += 1
if items_checked > max_check:
break

if hasattr(type(item), "mro") and type(item) != type:
mro: list | str = type(item).mro()
elif hasattr(item, "mro") and type(item) != type:
mro = item.mro()
else:
mro = str(self)

if "syft" in str(mro).lower():
has_syft = True
extra_fields = getattr(item, "__repr_attrs__", [])
break

if has_syft:
# if custom_repr:
table_icon = None
if hasattr(values[0], "icon"):
table_icon = values[0].icon
# this is a list of dicts
is_homogenous = len({type(x) for x in values}) == 1
# third party

try:
vals = get_repr_values_table(
self, is_homogenous, extra_fields=extra_fields
)
except Exception:
return str(self)

first_value = values[0]
if is_homogenous:
cls_name = first_value.__class__.__name__
grid_template_columns, grid_template_cell_columns = (
_get_grid_template_columns(first_value)
)
else:
cls_name = ""
grid_template_columns = None
grid_template_cell_columns = None

return create_table_template(
vals,
f"{cls_name} {self.__class__.__name__.capitalize()}",
table_icon=table_icon,
grid_template_columns=grid_template_columns,
grid_template_cell_columns=grid_template_cell_columns,
)

except Exception as e:
print(
f"error representing {type(self)} of objects. {e}, {traceback.format_exc()}"
)
pass

# stdlib
import html

return html.escape(self.__repr__())


# give lists and dicts a _repr_html_ if they contain SyftObject's
aggressive_set_attr(type([]), "_repr_html_", list_dict_repr_html)
aggressive_set_attr(type({}), "_repr_html_", list_dict_repr_html)
Expand Down
Loading

0 comments on commit feb7a59

Please sign in to comment.