diff --git a/news/458.feature.rst b/news/458.feature.rst new file mode 100644 index 0000000000..847bf0a40d --- /dev/null +++ b/news/458.feature.rst @@ -0,0 +1,3 @@ +``memray attach`` has been enhanced to allow tracking for only a set period of +time, or until a set heap size is reached. You can also manually deactivate +tracking that was started by a previous call to ``memray attach``. diff --git a/src/memray/commands/attach.py b/src/memray/commands/attach.py index 927fe338ef..a043e1d5d1 100644 --- a/src/memray/commands/attach.py +++ b/src/memray/commands/attach.py @@ -19,41 +19,135 @@ from .live import LiveCommand from .run import _get_free_port +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal # type: ignore + +TrackingMode = Literal["ACTIVATE", "DEACTIVATE", "UNTIL_HEAP_SIZE", "FOR_DURATION"] + + GDB_SCRIPT = pathlib.Path(__file__).parent / "_attach.gdb" LLDB_SCRIPT = pathlib.Path(__file__).parent / "_attach.lldb" RTLD_DEFAULT = memray._memray.RTLD_DEFAULT RTLD_NOW = memray._memray.RTLD_NOW PAYLOAD = """ import atexit +import time +import threading +import resource +import sys +from contextlib import suppress import memray +def _get_current_heap_size() -> int: + usage = resource.getrusage(resource.RUSAGE_SELF) + rss_bytes = usage.ru_maxrss * 1024 # Convert from KB to bytes + return rss_bytes + + +class RepeatingTimer(threading.Thread): + def __init__(self, interval, function): + self._interval = interval + self._function = function + self._canceled = threading.Event() + super().__init__() + + def cancel(self): + self._canceled.set() + + def run(self): + while not self._canceled.wait(self._interval): + if self._function(): + break + + def deactivate_last_tracker(): tracker = getattr(memray, "_last_tracker", None) if not tracker: return memray._last_tracker = None - tracker.__exit__(None, None, None) + try: + tracker.__exit__(None, None, None) + finally: + # Clean up resources associated with the Tracker ASAP, + # even if an exception was raised. + del tracker + + # Stop any waiting threads. This attribute may be unset if an old Memray + # version attached 1st, setting last_tracker but not _attach_event_threads. + # It could also be unset if we're racing another deactivate call. + for thread in memray.__dict__.pop("_attach_event_threads", []): + thread.cancel() + + +def activate_tracker(): + deactivate_last_tracker() + tracker = {tracker_call} + try: + tracker.__enter__() + memray._last_tracker = tracker + finally: + # Clean up resources associated with the Tracker ASAP, + # even if an exception was raised. + del tracker + memray._attach_event_threads = [] + + +def track_until_heap_size(heap_size): + activate_tracker() + + def check_heap_size() -> bool: + current_heap_size = _get_current_heap_size() + if current_heap_size >= heap_size: + print( + "memray: Deactivating tracking: heap size has reached", + current_heap_size, + "bytes, the limit was", + heap_size, + file=sys.stderr, + ) + deactivate_last_tracker() + return True # Condition we were waiting for has happened + return False # Keep polling + + thread = RepeatingTimer(1, check_heap_size) + thread.start() + memray._attach_event_threads.append(thread) + + +def track_for_duration(duration=5): + activate_tracker() + + def deactivate_because_timer_elapsed(): + print( + "memray: Deactivating tracking:", + duration, + "seconds have elapsed", + file=sys.stderr, + ) + deactivate_last_tracker() + + thread = threading.Timer(duration, deactivate_because_timer_elapsed) + thread.start() + memray._attach_event_threads.append(thread) if not hasattr(memray, "_last_tracker"): # This only needs to be registered the first time we attach. atexit.register(deactivate_last_tracker) -deactivate_last_tracker() - -tracker = {tracker_call} -try: - tracker.__enter__() -except: - # Prevent the exception from keeping the tracker alive. - # This way resources are cleaned up ASAP. - del tracker - raise - -memray._last_tracker = tracker +if {mode!r} == "ACTIVATE": + activate_tracker() +elif {mode!r} == "DEACTIVATE": + deactivate_last_tracker() +elif {mode!r} == "UNTIL_HEAP_SIZE": + track_until_heap_size({heap_size}) +elif {mode!r} == "FOR_DURATION": + track_for_duration({duration}) """ @@ -281,6 +375,23 @@ def prepare_parser(self, parser: argparse.ArgumentParser) -> None: action="store_true", ) + mode = parser.add_mutually_exclusive_group() + + mode.add_argument( + "--stop-tracking", + action="store_true", + help="Stop any tracker installed by a previous `memray attach` call", + default=False, + ) + + mode.add_argument( + "--heap-limit", type=int, help="Heap size to track until (in bytes)" + ) + + mode.add_argument( + "--duration", type=int, help="Duration to track for (in seconds)" + ) + parser.add_argument( "--method", help="Method to use for injecting code into the process to track", @@ -304,6 +415,32 @@ def prepare_parser(self, parser: argparse.ArgumentParser) -> None: def run(self, args: argparse.Namespace, parser: argparse.ArgumentParser) -> None: verbose = args.verbose + mode: TrackingMode = "ACTIVATE" + duration = None + heap_size = None + + if args.stop_tracking: + if args.output: + parser.error("Can't use --stop-tracking with -o or --output") + if args.force: + parser.error("Can't use --stop-tracking with -f or --force") + if args.aggregate: + parser.error("Can't use --stop-tracking with --aggregate") + if args.native: + parser.error("Can't use --stop-tracking with --native") + if args.follow_fork: + parser.error("Can't use --stop-tracking with --follow-fork") + if args.trace_python_allocators: + parser.error("Can't use --stop-tracking with --trace-python-allocators") + if args.no_compress: + parser.error("Can't use --stop-tracking with --no-compress") + mode = "DEACTIVATE" + elif args.heap_limit: + mode = "UNTIL_HEAP_SIZE" + heap_size = args.heap_limit + elif args.duration: + mode = "FOR_DURATION" + duration = args.duration if args.method == "auto": # Prefer gdb on Linux but lldb on macOS @@ -368,7 +505,14 @@ def run(self, args: argparse.Namespace, parser: argparse.ArgumentParser) -> None client = server.accept()[0] - client.sendall(PAYLOAD.format(tracker_call=tracker_call).encode("utf-8")) + client.sendall( + PAYLOAD.format( + tracker_call=tracker_call, + mode=mode, + heap_size=heap_size, + duration=duration, + ).encode("utf-8") + ) client.shutdown(socket.SHUT_WR) if not live_port: diff --git a/tests/unit/test_attach.py b/tests/unit/test_attach.py index 7ca44e4426..31c8693bb7 100644 --- a/tests/unit/test_attach.py +++ b/tests/unit/test_attach.py @@ -18,4 +18,55 @@ def test_memray_attach_aggregated_without_output_file( main(["attach", "--aggregate", "1234"]) captured = capsys.readouterr() + print("Error", captured.err) assert "Can't use aggregated mode without an output file." in captured.err + + +class TestAttachSubCommandOptions: + @pytest.mark.parametrize( + "option", + [ + ["--output", "foo"], + ["-o", "foo"], + ["--native"], + ["--force"], + ["-f"], + ["--aggregate"], + ["--follow-fork"], + ["--trace-python-allocators"], + ["--no-compress"], + ], + ) + def test_memray_attach_stop_tracking_option_with_other_options( + self, option, capsys + ): + # WHEN + with pytest.raises(SystemExit): + main(["attach", "1234", "--stop-tracking", *option]) + + captured = capsys.readouterr() + assert "Can't use --stop-tracking with" in captured.err + assert option[0] in captured.err.split() + + @pytest.mark.parametrize( + "arg1,arg2", + [ + ("--stop-tracking", "--heap-limit=10"), + ("--stop-tracking", "--duration=10"), + ("--heap-limit=10", "--duration=10"), + ], + ) + def test_memray_attach_stop_tracking_option_with_other_mode_options( + self, arg1, arg2, capsys + ): + # WHEN + with pytest.raises(SystemExit): + main(["attach", "1234", arg1, arg2]) + + captured = capsys.readouterr() + arg1_name = arg1.split("=")[0] + arg2_name = arg2.split("=")[0] + assert ( + f"argument {arg2_name}: not allowed with argument {arg1_name}" + in captured.err + )