Skip to content

Commit

Permalink
Implements post_transaction_hook
Browse files Browse the repository at this point in the history
See issue #24.
  • Loading branch information
hirak99 committed Nov 12, 2023
1 parent b752547 commit 06c0811
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 9 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,12 @@ keep_daily = 5
keep_weekly = 0
keep_monthly = 0
keep_yearly = 0

# Uncomment example to specify scripts to run after yabsnap creates or deletes any snap.
# If any creation / deletion operation occurs, each script will be called with the `source` as an
# argument. Use space as delimiter to specify multiple scripts.
# Example -
# post_transaction_scripts = "/home/me/script1.sh" "/home/me/script2.sh"
```

# Command Line Interface
Expand Down
12 changes: 11 additions & 1 deletion src/code/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
import dataclasses
import datetime
import logging
import pathlib
import os
import pathlib
import shlex

from . import human_interval
from . import os_utils
Expand Down Expand Up @@ -64,6 +65,8 @@ class Config:
keep_monthly: int = 0
keep_yearly: int = 0

post_transaction_scripts: list[str] = dataclasses.field(default_factory=list)

def is_schedule_enabled(self) -> bool:
return (
self.keep_hourly > 0
Expand All @@ -84,6 +87,9 @@ def from_configfile(cls, config_file: str) -> "Config":
dest_prefix=section["dest_prefix"],
)
for key, value in section.items():
if key == "post_transaction_scripts":
result.post_transaction_scripts = shlex.split(value)
continue
if not hasattr(result, key):
logging.warning(f"Invalid field {key=} found in {config_file=}")
if key.endswith("_interval"):
Expand All @@ -106,6 +112,10 @@ def deletion_rules(self) -> list[tuple[datetime.timedelta, int]]:
def mount_path(self) -> str:
return os.path.dirname(self.dest_prefix)

def call_post_hooks(self) -> None:
for script in self.post_transaction_scripts:
os_utils.run_user_script(script, [self.source])


def iterate_configs(source: Optional[str]) -> Iterator[Config]:
config_iterator: Iterable[str]
Expand Down
27 changes: 24 additions & 3 deletions src/code/configs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest import mock
import os
import tempfile
import unittest
Expand All @@ -32,12 +33,32 @@ def test_default_config(self):
)
self.assertEqual(config, expected_config)

def test_post_transaction_scripts(self):
with tempfile.NamedTemporaryFile(prefix="yabsnap_config_test_") as file:
with open(configs._example_config_fname()) as example_file:
for line in example_file:
file.write(line.encode())

# Add a new line.
file.write(
'post_transaction_scripts = script1.sh "/my directory/script2.sh"\n'.encode(),
)
file.flush()

read_config = configs.Config.from_configfile(file.name)
self.assertEqual(
read_config.post_transaction_scripts,
["script1.sh", "/my directory/script2.sh"],
)

def test_create_config(self):
with tempfile.NamedTemporaryFile(prefix="yabsnap_config_test_") as file:
configs.USER_CONFIG_FILE = file.name
# Don't need the file; in fact if it exists we cannot create it.
os.remove(file.name)
# Create -
configs.create_config("configname", "source")

# Create (i.e write to file.name) -
with mock.patch.object(configs, "USER_CONFIG_FILE", file.name):
configs.create_config("configname", "source")

# Read back -
read_config = configs.Config.from_configfile(file.name)
Expand Down
6 changes: 6 additions & 0 deletions src/code/example_config.conf
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,9 @@ keep_daily = 5
keep_weekly = 0
keep_monthly = 0
keep_yearly = 0

# Uncomment example to specify scripts to run after yabsnap creates or deletes any snap.
# If any creation / deletion operation occurs, each script will be called with the `source` as an
# argument. Use space as delimiter to specify multiple scripts.
# Example -
# post_transaction_scripts = "/home/me/script1.sh" "/home/me/script2.sh"
7 changes: 6 additions & 1 deletion src/code/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import argparse
import datetime
import logging
import subprocess
from typing import Iterable

from . import colored_logs
Expand Down Expand Up @@ -94,6 +95,8 @@ def _delete_snap(configs_iter: Iterable[configs.Config], path_suffix: str, sync:
snap.delete()
mount_paths.add(config.mount_path)

config.call_post_hooks()

if sync:
_btrfs_sync(mount_paths)

Expand Down Expand Up @@ -124,8 +127,10 @@ def _config_operation(command: str, source: str, comment: str, sync: bool):
else:
raise ValueError(f"Command not implemented: {command}")

if snapper.need_sync:
if snapper.snaps_deleted:
mount_paths_to_sync.add(config.mount_path)
if snapper.snaps_created or snapper.snaps_deleted:
config.call_post_hooks()

if sync:
_btrfs_sync(mount_paths_to_sync)
Expand Down
12 changes: 12 additions & 0 deletions src/code/os_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,18 @@ def execute_sh(command: str, error_ok: bool = False) -> Optional[str]:
return None


def run_user_script(script_name: str, args: list[str]) -> bool:
try:
subprocess.check_call([script_name] + args)
except FileNotFoundError:
logging.warning(f"User script {script_name=} does not exist.")
return False
except subprocess.CalledProcessError:
logging.warning(f"User script {script_name=} with {args=} resulted in error.")
return False
return True


def is_btrfs_volume(mount_point: str) -> bool:
"""Test if directory is a btrfs volume."""
# Based on https://stackoverflow.com/a/32865333/196462
Expand Down
36 changes: 36 additions & 0 deletions src/code/os_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile
import unittest

from . import os_utils

# For testing, we can access private methods.
# pyright: reportPrivateUsage=false


class OsUtilsTest(unittest.TestCase):
def test_run_user_script(self):
self.assertTrue(os_utils.run_user_script("true", []))
self.assertTrue(os_utils.run_user_script("test", ["1", "=", "1"]))
self.assertFalse(os_utils.run_user_script("test", ["1", "=", "2"]))
with tempfile.TemporaryDirectory() as dir:
# Script does not exist.
self.assertFalse(os_utils.run_user_script(os.path.join(dir, "test.sh"), []))


if __name__ == "__main__":
unittest.main()
12 changes: 8 additions & 4 deletions src/code/snap_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,10 @@ def __init__(self, config: configs.Config, now: datetime.datetime) -> None:
self._config = config
self._now = now
self._now_str = self._now.strftime(snap_holder.TIME_FORMAT)
# Set to true on any delete operation.
self.need_sync = False
# Set to true on any create operation.
self.snaps_created = False
# Set to true on any delete operation. If True, may run a btrfs subv sync.
self.snaps_deleted = False

def _apply_deletion_rules(self, snaps: Iterable[snap_holder.Snapshot]) -> bool:
"""Deletes old backups. Returns True if new backup is needed."""
Expand All @@ -89,7 +91,7 @@ def _apply_deletion_rules(self, snaps: Iterable[snap_holder.Snapshot]) -> bool:
elapsed_secs = (self._now - when).total_seconds()
if elapsed_secs > self._config.min_keep_secs:
snap_holder.Snapshot(target).delete()
self.need_sync = True
self.snaps_deleted = True
else:
logging.info(f"Not enough time passed, not deleting {target}")

Expand All @@ -114,11 +116,12 @@ def _create_and_maintain_n_backups(
if comment:
snapshot.metadata.comment = comment
snapshot.create_from(self._config.source)
self.snaps_created = True

# Clean up old snaps; leave count-1 previous snaps (plus the one now created).
for expired in _all_but_last_k(previous_snaps, count - 1):
expired.delete()
self.need_sync = True
self.snaps_deleted = True

def create(self, comment: Optional[str]):
try:
Expand Down Expand Up @@ -191,6 +194,7 @@ def scheduled(self):
snapshot = snap_holder.Snapshot(self._config.dest_prefix + self._now_str)
snapshot.metadata.trigger = "S"
snapshot.create_from(self._config.source)
self.snaps_created = True

def list_snaps(self):
"""Print the backups for humans."""
Expand Down

0 comments on commit 06c0811

Please sign in to comment.