diff --git a/pymobiledevice3/remote/remotexpc.py b/pymobiledevice3/remote/remotexpc.py index d26fec28..e11d9008 100644 --- a/pymobiledevice3/remote/remotexpc.py +++ b/pymobiledevice3/remote/remotexpc.py @@ -14,7 +14,6 @@ from pymobiledevice3.exceptions import StreamClosedError from pymobiledevice3.remote.xpc_message import XpcFlags, XpcInt64Type, XpcUInt64Type, XpcWrapper, create_xpc_wrapper, \ decode_xpc_object -from pymobiledevice3.service_connection import ServiceConnection # Extracted by sniffing `remoted` traffic via Wireshark DEFAULT_SETTINGS_MAX_CONCURRENT_STREAMS = 100 @@ -39,9 +38,10 @@ class RemoteXPCConnection: def __init__(self, address: Tuple[str, int]): self._previous_frame_data = b'' self.address = address - self.service_connection: Optional[ServiceConnection] = None self.next_message_id: Mapping[int: int] = {ROOT_CHANNEL: 0, REPLY_CHANNEL: 0} self.peer_info = None + self._reader: Optional[asyncio.StreamReader] = None + self._writer: Optional[asyncio.StreamWriter] = None async def __aenter__(self) -> 'RemoteXPCConnection': await self.connect() @@ -50,20 +50,26 @@ async def __aenter__(self) -> 'RemoteXPCConnection': async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: await self.close() - async def connect(self, keep_alive: bool = True) -> None: - self.service_connection = ServiceConnection.create_using_tcp(self.address[0], self.address[1], - keep_alive=keep_alive) - await self.service_connection.aio_start() + async def connect(self) -> None: + self._reader, self._writer = await asyncio.open_connection(self.address[0], self.address[1]) await self._do_handshake() async def close(self) -> None: - if self.service_connection is not None: - await self.service_connection.aio_close() + if self._writer is None: + return + self._writer.close() + try: + await self._writer.wait_closed() + except ConnectionResetError: + pass + self._writer = None + self._reader = None async def send_request(self, data: Mapping, wanting_reply: bool = False) -> None: xpc_wrapper = create_xpc_wrapper( data, message_id=self.next_message_id[ROOT_CHANNEL], wanting_reply=wanting_reply) - await self.service_connection.aio_sendall(DataFrame(stream_id=ROOT_CHANNEL, data=xpc_wrapper).serialize()) + self._writer.write(DataFrame(stream_id=ROOT_CHANNEL, data=xpc_wrapper).serialize()) + await self._writer.drain() async def iter_file_chunks(self, total_size: int, file_idx: int = 0) -> Generator[bytes, None, None]: stream_id = (file_idx + 1) * 2 @@ -124,7 +130,8 @@ def shell(self) -> None: }) async def _do_handshake(self) -> None: - await self.service_connection.aio_sendall(HTTP2_MAGIC) + self._writer.write(HTTP2_MAGIC) + await self._writer.drain() # send h2 headers await self._send_frame(SettingsFrame(settings={ @@ -146,14 +153,15 @@ async def _do_handshake(self) -> None: await self._send_frame(SettingsFrame(flags=['ACK'])) - async def _open_channel(self, stream_id: int, flags: int): + async def _open_channel(self, stream_id: int, flags: int) -> None: flags |= XpcFlags.ALWAYS_SET await self._send_frame(HeadersFrame(stream_id=stream_id, flags=['END_HEADERS'])) await self._send_frame( DataFrame(stream_id=stream_id, data=XpcWrapper.build({'size': 0, 'flags': flags, 'payload': None}))) async def _send_frame(self, frame: Frame) -> None: - await self.service_connection.aio_sendall(frame.serialize()) + self._writer.write(frame.serialize()) + await self._writer.drain() async def _receive_next_data_frame(self) -> DataFrame: while True: @@ -173,7 +181,7 @@ async def _receive_next_data_frame(self) -> DataFrame: return frame async def _receive_frame(self) -> Frame: - buf = await self._recvall(FRAME_HEADER_SIZE) + buf = await self._reader.readexactly(FRAME_HEADER_SIZE) frame, additional_size = Frame.parse_frame_header(memoryview(buf)) frame.parse_body(memoryview(await self._recvall(additional_size))) return frame @@ -182,7 +190,7 @@ async def _recvall(self, size: int) -> bytes: data = b'' while len(data) < size: try: - chunk = await self.service_connection.aio_recvall(size - len(data)) + chunk = await self._reader.readexactly(size - len(data)) except IncompleteReadError: raise ConnectionAbortedError() data += chunk