Skip to content

Commit

Permalink
Merge pull request #230 from yotamN/feature/async-exports
Browse files Browse the repository at this point in the history
Add async variant to rpc calls
  • Loading branch information
yotamN authored Feb 3, 2023
2 parents 1b1e345 + fb86c3a commit 1d35663
Showing 1 changed file with 120 additions and 24 deletions.
144 changes: 120 additions & 24 deletions frida/core.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
from __future__ import annotations

import asyncio
import dataclasses
import fnmatch
import functools
import json
import sys
import threading
import traceback
import warnings
from types import TracebackType
from typing import (
Any,
AnyStr,
Awaitable,
Callable,
Dict,
List,
Expand Down Expand Up @@ -156,7 +161,7 @@ def terminate(self) -> None:
self._impl.terminate()


class ScriptExports:
class ScriptExportsSync:
"""
Proxy object that expose all the RPC exports of a script as attributes on this class
Expand All @@ -166,7 +171,7 @@ class ScriptExports:
def __init__(self, script: "Script") -> None:
self._script = script

def __getattr__(self, name: str) -> Any:
def __getattr__(self, name: str) -> Callable[..., Any]:
script = self._script
js_name = _to_camel_case(name)

Expand All @@ -176,7 +181,33 @@ def method(*args: Any, **kwargs: Any) -> Any:
return method

def __dir__(self) -> List[str]:
return self._script.list_exports()
return self._script.list_exports_sync()


ScriptExports = ScriptExportsSync


class ScriptExportsAsync:
"""
Proxy object that expose all the RPC exports of a script as attributes on this class
A method named exampleMethod in a script will be called with instance.example_method on this object
"""

def __init__(self, script: "Script") -> None:
self._script = script

def __getattr__(self, name: str) -> Callable[..., Awaitable[Any]]:
script = self._script
js_name = _to_camel_case(name)

async def method(*args: Any, **kwargs: Any) -> Any:
return await script._rpc_request_async("call", js_name, args, **kwargs)

return method

def __dir__(self) -> List[str]:
return self._script.list_exports_sync()


class ScriptErrorMessage(TypedDict):
Expand All @@ -198,22 +229,47 @@ class ScriptPayloadMessage(TypedDict):
ScriptDestroyedCallback = Callable[[], None]


class RPCException(Exception):
"""
Wraps remote errors from the script RPC
"""

def __str__(self) -> str:
return str(self.args[2]) if len(self.args) >= 3 else str(self.args[0])


class Script:
def __init__(self, impl: _frida.Script) -> None:
self.exports = ScriptExports(self)
self.exports_sync = ScriptExportsSync(self)
self.exports_async = ScriptExportsAsync(self)

self._impl = impl

self._on_message_callbacks: List[ScriptMessageCallback] = []
self._log_handler: Callable[[str, str], None] = self.default_log_handler

self._pending: Dict[int, Callable[..., Any]] = {}
self._pending: Dict[
int, Callable[[Optional[Any], Optional[Union[RPCException, _frida.InvalidOperationError]]], None]
] = {}
self._next_request_id = 1
self._cond = threading.Condition()

impl.on("destroyed", self._on_destroyed)
impl.on("message", self._on_message)

@property
def exports(self) -> ScriptExportsSync:
"""
The old way of retrieving the synchronous exports caller
"""

warnings.warn(
"Script.exports will become asynchronous in the future, use the explicit Script.exports_sync instead",
DeprecationWarning,
stacklevel=2,
)
return self.exports_sync

def __repr__(self) -> str:
return repr(self._impl)

Expand Down Expand Up @@ -349,7 +405,16 @@ def default_log_handler(self, level: str, text: str) -> None:
else:
print(text, file=sys.stderr)

def list_exports(self) -> List[str]:
async def list_exports_async(self) -> List[str]:
"""
Asynchronously list all the exported attributes from the script's rpc
"""

result = await self._rpc_request_async("list")
assert isinstance(result, list)
return result

def list_exports_sync(self) -> List[str]:
"""
List all the exported attributes from the script's rpc
"""
Expand All @@ -358,11 +423,42 @@ def list_exports(self) -> List[str]:
assert isinstance(result, list)
return result

def list_exports(self) -> List[str]:
"""
List all the exported attributes from the script's rpc
"""

warnings.warn(
"Script.list_exports will become asynchronous in the future, use the explicit Script.list_exports_sync instead",
DeprecationWarning,
stacklevel=2,
)
return self.list_exports_sync()

def _rpc_request_async(self, *args: Any) -> asyncio.Future[Any]:
loop = asyncio.get_event_loop()
future: asyncio.Future[Any] = asyncio.Future()

def on_complete(value: Any, error: Optional[Union[RPCException, _frida.InvalidOperationError]]) -> None:
if error is not None:
loop.call_soon_threadsafe(future.set_exception, error)
else:
loop.call_soon_threadsafe(future.set_result, value)

request_id = self._append_pending(on_complete)

if not self.is_destroyed:
self._send_rpc_call(request_id, *args)
else:
self._on_destroyed()

return future

@cancellable
def _rpc_request(self, *args: Any) -> Any:
result = RPCResult()

def on_complete(value: Any, error: Union[None, Union[RPCException, _frida.InvalidOperationError]]) -> None:
def on_complete(value: Any, error: Optional[Union[RPCException, _frida.InvalidOperationError]]) -> None:
with self._cond:
result.finished = True
result.value = value
Expand All @@ -373,15 +469,10 @@ def on_cancelled() -> None:
self._pending.pop(request_id, None)
on_complete(None, None)

with self._cond:
request_id = self._next_request_id
self._next_request_id += 1
self._pending[request_id] = on_complete
request_id = self._append_pending(on_complete)

if not self.is_destroyed:
message = ["frida:rpc", request_id]
message.extend(args)
self.post(message)
self._send_rpc_call(request_id, *args)

cancellable = Cancellable.get_current()
cancel_handler = cancellable.connect(on_cancelled)
Expand All @@ -401,7 +492,21 @@ def on_cancelled() -> None:

return result.value

def _on_rpc_message(self, request_id: int, operation: str, params, data) -> None:
def _append_pending(
self, callback: Callable[[Any, Optional[Union[RPCException, _frida.InvalidOperationError]]], None]
) -> int:
with self._cond:
request_id = self._next_request_id
self._next_request_id += 1
self._pending[request_id] = callback
return request_id

def _send_rpc_call(self, request_id: int, *args: Any) -> None:
message = ["frida:rpc", request_id]
message.extend(args)
self.post(message)

def _on_rpc_message(self, request_id: int, operation: str, params: List[Any], data: Optional[Any]) -> None:
if operation in ("ok", "error"):
callback = self._pending.pop(request_id, None)
if callback is None:
Expand Down Expand Up @@ -1172,15 +1277,6 @@ def off(self, signal: str, callback: Callable[..., Any]) -> None:
self._impl.off(signal, callback)


class RPCException(Exception):
"""
Wraps remote errors from the script RPC
"""

def __str__(self) -> str:
return str(self.args[2]) if len(self.args) >= 3 else str(self.args[0])


class EndpointParameters:
def __init__(
self,
Expand Down

0 comments on commit 1d35663

Please sign in to comment.