Skip to content

Commit

Permalink
Merge pull request #8452 from OpenMined/freeze-protocol-version
Browse files Browse the repository at this point in the history
Freeze protocol version
  • Loading branch information
yashgorana committed Feb 9, 2024
2 parents a7ae983 + fd09c5e commit 22b903c
Show file tree
Hide file tree
Showing 7 changed files with 1,192 additions and 1,010 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/cd-syft.yml
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ jobs:
author_name: ${{ secrets.OM_BOT_NAME }}
author_email: ${{ secrets.OM_BOT_EMAIL }}
message: "[syft]bump version"
add: "['.bumpversion.cfg', 'VERSION', 'packages/grid/VERSION','packages/syft/PYPI.md', 'packages/grid/devspace.yaml', 'packages/syft/src/syft/VERSION', 'packages/syft/setup.cfg', 'packages/grid/frontend/package.json', 'packages/syft/src/syft/__init__.py', 'packages/hagrid/hagrid/manifest_template.yml', 'packages/grid/helm/syft/Chart.yaml','packages/grid/helm/repo', 'packages/hagrid/hagrid/deps.py', 'packages/grid/podman/podman-kube/podman-syft-kube.yaml' ,'packages/grid/podman/podman-kube/podman-syft-kube-config.yaml', 'packages/syftcli/manifest.yml', 'packages/syft/src/syft/protocol/protocol_version.json', 'packages/grid/backend/worker_cpu.dockerfile','packages/grid/helm/syft/values.yaml','packages/grid/helm/syft']"
add: "['.bumpversion.cfg', 'VERSION', 'packages/grid/VERSION','packages/syft/PYPI.md', 'packages/grid/devspace.yaml', 'packages/syft/src/syft/VERSION', 'packages/syft/setup.cfg', 'packages/grid/frontend/package.json', 'packages/syft/src/syft/__init__.py', 'packages/hagrid/hagrid/manifest_template.yml', 'packages/grid/helm/syft/Chart.yaml','packages/grid/helm/repo', 'packages/hagrid/hagrid/deps.py', 'packages/grid/podman/podman-kube/podman-syft-kube.yaml' ,'packages/grid/podman/podman-kube/podman-syft-kube-config.yaml', 'packages/syftcli/manifest.yml', 'packages/syft/src/syft/protocol/protocol_version.json', 'packages/syft/src/syft/protocol/releases/', 'packages/grid/backend/worker_cpu.dockerfile','packages/grid/helm/syft/values.yaml','packages/grid/helm/syft']"

- name: Changes to commit to Syft Repo during stable release
if: needs.merge-docker-images.outputs.release_tag == 'latest'
Expand All @@ -419,7 +419,7 @@ jobs:
author_name: ${{ secrets.OM_BOT_NAME }}
author_email: ${{ secrets.OM_BOT_EMAIL }}
message: "[syft] bump protocol version"
add: "['packages/syft/src/syft/protocol/protocol_version.json']"
add: "['packages/syft/src/syft/protocol/protocol_version.json', 'packages/syft/src/syft/protocol/releases/']"

- name: Scheduled Build and Publish
if: github.event_name == 'schedule'
Expand Down
152 changes: 140 additions & 12 deletions packages/syft/src/syft/protocol/data_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,19 @@
import re
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Type
from typing import Union

# third party
from packaging.version import parse
from result import OkErr
from result import Result

# relative
from .. import __version__
from ..serde.recursive import TYPE_BANK
from ..service.response import SyftError
from ..service.response import SyftException
Expand All @@ -30,29 +33,33 @@
PROTOCOL_TYPE = Union[str, int]


def natural_key(key: PROTOCOL_TYPE) -> list[int]:
def natural_key(key: PROTOCOL_TYPE) -> List[int]:
"""Define key for natural ordering of strings."""
if isinstance(key, int):
key = str(key)
return [int(s) if s.isdigit() else s for s in re.split("(\d+)", key)]


def sort_dict_naturally(d: dict) -> dict:
def sort_dict_naturally(d: Dict) -> Dict:
"""Sort dictionary by keys in natural order."""
return {k: d[k] for k in sorted(d.keys(), key=natural_key)}


def data_protocol_file_name():
def data_protocol_file_name() -> str:
return PROTOCOL_STATE_FILENAME


def data_protocol_dir():
return os.path.abspath(str(Path(__file__).parent))
def data_protocol_dir() -> Path:
return Path(os.path.abspath(str(Path(__file__).parent)))


def protocol_release_dir() -> Path:
return data_protocol_dir() / "releases"


class DataProtocol:
def __init__(self, filename: str) -> None:
self.file_path = Path(data_protocol_dir()) / filename
self.file_path = data_protocol_dir() / filename
self.load_state()

def load_state(self) -> None:
Expand All @@ -78,13 +85,34 @@ def _calculate_object_hash(klass: Type[SyftBaseObject]) -> str:

return hashlib.sha256(json.dumps(obj_meta_info).encode()).hexdigest()

def read_history(self) -> Dict:
@staticmethod
def read_json(file_path: Path) -> Dict:
try:
return json.loads(self.file_path.read_text())
return json.loads(file_path.read_text())
except Exception:
return {}

def save_history(self, history: dict) -> None:
def read_history(self) -> Dict:
protocol_history = self.read_json(self.file_path)

for version in protocol_history.keys():
if version == "dev":
continue
release_version_path = (
protocol_release_dir() / protocol_history[version]["release_name"]
)
released_version = self.read_json(file_path=release_version_path)
protocol_history[version] = released_version.get(version, {})

return protocol_history

def save_history(self, history: Dict) -> None:
for file_path in protocol_release_dir().iterdir():
for version in self.read_json(file_path):
# Skip adding file if the version is not part of the history
if version not in history.keys():
continue
history[version] = {"release_name": file_path.name}
self.file_path.write_text(json.dumps(history, indent=2) + "\n")

@property
Expand Down Expand Up @@ -136,7 +164,7 @@ def build_state(self, stop_key: Optional[str] = None) -> dict:
return state_dict
return state_dict

def diff_state(self, state: dict) -> tuple[dict, dict]:
def diff_state(self, state: Dict) -> tuple[Dict, Dict]:
compare_dict = defaultdict(dict) # what versions are in the latest code
object_diff = defaultdict(dict) # diff in latest code with saved json
for k in TYPE_BANK:
Expand Down Expand Up @@ -274,6 +302,7 @@ def bump_protocol_version(self) -> Result[SyftSuccess, SyftError]:

keys = self.protocol_history.keys()
if "dev" not in keys:
self.validate_release()
print("You can't bump the protocol if there are no staged changes.")
return SyftError(
message="Failed to bump version as there are no staged changes."
Expand All @@ -287,11 +316,110 @@ def bump_protocol_version(self) -> Result[SyftSuccess, SyftError]:

next_highest_protocol = highest_protocol + 1
self.protocol_history[str(next_highest_protocol)] = self.protocol_history["dev"]
self.freeze_release(self.protocol_history, str(next_highest_protocol))
del self.protocol_history["dev"]
self.save_history(self.protocol_history)
self.load_state()
return SyftSuccess(message=f"Protocol Updated to {next_highest_protocol}")

@staticmethod
def freeze_release(protocol_history: Dict, latest_protocol: str) -> None:
"""Freezes latest release as a separate release file."""

# Get release history
release_history = protocol_history[latest_protocol]

# Create new file for the version
syft_version = parse(__version__)
release_file_name = f"{syft_version.public}.json"
release_file = protocol_release_dir() / release_file_name

# Save the new released version
release_file.write_text(
json.dumps({latest_protocol: release_history}, indent=2)
)

def validate_release(self) -> None:
"""Validate if latest release name is consistent with syft version"""
# Read the protocol history
protocol_history = self.read_json(self.file_path)
sorted_protocol_versions = sorted(protocol_history.keys(), key=natural_key)

# Grab the latest protocol
latest_protocol = (
sorted_protocol_versions[-1] if len(sorted_protocol_versions) > 0 else None
)

# Skip validation if latest protocol is dev
if latest_protocol is None or latest_protocol == "dev":
return

# Get filename of the latest protocol
release_name = protocol_history[latest_protocol]["release_name"]
# Extract syft version from release name
protocol_syft_version = parse(release_name.split(".json")[0])
current_syft_version = parse(__version__)

# If base syft version in latest protocol version is not same as current syft version
# Skip updating the release name
if protocol_syft_version.base_version != current_syft_version.base_version:
return

# Update release name to latest beta, stable or post based on current syft version
print(
f"Current release {release_name} will be updated to {current_syft_version}"
)

# Get latest protocol file path
latest_protocol_fp: Path = protocol_release_dir() / release_name

# New protocol file path
new_protocol_file_path = (
protocol_release_dir() / f"{current_syft_version.public}.json"
)

# Update older file path to newer file path
latest_protocol_fp.rename(new_protocol_file_path)
protocol_history[latest_protocol][
"release_name"
] = f"{current_syft_version}.json"

# Save history
self.file_path.write_text(json.dumps(protocol_history, indent=2) + "\n")

# Reload protocol
self.read_history()

def revert_latest_protocol(self) -> Result[SyftSuccess, SyftError]:
"""Revert latest protocol changes to dev"""

# Get current protocol history
protocol_history = self.read_json(self.file_path)

# Get latest released protocol
sorted_protocol_versions = sorted(protocol_history.keys(), key=natural_key)
latest_protocol = (
sorted_protocol_versions[-1] if len(sorted_protocol_versions) > 0 else None
)

# If current protocol is dev, skip revert
if latest_protocol is None or latest_protocol == "dev":
return SyftError(message="Revert skipped !! Already running dev protocol.")

# Read the current released protocol
release_name = protocol_history[latest_protocol]["release_name"]
protocol_file_path: Path = protocol_release_dir() / release_name

released_protocol = self.read_json(protocol_file_path)
protocol_history["dev"] = released_protocol[latest_protocol]

# Delete the current released protocol
protocol_history.pop(latest_protocol)
protocol_file_path.unlink()

# Save history
self.save_history(protocol_history)

def check_protocol(self) -> Result[SyftSuccess, SyftError]:
if len(self.diff) != 0:
return SyftError(message="Protocol Changes Unstaged")
Expand Down Expand Up @@ -338,7 +466,7 @@ def has_dev(self) -> bool:
return False


def get_data_protocol():
def get_data_protocol() -> DataProtocol:
return DataProtocol(filename=data_protocol_file_name())


Expand All @@ -357,7 +485,7 @@ def check_or_stage_protocol() -> Result[SyftSuccess, SyftError]:
return data_protocol.check_or_stage_protocol()


def debox_arg_and_migrate(arg: Any, protocol_state: dict):
def debox_arg_and_migrate(arg: Any, protocol_state: dict) -> Any:
"""Debox the argument based on whether it is iterable or single entity."""
constructor = None
extra_args = []
Expand Down
Loading

0 comments on commit 22b903c

Please sign in to comment.