Skip to content

Commit

Permalink
Merge pull request #1092 from doronz88/refactor/remotexpc-tcp-async
Browse files Browse the repository at this point in the history
remotexpc: make tcp connections async
  • Loading branch information
doronz88 committed Jun 25, 2024
2 parents 6f99003 + ade22e6 commit 34693e1
Showing 1 changed file with 22 additions and 14 deletions.
36 changes: 22 additions & 14 deletions pymobiledevice3/remote/remotexpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from pymobiledevice3.exceptions import StreamClosedError
from pymobiledevice3.remote.xpc_message import XpcFlags, XpcInt64Type, XpcUInt64Type, XpcWrapper, create_xpc_wrapper, \
decode_xpc_object
from pymobiledevice3.service_connection import ServiceConnection

# Extracted by sniffing `remoted` traffic via Wireshark
DEFAULT_SETTINGS_MAX_CONCURRENT_STREAMS = 100
Expand All @@ -39,9 +38,10 @@ class RemoteXPCConnection:
def __init__(self, address: Tuple[str, int]):
self._previous_frame_data = b''
self.address = address
self.service_connection: Optional[ServiceConnection] = None
self.next_message_id: Mapping[int: int] = {ROOT_CHANNEL: 0, REPLY_CHANNEL: 0}
self.peer_info = None
self._reader: Optional[asyncio.StreamReader] = None
self._writer: Optional[asyncio.StreamWriter] = None

async def __aenter__(self) -> 'RemoteXPCConnection':
await self.connect()
Expand All @@ -50,20 +50,26 @@ async def __aenter__(self) -> 'RemoteXPCConnection':
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
await self.close()

async def connect(self, keep_alive: bool = True) -> None:
self.service_connection = ServiceConnection.create_using_tcp(self.address[0], self.address[1],
keep_alive=keep_alive)
await self.service_connection.aio_start()
async def connect(self) -> None:
self._reader, self._writer = await asyncio.open_connection(self.address[0], self.address[1])
await self._do_handshake()

async def close(self) -> None:
if self.service_connection is not None:
await self.service_connection.aio_close()
if self._writer is None:
return
self._writer.close()
try:
await self._writer.wait_closed()
except ConnectionResetError:
pass
self._writer = None
self._reader = None

async def send_request(self, data: Mapping, wanting_reply: bool = False) -> None:
xpc_wrapper = create_xpc_wrapper(
data, message_id=self.next_message_id[ROOT_CHANNEL], wanting_reply=wanting_reply)
await self.service_connection.aio_sendall(DataFrame(stream_id=ROOT_CHANNEL, data=xpc_wrapper).serialize())
self._writer.write(DataFrame(stream_id=ROOT_CHANNEL, data=xpc_wrapper).serialize())
await self._writer.drain()

async def iter_file_chunks(self, total_size: int, file_idx: int = 0) -> Generator[bytes, None, None]:
stream_id = (file_idx + 1) * 2
Expand Down Expand Up @@ -124,7 +130,8 @@ def shell(self) -> None:
})

async def _do_handshake(self) -> None:
await self.service_connection.aio_sendall(HTTP2_MAGIC)
self._writer.write(HTTP2_MAGIC)
await self._writer.drain()

# send h2 headers
await self._send_frame(SettingsFrame(settings={
Expand All @@ -146,14 +153,15 @@ async def _do_handshake(self) -> None:

await self._send_frame(SettingsFrame(flags=['ACK']))

async def _open_channel(self, stream_id: int, flags: int):
async def _open_channel(self, stream_id: int, flags: int) -> None:
flags |= XpcFlags.ALWAYS_SET
await self._send_frame(HeadersFrame(stream_id=stream_id, flags=['END_HEADERS']))
await self._send_frame(
DataFrame(stream_id=stream_id, data=XpcWrapper.build({'size': 0, 'flags': flags, 'payload': None})))

async def _send_frame(self, frame: Frame) -> None:
await self.service_connection.aio_sendall(frame.serialize())
self._writer.write(frame.serialize())
await self._writer.drain()

async def _receive_next_data_frame(self) -> DataFrame:
while True:
Expand All @@ -173,7 +181,7 @@ async def _receive_next_data_frame(self) -> DataFrame:
return frame

async def _receive_frame(self) -> Frame:
buf = await self._recvall(FRAME_HEADER_SIZE)
buf = await self._reader.readexactly(FRAME_HEADER_SIZE)
frame, additional_size = Frame.parse_frame_header(memoryview(buf))
frame.parse_body(memoryview(await self._recvall(additional_size)))
return frame
Expand All @@ -182,7 +190,7 @@ async def _recvall(self, size: int) -> bytes:
data = b''
while len(data) < size:
try:
chunk = await self.service_connection.aio_recvall(size - len(data))
chunk = await self._reader.readexactly(size - len(data))
except IncompleteReadError:
raise ConnectionAbortedError()
data += chunk
Expand Down

0 comments on commit 34693e1

Please sign in to comment.