From 5f4af8dbaf793725c89fe2a19e7f7cbbf751cbf0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9phine=20Wolf=20Oberholtzer?= Date: Sun, 14 Jul 2024 21:27:22 -0400 Subject: [PATCH] Implement server lifecycle events --- Makefile | 2 +- pyproject.toml | 1 + supriya/contexts/__init__.py | 3 +- supriya/contexts/core.py | 38 +- supriya/contexts/nonrealtime.py | 9 +- supriya/contexts/realtime.py | 522 ++++++--- supriya/contexts/shm.cpp | 1318 ++++++++++++----------- supriya/osc.py | 1133 ------------------- supriya/osc/__init__.py | 32 + supriya/osc/asynchronous.py | 193 ++++ supriya/osc/messages.py | 505 +++++++++ supriya/osc/protocols.py | 383 +++++++ supriya/osc/threaded.py | 234 ++++ supriya/scsynth.py | 395 ++++--- supriya/typing.py | 5 + supriya/ugens/triggers.py | 2 +- tests/book/conftest.py | 3 +- tests/contexts/test_Server_lifecycle.py | 616 +++++++++-- tests/patterns/test_Pattern.py | 8 +- tests/test_osc.py | 133 ++- 20 files changed, 3322 insertions(+), 2213 deletions(-) delete mode 100644 supriya/osc.py create mode 100644 supriya/osc/__init__.py create mode 100644 supriya/osc/asynchronous.py create mode 100644 supriya/osc/messages.py create mode 100644 supriya/osc/protocols.py create mode 100644 supriya/osc/threaded.py diff --git a/Makefile b/Makefile index 17b838119..876ee7530 100644 --- a/Makefile +++ b/Makefile @@ -63,7 +63,7 @@ isort: ## Reformat via isort lint: reformat flake8 mypy ## Run all linters mypy: ## Type-check via mypy - mypy ${project}/ + mypy ${project}/ tests/ mypy-cov: ## Type-check via mypy with coverage reported to ./mypycov/ mypy --html-report ./mypycov/ ${project}/ diff --git a/pyproject.toml b/pyproject.toml index 62c8ddf01..b23b9fe54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ classifiers = [ ] dependencies = [ "platformdirs >= 4.0.0", + "psutil", "uqbar >= 0.7.3", ] description = "A Python API for SuperCollider" diff --git a/supriya/contexts/__init__.py b/supriya/contexts/__init__.py index 2ece96390..c2d3ad723 100644 --- a/supriya/contexts/__init__.py +++ b/supriya/contexts/__init__.py @@ -2,7 +2,7 @@ Tools for interacting with scsynth-compatible execution contexts. """ -from .core import Context +from .core import BootStatus, Context from .entities import ( Buffer, BufferGroup, @@ -19,6 +19,7 @@ __all__ = [ "AsyncServer", "BaseServer", + "BootStatus", "Buffer", "BufferGroup", "Bus", diff --git a/supriya/contexts/core.py b/supriya/contexts/core.py index e288338f4..c914599bc 100644 --- a/supriya/contexts/core.py +++ b/supriya/contexts/core.py @@ -7,6 +7,7 @@ import abc import contextlib import dataclasses +import enum import itertools import threading from os import PathLike @@ -101,6 +102,13 @@ ) +class BootStatus(enum.IntEnum): + OFFLINE = 0 + BOOTING = 1 + ONLINE = 2 + QUITTING = 3 + + @dataclasses.dataclass class Moment: """ @@ -204,13 +212,20 @@ class Context(metaclass=abc.ABCMeta): ### INITIALIZER ### - def __init__(self, options: Optional[Options], **kwargs) -> None: + def __init__( + self, + options: Optional[Options], + name: Optional[str] = None, + **kwargs, + ) -> None: self._audio_bus_allocator = BlockAllocator() + self._boot_status = BootStatus.OFFLINE self._buffer_allocator = BlockAllocator() self._client_id = 0 self._control_bus_allocator = BlockAllocator() self._latency = 0.0 self._lock = threading.RLock() + self._name = name self._node_id_allocator = NodeIdAllocator() self._options = new(options or Options(), **kwargs) self._sync_id = self._sync_id_minimum = 0 @@ -1380,6 +1395,20 @@ def audio_output_bus_group(self) -> BusGroup: count=self.options.output_bus_channel_count, ) + @property + def boot_status(self) -> BootStatus: + """ + Get the server's boot status. + """ + return self._boot_status + + @property + def default_group(self) -> Group: + """ + Get the server's default group. + """ + return self.root_node + @property def client_id(self) -> int: """ @@ -1394,6 +1423,13 @@ def latency(self) -> float: """ return self._latency + @property + def name(self) -> Optional[str]: + """ + Get the context's optional name. + """ + return self._name + @property def options(self) -> Options: """ diff --git a/supriya/contexts/nonrealtime.py b/supriya/contexts/nonrealtime.py index 340a70f5d..dd663b108 100644 --- a/supriya/contexts/nonrealtime.py +++ b/supriya/contexts/nonrealtime.py @@ -2,7 +2,6 @@ Tools for interacting with non-realtime execution contexts. """ -import asyncio import hashlib import logging import platform @@ -22,7 +21,7 @@ from ..scsynth import AsyncNonrealtimeProcessProtocol, Options from ..typing import HeaderFormatLike, SampleFormatLike, SupportsOsc from ..ugens import SynthDef -from .core import Context, ContextError, ContextObject, Node +from .core import BootStatus, Context, ContextError, ContextObject, Node from .requests import DoNothing, RequestBundle, Requestable logger = logging.getLogger(__name__) @@ -40,6 +39,7 @@ class Score(Context): def __init__(self, options: Optional[Options] = None, **kwargs): super().__init__(options=options, **kwargs) + self._boot_status = BootStatus.ONLINE self._requests: Dict[float, List[Requestable]] = {} ### CLASS METHODS ### @@ -179,10 +179,9 @@ async def render( ] ) # render the datagram - exit_future = asyncio.get_running_loop().create_future() - protocol = AsyncNonrealtimeProcessProtocol(exit_future) + protocol = AsyncNonrealtimeProcessProtocol() await protocol.run(command, render_directory_path_) - exit_code: int = await exit_future + exit_code: int = await protocol.exit_future assert render_directory_path_ / render_file_name if output_file_path_: shutil.copy( diff --git a/supriya/contexts/realtime.py b/supriya/contexts/realtime.py index d62e8c4d7..42870762a 100644 --- a/supriya/contexts/realtime.py +++ b/supriya/contexts/realtime.py @@ -3,13 +3,16 @@ """ import asyncio -import dataclasses +import concurrent.futures import enum import logging import warnings +from collections.abc import Sequence as SequenceABC from typing import ( TYPE_CHECKING, + Callable, Dict, + Iterable, List, Optional, Sequence, @@ -35,6 +38,7 @@ from ..osc import ( AsyncOscProtocol, HealthCheck, + OscBundle, OscMessage, OscProtocol, OscProtocolOffline, @@ -46,9 +50,10 @@ ProcessProtocol, SyncProcessProtocol, ) -from ..typing import SupportsOsc +from ..typing import FutureLike, SupportsOsc from ..ugens import SynthDef from .core import ( + BootStatus, Buffer, Bus, Context, @@ -108,7 +113,6 @@ class FailWarning(Warning): DEFAULT_HEALTHCHECK = HealthCheck( active=False, backoff_factor=1.5, - callback=lambda: None, max_attempts=5, request_pattern=["/status"], response_pattern=["/status.reply"], @@ -116,11 +120,21 @@ class FailWarning(Warning): ) -class BootStatus(enum.IntEnum): - OFFLINE = 0 - BOOTING = 1 - ONLINE = 2 - QUITTING = 3 +class ServerLifecycleEvent(enum.Enum): + BOOTING = enum.auto() + PROCESS_BOOTED = enum.auto() + CONNECTING = enum.auto() + OSC_CONNECTED = enum.auto() + CONNECTED = enum.auto() + BOOTED = enum.auto() + OSC_PANICKED = enum.auto() + PROCESS_PANICKED = enum.auto() + QUITTING = enum.auto() + DISCONNECTING = enum.auto() + OSC_DISCONNECTED = enum.auto() + DISCONNECTED = enum.auto() + PROCESS_QUIT = enum.auto() + QUIT = enum.auto() class BaseServer(Context): @@ -139,24 +153,30 @@ class BaseServer(Context): def __init__( self, + boot_future: FutureLike[bool], + exit_future: FutureLike[bool], options: Optional[Options], osc_protocol: OscProtocol, process_protocol: ProcessProtocol, + name: Optional[str] = None, **kwargs, ) -> None: - super().__init__(options) - self._latency = 0.1 - self._is_owner = False + super().__init__(options, name=name, **kwargs) + self._boot_future = boot_future self._boot_status = BootStatus.OFFLINE self._buffers: Set[int] = set() + self._exit_future = exit_future + self._is_owner = False + self._latency = 0.1 + self._lifecycle_event_callbacks: Dict[ServerLifecycleEvent, List[Callable]] = {} self._maximum_logins = 1 self._node_active: Dict[int, bool] = {} self._node_children: Dict[int, List[int]] = {} self._node_parents: Dict[int, int] = {} self._osc_protocol = osc_protocol self._process_protocol = process_protocol - self._shm: Optional["ServerSHM"] = None self._setup_osc_callbacks() + self._shm: Optional["ServerSHM"] = None self._status: Optional[StatusInfo] = None ### SPECIAL METHODS ### @@ -185,6 +205,20 @@ def __contains__(self, object_: ContextObject) -> bool: ### PRIVATE METHODS ### + def _add_node_to_children( + self, id_: int, parent_id: int, previous_id: int, next_id: int + ) -> None: + self._node_parents[id_] = parent_id + children = self._node_children[parent_id] + if previous_id == -1: + children.insert(0, id_) + elif next_id == -1: + children.append(id_) + elif previous_id in children: + children.insert(children.index(previous_id) + 1, id_) + elif next_id in children: + children.insert(children.index(next_id), id_) + def _free_id( self, type_: Type[ContextObject], @@ -193,92 +227,82 @@ def _free_id( ) -> None: self._get_allocator(type_, calculation_rate).free(id_) - def _handle_osc_callbacks(self, message: OscMessage) -> None: - def _handle_done(message: OscMessage) -> None: - if message.contents[0] in ( - "/b_alloc", - "/b_allocRead", - "/b_allocReadChannel", - ): - self._buffers.add(message.contents[1]) - elif message.contents[0] == "/b_free": - if message.contents[1] in self._buffers: - self._buffers.remove(message.contents[1]) - self._free_id(Buffer, message.contents[1]) + def _handle_done_b_alloc(self, message: OscMessage) -> None: + with self._lock: + self._buffers.add(message.contents[1]) - def _handle_fail(message: OscMessage) -> None: - warnings.warn(" ".join(str(x) for x in message.contents), FailWarning) + def _handle_done_b_alloc_read(self, message: OscMessage) -> None: + with self._lock: + self._buffers.add(message.contents[1]) + + def _handle_done_b_alloc_read_channel(self, message: OscMessage) -> None: + with self._lock: + self._buffers.add(message.contents[1]) - def _handle_n_end(message: OscMessage) -> None: + def _handle_done_b_free(self, message: OscMessage) -> None: + with self._lock: + if message.contents[1] in self._buffers: + self._buffers.remove(message.contents[1]) + self._free_id(Buffer, message.contents[1]) + + def _handle_done_quit(self, message: OscMessage): + raise NotImplementedError + + def _handle_fail(self, message: OscMessage) -> None: + warnings.warn(" ".join(str(x) for x in message.contents), FailWarning) + + def _handle_n_end(self, message: OscMessage) -> None: + with self._lock: id_, parent_id, *_ = message.contents if parent_id == -1: parent_id = self._node_parents.get(id_) if parent_id is not None: - _remove_node_from_children(id_, parent_id) + self._remove_node_from_children(id_, parent_id) self._free_id(Node, id_) self._node_active.pop(id_, None) self._node_children.pop(id_, None) self._node_parents.pop(id_, None) - def _handle_n_go(message: OscMessage) -> None: + def _handle_n_go(self, message: OscMessage) -> None: + with self._lock: id_, parent_id, previous_id, next_id, is_group, *_ = message.contents self._node_parents[id_] = parent_id self._node_active[id_] = True if is_group: self._node_children[id_] = [] - _add_node_to_children(id_, parent_id, previous_id, next_id) + self._add_node_to_children(id_, parent_id, previous_id, next_id) - def _handle_n_move(message: OscMessage) -> None: + def _handle_n_move(self, message: OscMessage) -> None: + with self._lock: id_, parent_id, previous_id, next_id, *_ = message.contents old_parent_id = self._node_parents[id_] - _remove_node_from_children(id_, old_parent_id) - _add_node_to_children(id_, parent_id, previous_id, next_id) + self._remove_node_from_children(id_, old_parent_id) + self._add_node_to_children(id_, parent_id, previous_id, next_id) - def _handle_n_off(message: OscMessage) -> None: + def _handle_n_off(self, message: OscMessage) -> None: + with self._lock: self._node_active[message.contents[0]] = False - def _handle_n_on(message: OscMessage) -> None: + def _handle_n_on(self, message: OscMessage) -> None: + with self._lock: self._node_active[message.contents[0]] = True - def _handle_status_reply(message: OscMessage): + def _handle_status_reply(self, message: OscMessage): + with self._lock: self._status = cast(StatusInfo, StatusInfo.from_osc(message)) - def _add_node_to_children( - id_: int, parent_id: int, previous_id: int, next_id: int - ) -> None: - self._node_parents[id_] = parent_id - children = self._node_children[parent_id] - if previous_id == -1: - children.insert(0, id_) - elif next_id == -1: - children.append(id_) - elif previous_id in children: - children.insert(children.index(previous_id) + 1, id_) - elif next_id in children: - children.insert(children.index(next_id), id_) - - def _remove_node_from_children(id_: int, parent_id: int) -> None: - children = self._node_children[parent_id] - try: - children.pop(children.index(id_)) - except ValueError: - pass - - handlers = { - "/done": _handle_done, - "/fail": _handle_fail, - "/n_end": _handle_n_end, - "/n_go": _handle_n_go, - "/n_move": _handle_n_move, - "/n_off": _handle_n_off, - "/n_on": _handle_n_on, - "/status.reply": _handle_status_reply, - } + def _on_lifecycle_event(self, event: ServerLifecycleEvent) -> None: + for callback in self._lifecycle_event_callbacks.get(event, []): + if asyncio.iscoroutine(result := callback(event)): + asyncio.get_running_loop().create_task(result) - with self._lock: - handler = handlers.get(str(message.address)) - if handler is not None: - handler(message) + def _remove_node_from_children(self, id_: int, parent_id: int) -> None: + if not (children := self._node_children.get(parent_id, [])): + return + try: + children.pop(children.index(id_)) + except ValueError: + pass def _resolve_node(self, node: Union[Node, SupportsInt, None]) -> int: if node is None: @@ -286,19 +310,21 @@ def _resolve_node(self, node: Union[Node, SupportsInt, None]) -> int: return int(node) def _setup_osc_callbacks(self) -> None: - for pattern in ( - ["/done"], - ["/fail"], - ["/n_end"], - ["/n_go"], - ["/n_move"], - ["/n_off"], - ["/n_on"], - ["/status.reply"], - ): - self._osc_protocol.register( - pattern=pattern, procedure=self._handle_osc_callbacks - ) + for pattern, procedure in [ + (["/done", "/b_alloc"], self._handle_done_b_alloc), + (["/done", "/b_allocRead"], self._handle_done_b_alloc_read), + (["/done", "/b_allocReadChannel"], self._handle_done_b_alloc_read_channel), + (["/done", "/b_free"], self._handle_done_b_free), + (["/done", "/quit"], self._handle_done_quit), + (["/fail"], self._handle_fail), + (["/n_end"], self._handle_n_end), + (["/n_go"], self._handle_n_go), + (["/n_move"], self._handle_n_move), + (["/n_off"], self._handle_n_off), + (["/n_on"], self._handle_n_on), + (["/status.reply"], self._handle_status_reply), + ]: + self._osc_protocol.register(pattern=pattern, procedure=procedure) def _setup_shm(self) -> None: try: @@ -340,16 +366,33 @@ def _validate_moment_timestamp(self, seconds: Optional[float]) -> None: ### PUBLIC METHODS ### - def send(self, message: SupportsOsc) -> None: + def on( + self, + event: Union[ServerLifecycleEvent, Iterable[ServerLifecycleEvent]], + callback: Callable[[ServerLifecycleEvent], None], + ) -> None: + if isinstance(event, ServerLifecycleEvent): + events_ = [event] + else: + events_ = list(set(event)) + for event_ in events_: + if callback not in ( + callbacks := self._lifecycle_event_callbacks.setdefault(event_, []) + ): + callbacks.append(callback) + + def send( + self, message: Union[OscMessage, OscBundle, SupportsOsc, SequenceABC, str] + ) -> None: """ Send a message to the execution context. :param message: The message to send. """ - if self._boot_status not in (BootStatus.BOOTING, BootStatus.ONLINE): + if self._boot_status == BootStatus.OFFLINE: raise ServerOffline self._osc_protocol.send( - message.to_osc() if hasattr(message, "to_osc") else message + message.to_osc() if isinstance(message, SupportsOsc) else message ) def set_latency(self, latency: float) -> None: @@ -363,11 +406,14 @@ def set_latency(self, latency: float) -> None: ### PUBLIC PROPERTIES ### @property - def boot_status(self) -> BootStatus: + def boot_future(self) -> FutureLike[bool]: """ - Get the server's boot status. + Get the server's boot future. + + Only reference this _after_ booting or connecting, as the future is + created when booting or connecting. """ - return self._boot_status + return self._boot_future @property def default_group(self) -> Group: @@ -376,6 +422,16 @@ def default_group(self) -> Group: """ return Group(context=self, id_=self._client_id + 1) + @property + def exit_future(self) -> FutureLike[bool]: + """ + Get the server's exit future. + + Only reference this _after_ booting or connecting, as the future is + created when booting or connecting. + """ + return self._exit_future + @property def is_owner(self) -> bool: """ @@ -415,24 +471,54 @@ class Server(BaseServer): ### INITIALIZER ### - def __init__(self, options: Optional[Options] = None, **kwargs): + def __init__( + self, + options: Optional[Options] = None, + name: Optional[str] = None, + **kwargs, + ): super().__init__( - osc_protocol=ThreadedOscProtocol(), + boot_future=concurrent.futures.Future(), + exit_future=concurrent.futures.Future(), + name=name, options=options, - process_protocol=SyncProcessProtocol(), + osc_protocol=ThreadedOscProtocol( + name=name, + on_connect_callback=lambda: self._on_lifecycle_event( + ServerLifecycleEvent.OSC_CONNECTED, + ), + on_disconnect_callback=lambda: self._on_lifecycle_event( + ServerLifecycleEvent.OSC_DISCONNECTED, + ), + on_panic_callback=self._on_osc_panicked, + ), + process_protocol=SyncProcessProtocol( + name=name, + on_boot_callback=lambda: self._on_lifecycle_event( + ServerLifecycleEvent.PROCESS_BOOTED + ), + on_panic_callback=lambda: self._on_lifecycle_event( + ServerLifecycleEvent.PROCESS_PANICKED + ), + on_quit_callback=lambda: self._on_lifecycle_event( + ServerLifecycleEvent.PROCESS_QUIT + ), + ), **kwargs, ) ### PRIVATE METHODS ### def _connect(self) -> None: - logger.info("Connecting") + logger.info( + f"[{self._options.ip_address}:{self._options.port}/{self.name or hex(id(self))}] " + "connecting ..." + ) + self._on_lifecycle_event(ServerLifecycleEvent.CONNECTING) cast(ThreadedOscProtocol, self._osc_protocol).connect( ip_address=self._options.ip_address, port=self._options.port, - healthcheck=dataclasses.replace( - DEFAULT_HEALTHCHECK, callback=self._shutdown - ), + healthcheck=DEFAULT_HEALTHCHECK, ) self._setup_notifications() self._contexts.add(self) @@ -441,24 +527,63 @@ def _connect(self) -> None: if self._client_id == 0: self._setup_system() self.sync() + self._osc_protocol.boot_future.result() + logger.info( + f"[{self._options.ip_address}:{self._options.port}/{self.name or hex(id(self))}] " + "... connected!" + ) self._boot_status = BootStatus.ONLINE - logger.info("Connected") + if not self.is_owner: + self._boot_future.set_result(True) + self._on_lifecycle_event(ServerLifecycleEvent.CONNECTED) def _disconnect(self) -> None: - logger.info("Disconnecting") - self._boot_status = BootStatus.QUITTING - self._teardown_shm() + logger.info( + f"[{self._options.ip_address}:{self._options.port}/{self.name or hex(id(self))}] " + "disconnecting ..." + ) + self._on_lifecycle_event(ServerLifecycleEvent.DISCONNECTING) self._osc_protocol.disconnect() self._teardown_shm() self._teardown_state() if self in self._contexts: self._contexts.remove(self) + was_owner = self._is_owner = True self._is_owner = False + logger.info( + f"[{self._options.ip_address}:{self._options.port}/{self.name or hex(id(self))}] " + "... disconnected!" + ) self._boot_status = BootStatus.OFFLINE - logger.info("Disconnected") + self._on_lifecycle_event(ServerLifecycleEvent.DISCONNECTED) + if not was_owner: + if not self._boot_future.done(): + self._boot_future.set_result(True) + if not self._exit_future.done(): + self._exit_future.set_result(True) + + def _handle_done_quit(self, message: OscMessage) -> None: + logger.info( + f"[{self._options.ip_address}:{self._options.port}/{self.name or hex(id(self))}] " + f"handling {message.to_list()} ..." + ) + if self._boot_status == BootStatus.ONLINE: + self._shutdown() + else: + logger.info( + f"[{self._options.ip_address}:{self._options.port}/{self.name or hex(id(self))}] " + f"... already quitting!" + ) + + def _on_osc_panicked(self) -> None: + self._on_lifecycle_event(ServerLifecycleEvent.OSC_PANICKED) + self._shutdown() def _setup_notifications(self) -> None: - logger.info("Setting up notifications") + logger.info( + f"[{self._options.ip_address}:{self._options.port}/{self.name or hex(id(self))}] " + "setting up notifications ..." + ) response = ToggleNotifications(True).communicate(server=self) if response is None or not isinstance(response, (DoneInfo, FailInfo)): raise RuntimeError @@ -489,17 +614,37 @@ def boot(self, *, options: Optional[Options] = None, **kwargs) -> "Server": """ if self._boot_status != BootStatus.OFFLINE: raise ServerOnline + logger.info( + f"[{self._options.ip_address}:{self._options.port}/{self.name or hex(id(self))}] " + "booting ..." + ) + self._boot_future = concurrent.futures.Future() + self._exit_future = concurrent.futures.Future() self._boot_status = BootStatus.BOOTING + self._on_lifecycle_event(ServerLifecycleEvent.BOOTING) self._options = new(options or self._options, **kwargs) - logger.debug(f"Options: {self._options}") + logger.debug( + f"[{self._options.ip_address}:{self._options.port}/{self.name or hex(id(self))}] " + f"options: {self._options}" + ) try: - self._process_protocol.boot(self._options) + cast(SyncProcessProtocol, self._process_protocol).boot(self._options) except ServerCannotBoot: + if not self._boot_future.done(): + self._boot_future.set_result(False) + self._exit_future.set_result(False) self._boot_status = BootStatus.OFFLINE raise self._is_owner = True - self._connect() self._setup_shm() + self._connect() + logger.info( + f"[{self._options.ip_address}:{self._options.port}/{self.name or hex(id(self))}] " + "... booted!" + ) + if self.is_owner: + self._boot_future.set_result(True) + self._on_lifecycle_event(ServerLifecycleEvent.BOOTED) return self def connect(self, *, options: Optional[Options] = None, **kwargs) -> "Server": @@ -509,8 +654,10 @@ def connect(self, *, options: Optional[Options] = None, **kwargs) -> "Server": :param options: The context's options. :param kwargs: Keyword arguments for options. """ - if self._boot_status in (BootStatus.BOOTING, BootStatus.ONLINE): + if self._boot_status != BootStatus.OFFLINE: raise ServerOnline + self._boot_future = concurrent.futures.Future() + self._exit_future = concurrent.futures.Future() self._boot_status = BootStatus.BOOTING self._options = new(options or self._options, **kwargs) self._is_owner = False @@ -525,6 +672,7 @@ def disconnect(self) -> "Server": raise ServerOffline if self._is_owner: raise OwnedServerShutdown("Cannot disconnect from owned server.") + self._boot_status = BootStatus.QUITTING self._disconnect() return self @@ -756,13 +904,27 @@ def quit(self, force: bool = False) -> "Server": raise UnownedServerShutdown( "Cannot quit unowned server without force flag." ) + logger.info( + f"[{self._options.ip_address}:{self._options.port}/{self.name or hex(id(self))}] " + "quitting ..." + ) + self._boot_status = BootStatus.QUITTING + self._on_lifecycle_event(ServerLifecycleEvent.QUITTING) try: Quit().communicate(server=self) except OscProtocolOffline: pass - self._teardown_shm() - self._process_protocol.quit() self._disconnect() + cast(SyncProcessProtocol, self._process_protocol).quit() + logger.info( + f"[{self._options.ip_address}:{self._options.port}/{self.name or hex(id(self))}] " + "... quit!" + ) + self._on_lifecycle_event(ServerLifecycleEvent.QUIT) + if not self._boot_future.done(): + self._boot_future.set_result(True) + if not self._exit_future.done(): + self._exit_future.set_result(True) return self def reboot(self) -> "Server": @@ -814,24 +976,51 @@ class AsyncServer(BaseServer): ### INITIALIZER ### - def __init__(self, options: Optional[Options] = None, **kwargs): + def __init__( + self, options: Optional[Options] = None, name: Optional[str] = None, **kwargs + ): super().__init__( - osc_protocol=AsyncOscProtocol(), + boot_future=asyncio.Future(), + exit_future=asyncio.Future(), + name=name, options=options, - process_protocol=AsyncProcessProtocol(), + osc_protocol=AsyncOscProtocol( + name=name, + on_connect_callback=lambda: self._on_lifecycle_event( + ServerLifecycleEvent.OSC_CONNECTED, + ), + on_disconnect_callback=lambda: self._on_lifecycle_event( + ServerLifecycleEvent.OSC_DISCONNECTED, + ), + on_panic_callback=self._on_osc_panicked, + ), + process_protocol=AsyncProcessProtocol( + name=name, + on_boot_callback=lambda: self._on_lifecycle_event( + ServerLifecycleEvent.PROCESS_BOOTED + ), + on_panic_callback=lambda: self._on_lifecycle_event( + ServerLifecycleEvent.PROCESS_PANICKED + ), + on_quit_callback=lambda: self._on_lifecycle_event( + ServerLifecycleEvent.PROCESS_QUIT + ), + ), **kwargs, ) ### PRIVATE METHODS ### async def _connect(self) -> None: - logger.info("Connecting") + logger.info( + f"[{self._options.ip_address}:{self._options.port}/{self.name or hex(id(self))}] " + "connecting ..." + ) + self._on_lifecycle_event(ServerLifecycleEvent.CONNECTING) await cast(AsyncOscProtocol, self._osc_protocol).connect( ip_address=self._options.ip_address, port=self._options.port, - healthcheck=dataclasses.replace( - DEFAULT_HEALTHCHECK, callback=self._shutdown - ), + healthcheck=DEFAULT_HEALTHCHECK, ) await self._setup_notifications() self._contexts.add(self) @@ -840,23 +1029,63 @@ async def _connect(self) -> None: if self._client_id == 0: self._setup_system() await self.sync() + await cast(asyncio.Future, self._osc_protocol.boot_future) + logger.info( + f"[{self._options.ip_address}:{self._options.port}/{self.name or hex(id(self))}] " + "... connected!" + ) self._boot_status = BootStatus.ONLINE - logger.info("Connected") + if not self.is_owner: + self._boot_future.set_result(True) + self._on_lifecycle_event(ServerLifecycleEvent.CONNECTED) async def _disconnect(self) -> None: - logger.info("Disconnecting") - self._boot_status = BootStatus.QUITTING - self._osc_protocol.disconnect() + logger.info( + f"[{self._options.ip_address}:{self._options.port}/{self.name or hex(id(self))}] " + "disconnecting ..." + ) + self._on_lifecycle_event(ServerLifecycleEvent.DISCONNECTING) + await cast(AsyncOscProtocol, self._osc_protocol).disconnect() self._teardown_shm() self._teardown_state() if self in self._contexts: self._contexts.remove(self) + was_owner = self._is_owner = True self._is_owner = False + logger.info( + f"[{self._options.ip_address}:{self._options.port}/{self.name or hex(id(self))}] " + "... disconnected!" + ) self._boot_status = BootStatus.OFFLINE - logger.info("Disconnected") + self._on_lifecycle_event(ServerLifecycleEvent.DISCONNECTED) + if not was_owner: + if not self._boot_future.done(): + self._boot_future.set_result(True) + if not self._exit_future.done(): + self._exit_future.set_result(True) + + async def _handle_done_quit(self, message: OscMessage) -> None: + logger.info( + f"[{self._options.ip_address}:{self._options.port}/{self.name or hex(id(self))}] " + f"handling {message.to_list()} ..." + ) + if self._boot_status == BootStatus.ONLINE: + await self._shutdown() + else: + logger.info( + f"[{self._options.ip_address}:{self._options.port}/{self.name or hex(id(self))}] " + f"... already quitting!" + ) + + async def _on_osc_panicked(self) -> None: + self._on_lifecycle_event(ServerLifecycleEvent.OSC_PANICKED) + await self._shutdown() async def _setup_notifications(self) -> None: - logger.info("Setting up notifications") + logger.info( + f"[{self._options.ip_address}:{self._options.port}/{self.name or hex(id(self))}] " + "setting up notifications ..." + ) response = await ToggleNotifications(True).communicate_async(server=self) if response is None or not isinstance(response, (DoneInfo, FailInfo)): raise RuntimeError @@ -889,16 +1118,38 @@ async def boot( """ if self._boot_status != BootStatus.OFFLINE: raise ServerOnline + logger.info( + f"[{self._options.ip_address}:{self._options.port}/{self.name or hex(id(self))}] " + "booting ..." + ) + loop = asyncio.get_running_loop() + self._boot_future = loop.create_future() + self._exit_future = loop.create_future() self._boot_status = BootStatus.BOOTING + self._on_lifecycle_event(ServerLifecycleEvent.BOOTING) self._options = new(options or self._options, **kwargs) - logger.debug(f"Options: {self._options}") + logger.debug( + f"[{self._options.ip_address}:{self._options.port}/{self.name or hex(id(self))}] " + f"options: {self._options}" + ) try: - await self._process_protocol.boot(self._options) + await cast(AsyncProcessProtocol, self._process_protocol).boot(self._options) except ServerCannotBoot: + if not self._boot_future.done(): + self._boot_future.set_result(False) + self._exit_future.set_result(False) self._boot_status = BootStatus.OFFLINE raise self._is_owner = True + self._setup_shm() await self._connect() + logger.info( + f"[{self._options.ip_address}:{self._options.port}/{self.name or hex(id(self))}] " + "... booted!" + ) + self._on_lifecycle_event(ServerLifecycleEvent.BOOTED) + if self.is_owner: + self._boot_future.set_result(True) return self async def connect( @@ -910,8 +1161,11 @@ async def connect( :param options: The context's options. :param kwargs: Keyword arguments for options. """ - if self._boot_status in (BootStatus.BOOTING, BootStatus.ONLINE): + if self._boot_status != BootStatus.OFFLINE: raise ServerOnline + loop = asyncio.get_running_loop() + self._boot_future = loop.create_future() + self._exit_future = loop.create_future() self._boot_status = BootStatus.BOOTING self._options = new(options or self._options, **kwargs) self._is_owner = False @@ -926,6 +1180,7 @@ async def disconnect(self) -> "AsyncServer": raise ServerOffline if self._is_owner: raise OwnedServerShutdown("Cannot disconnect from owned server.") + self._boot_status = BootStatus.QUITTING await self._disconnect() return self @@ -1163,12 +1418,27 @@ async def quit(self, force: bool = False) -> "AsyncServer": raise UnownedServerShutdown( "Cannot quit unowned server without force flag." ) + logger.info( + f"[{self._options.ip_address}:{self._options.port}/{self.name or hex(id(self))}] " + "quitting ..." + ) + self._boot_status = BootStatus.QUITTING + self._on_lifecycle_event(ServerLifecycleEvent.QUITTING) try: await Quit().communicate_async(server=self, timeout=1) except (OscProtocolOffline, asyncio.TimeoutError): pass - await self._process_protocol.quit() await self._disconnect() + await cast(AsyncProcessProtocol, self._process_protocol).quit() + logger.info( + f"[{self._options.ip_address}:{self._options.port}/{self.name or hex(id(self))}] " + "... quit!" + ) + self._on_lifecycle_event(ServerLifecycleEvent.QUIT) + if not self._boot_future.done(): + self._boot_future.set_result(True) + if not self._exit_future.done(): + self._exit_future.set_result(True) return self async def reboot(self) -> "AsyncServer": diff --git a/supriya/contexts/shm.cpp b/supriya/contexts/shm.cpp index e19d18c71..26c65daf0 100644 --- a/supriya/contexts/shm.cpp +++ b/supriya/contexts/shm.cpp @@ -1,4 +1,4 @@ -/* Generated by Cython 3.0.9 */ +/* Generated by Cython 3.0.11 */ /* BEGIN: Cython Metadata { @@ -45,10 +45,10 @@ END: Cython Metadata */ #else #define __PYX_EXTRA_ABI_MODULE_NAME "" #endif -#define CYTHON_ABI "3_0_9" __PYX_EXTRA_ABI_MODULE_NAME +#define CYTHON_ABI "3_0_11" __PYX_EXTRA_ABI_MODULE_NAME #define __PYX_ABI_MODULE_NAME "_cython_" CYTHON_ABI #define __PYX_TYPE_MODULE_PREFIX __PYX_ABI_MODULE_NAME "." -#define CYTHON_HEX_VERSION 0x030009F0 +#define CYTHON_HEX_VERSION 0x03000BF0 #define CYTHON_FUTURE_DIVISION 1 #include #ifndef offsetof @@ -140,6 +140,8 @@ END: Cython Metadata */ #ifndef CYTHON_UPDATE_DESCRIPTOR_DOC #define CYTHON_UPDATE_DESCRIPTOR_DOC 0 #endif + #undef CYTHON_USE_FREELISTS + #define CYTHON_USE_FREELISTS 0 #elif defined(PYPY_VERSION) #define CYTHON_COMPILING_IN_PYPY 1 #define CYTHON_COMPILING_IN_CPYTHON 0 @@ -201,6 +203,8 @@ END: Cython Metadata */ #ifndef CYTHON_UPDATE_DESCRIPTOR_DOC #define CYTHON_UPDATE_DESCRIPTOR_DOC 0 #endif + #undef CYTHON_USE_FREELISTS + #define CYTHON_USE_FREELISTS 0 #elif defined(CYTHON_LIMITED_API) #ifdef Py_LIMITED_API #undef __PYX_LIMITED_VERSION_HEX @@ -262,6 +266,8 @@ END: Cython Metadata */ #ifndef CYTHON_UPDATE_DESCRIPTOR_DOC #define CYTHON_UPDATE_DESCRIPTOR_DOC 0 #endif + #undef CYTHON_USE_FREELISTS + #define CYTHON_USE_FREELISTS 0 #elif defined(Py_GIL_DISABLED) || defined(Py_NOGIL) #define CYTHON_COMPILING_IN_PYPY 0 #define CYTHON_COMPILING_IN_CPYTHON 0 @@ -271,11 +277,17 @@ END: Cython Metadata */ #ifndef CYTHON_USE_TYPE_SLOTS #define CYTHON_USE_TYPE_SLOTS 1 #endif + #ifndef CYTHON_USE_TYPE_SPECS + #define CYTHON_USE_TYPE_SPECS 0 + #endif #undef CYTHON_USE_PYTYPE_LOOKUP #define CYTHON_USE_PYTYPE_LOOKUP 0 #ifndef CYTHON_USE_ASYNC_SLOTS #define CYTHON_USE_ASYNC_SLOTS 1 #endif + #ifndef CYTHON_USE_PYLONG_INTERNALS + #define CYTHON_USE_PYLONG_INTERNALS 0 + #endif #undef CYTHON_USE_PYLIST_INTERNALS #define CYTHON_USE_PYLIST_INTERNALS 0 #ifndef CYTHON_USE_UNICODE_INTERNALS @@ -283,8 +295,6 @@ END: Cython Metadata */ #endif #undef CYTHON_USE_UNICODE_WRITER #define CYTHON_USE_UNICODE_WRITER 0 - #undef CYTHON_USE_PYLONG_INTERNALS - #define CYTHON_USE_PYLONG_INTERNALS 0 #ifndef CYTHON_AVOID_BORROWED_REFS #define CYTHON_AVOID_BORROWED_REFS 0 #endif @@ -296,11 +306,22 @@ END: Cython Metadata */ #endif #undef CYTHON_FAST_THREAD_STATE #define CYTHON_FAST_THREAD_STATE 0 + #undef CYTHON_FAST_GIL + #define CYTHON_FAST_GIL 0 + #ifndef CYTHON_METH_FASTCALL + #define CYTHON_METH_FASTCALL 1 + #endif #undef CYTHON_FAST_PYCALL #define CYTHON_FAST_PYCALL 0 + #ifndef CYTHON_PEP487_INIT_SUBCLASS + #define CYTHON_PEP487_INIT_SUBCLASS 1 + #endif #ifndef CYTHON_PEP489_MULTI_PHASE_INIT #define CYTHON_PEP489_MULTI_PHASE_INIT 1 #endif + #ifndef CYTHON_USE_MODULE_STATE + #define CYTHON_USE_MODULE_STATE 0 + #endif #ifndef CYTHON_USE_TP_FINALIZE #define CYTHON_USE_TP_FINALIZE 1 #endif @@ -308,6 +329,12 @@ END: Cython Metadata */ #define CYTHON_USE_DICT_VERSIONS 0 #undef CYTHON_USE_EXC_INFO_STACK #define CYTHON_USE_EXC_INFO_STACK 0 + #ifndef CYTHON_UPDATE_DESCRIPTOR_DOC + #define CYTHON_UPDATE_DESCRIPTOR_DOC 1 + #endif + #ifndef CYTHON_USE_FREELISTS + #define CYTHON_USE_FREELISTS 0 + #endif #else #define CYTHON_COMPILING_IN_PYPY 0 #define CYTHON_COMPILING_IN_CPYTHON 1 @@ -398,6 +425,9 @@ END: Cython Metadata */ #ifndef CYTHON_UPDATE_DESCRIPTOR_DOC #define CYTHON_UPDATE_DESCRIPTOR_DOC 1 #endif + #ifndef CYTHON_USE_FREELISTS + #define CYTHON_USE_FREELISTS 1 + #endif #endif #if !defined(CYTHON_FAST_PYCCALL) #define CYTHON_FAST_PYCCALL (CYTHON_FAST_PYCALL && PY_VERSION_HEX >= 0x030600B1) @@ -2711,7 +2741,7 @@ static PyObject *__pyx_pf_7supriya_8contexts_3shm_9ServerSHM_4__getitem__(struct Py_ssize_t __pyx_t_5; PyObject *__pyx_t_6 = NULL; PyObject *__pyx_t_7 = NULL; - int __pyx_t_8; + unsigned int __pyx_t_8; PyObject *(*__pyx_t_9)(PyObject *); Py_ssize_t __pyx_t_10; int __pyx_t_11; @@ -3383,6 +3413,9 @@ static PyTypeObject __pyx_type_7supriya_8contexts_3shm_ServerSHM = { #if PY_VERSION_HEX >= 0x030C0000 0, /*tp_watched*/ #endif + #if PY_VERSION_HEX >= 0x030d00A4 + 0, /*tp_versions_used*/ + #endif #if CYTHON_COMPILING_IN_PYPY && PY_VERSION_HEX >= 0x03090000 && PY_VERSION_HEX < 0x030a0000 0, /*tp_pypy_flags*/ #endif @@ -6698,6 +6731,9 @@ static PyTypeObject __pyx_CyFunctionType_type = { #if PY_VERSION_HEX >= 0x030C0000 0, #endif +#if PY_VERSION_HEX >= 0x030d00A4 + 0, +#endif #if CYTHON_COMPILING_IN_PYPY && PY_VERSION_HEX >= 0x03090000 && PY_VERSION_HEX < 0x030a0000 0, #endif @@ -7141,245 +7177,239 @@ static CYTHON_INLINE unsigned int __Pyx_PyInt_As_unsigned_int(PyObject *x) { } return (unsigned int) val; } - } else + } #endif - if (likely(PyLong_Check(x))) { - if (is_unsigned) { + if (unlikely(!PyLong_Check(x))) { + unsigned int val; + PyObject *tmp = __Pyx_PyNumber_IntOrLong(x); + if (!tmp) return (unsigned int) -1; + val = __Pyx_PyInt_As_unsigned_int(tmp); + Py_DECREF(tmp); + return val; + } + if (is_unsigned) { #if CYTHON_USE_PYLONG_INTERNALS - if (unlikely(__Pyx_PyLong_IsNeg(x))) { - goto raise_neg_overflow; - } else if (__Pyx_PyLong_IsCompact(x)) { - __PYX_VERIFY_RETURN_INT(unsigned int, __Pyx_compact_upylong, __Pyx_PyLong_CompactValueUnsigned(x)) - } else { - const digit* digits = __Pyx_PyLong_Digits(x); - assert(__Pyx_PyLong_DigitCount(x) > 1); - switch (__Pyx_PyLong_DigitCount(x)) { - case 2: - if ((8 * sizeof(unsigned int) > 1 * PyLong_SHIFT)) { - if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { - __PYX_VERIFY_RETURN_INT(unsigned int, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if ((8 * sizeof(unsigned int) >= 2 * PyLong_SHIFT)) { - return (unsigned int) (((((unsigned int)digits[1]) << PyLong_SHIFT) | (unsigned int)digits[0])); - } + if (unlikely(__Pyx_PyLong_IsNeg(x))) { + goto raise_neg_overflow; + } else if (__Pyx_PyLong_IsCompact(x)) { + __PYX_VERIFY_RETURN_INT(unsigned int, __Pyx_compact_upylong, __Pyx_PyLong_CompactValueUnsigned(x)) + } else { + const digit* digits = __Pyx_PyLong_Digits(x); + assert(__Pyx_PyLong_DigitCount(x) > 1); + switch (__Pyx_PyLong_DigitCount(x)) { + case 2: + if ((8 * sizeof(unsigned int) > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(unsigned int, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(unsigned int) >= 2 * PyLong_SHIFT)) { + return (unsigned int) (((((unsigned int)digits[1]) << PyLong_SHIFT) | (unsigned int)digits[0])); } - break; - case 3: - if ((8 * sizeof(unsigned int) > 2 * PyLong_SHIFT)) { - if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { - __PYX_VERIFY_RETURN_INT(unsigned int, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if ((8 * sizeof(unsigned int) >= 3 * PyLong_SHIFT)) { - return (unsigned int) (((((((unsigned int)digits[2]) << PyLong_SHIFT) | (unsigned int)digits[1]) << PyLong_SHIFT) | (unsigned int)digits[0])); - } + } + break; + case 3: + if ((8 * sizeof(unsigned int) > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(unsigned int, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(unsigned int) >= 3 * PyLong_SHIFT)) { + return (unsigned int) (((((((unsigned int)digits[2]) << PyLong_SHIFT) | (unsigned int)digits[1]) << PyLong_SHIFT) | (unsigned int)digits[0])); } - break; - case 4: - if ((8 * sizeof(unsigned int) > 3 * PyLong_SHIFT)) { - if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { - __PYX_VERIFY_RETURN_INT(unsigned int, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if ((8 * sizeof(unsigned int) >= 4 * PyLong_SHIFT)) { - return (unsigned int) (((((((((unsigned int)digits[3]) << PyLong_SHIFT) | (unsigned int)digits[2]) << PyLong_SHIFT) | (unsigned int)digits[1]) << PyLong_SHIFT) | (unsigned int)digits[0])); - } + } + break; + case 4: + if ((8 * sizeof(unsigned int) > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(unsigned int, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(unsigned int) >= 4 * PyLong_SHIFT)) { + return (unsigned int) (((((((((unsigned int)digits[3]) << PyLong_SHIFT) | (unsigned int)digits[2]) << PyLong_SHIFT) | (unsigned int)digits[1]) << PyLong_SHIFT) | (unsigned int)digits[0])); } - break; - } + } + break; } + } #endif #if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX < 0x030C00A7 - if (unlikely(Py_SIZE(x) < 0)) { - goto raise_neg_overflow; - } + if (unlikely(Py_SIZE(x) < 0)) { + goto raise_neg_overflow; + } #else - { - int result = PyObject_RichCompareBool(x, Py_False, Py_LT); - if (unlikely(result < 0)) - return (unsigned int) -1; - if (unlikely(result == 1)) - goto raise_neg_overflow; - } + { + int result = PyObject_RichCompareBool(x, Py_False, Py_LT); + if (unlikely(result < 0)) + return (unsigned int) -1; + if (unlikely(result == 1)) + goto raise_neg_overflow; + } #endif - if ((sizeof(unsigned int) <= sizeof(unsigned long))) { - __PYX_VERIFY_RETURN_INT_EXC(unsigned int, unsigned long, PyLong_AsUnsignedLong(x)) + if ((sizeof(unsigned int) <= sizeof(unsigned long))) { + __PYX_VERIFY_RETURN_INT_EXC(unsigned int, unsigned long, PyLong_AsUnsignedLong(x)) #ifdef HAVE_LONG_LONG - } else if ((sizeof(unsigned int) <= sizeof(unsigned PY_LONG_LONG))) { - __PYX_VERIFY_RETURN_INT_EXC(unsigned int, unsigned PY_LONG_LONG, PyLong_AsUnsignedLongLong(x)) + } else if ((sizeof(unsigned int) <= sizeof(unsigned PY_LONG_LONG))) { + __PYX_VERIFY_RETURN_INT_EXC(unsigned int, unsigned PY_LONG_LONG, PyLong_AsUnsignedLongLong(x)) #endif - } - } else { + } + } else { #if CYTHON_USE_PYLONG_INTERNALS - if (__Pyx_PyLong_IsCompact(x)) { - __PYX_VERIFY_RETURN_INT(unsigned int, __Pyx_compact_pylong, __Pyx_PyLong_CompactValue(x)) - } else { - const digit* digits = __Pyx_PyLong_Digits(x); - assert(__Pyx_PyLong_DigitCount(x) > 1); - switch (__Pyx_PyLong_SignedDigitCount(x)) { - case -2: - if ((8 * sizeof(unsigned int) - 1 > 1 * PyLong_SHIFT)) { - if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { - __PYX_VERIFY_RETURN_INT(unsigned int, long, -(long) (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if ((8 * sizeof(unsigned int) - 1 > 2 * PyLong_SHIFT)) { - return (unsigned int) (((unsigned int)-1)*(((((unsigned int)digits[1]) << PyLong_SHIFT) | (unsigned int)digits[0]))); - } + if (__Pyx_PyLong_IsCompact(x)) { + __PYX_VERIFY_RETURN_INT(unsigned int, __Pyx_compact_pylong, __Pyx_PyLong_CompactValue(x)) + } else { + const digit* digits = __Pyx_PyLong_Digits(x); + assert(__Pyx_PyLong_DigitCount(x) > 1); + switch (__Pyx_PyLong_SignedDigitCount(x)) { + case -2: + if ((8 * sizeof(unsigned int) - 1 > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(unsigned int, long, -(long) (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(unsigned int) - 1 > 2 * PyLong_SHIFT)) { + return (unsigned int) (((unsigned int)-1)*(((((unsigned int)digits[1]) << PyLong_SHIFT) | (unsigned int)digits[0]))); } - break; - case 2: - if ((8 * sizeof(unsigned int) > 1 * PyLong_SHIFT)) { - if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { - __PYX_VERIFY_RETURN_INT(unsigned int, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if ((8 * sizeof(unsigned int) - 1 > 2 * PyLong_SHIFT)) { - return (unsigned int) ((((((unsigned int)digits[1]) << PyLong_SHIFT) | (unsigned int)digits[0]))); - } + } + break; + case 2: + if ((8 * sizeof(unsigned int) > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(unsigned int, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(unsigned int) - 1 > 2 * PyLong_SHIFT)) { + return (unsigned int) ((((((unsigned int)digits[1]) << PyLong_SHIFT) | (unsigned int)digits[0]))); } - break; - case -3: - if ((8 * sizeof(unsigned int) - 1 > 2 * PyLong_SHIFT)) { - if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { - __PYX_VERIFY_RETURN_INT(unsigned int, long, -(long) (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if ((8 * sizeof(unsigned int) - 1 > 3 * PyLong_SHIFT)) { - return (unsigned int) (((unsigned int)-1)*(((((((unsigned int)digits[2]) << PyLong_SHIFT) | (unsigned int)digits[1]) << PyLong_SHIFT) | (unsigned int)digits[0]))); - } + } + break; + case -3: + if ((8 * sizeof(unsigned int) - 1 > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(unsigned int, long, -(long) (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(unsigned int) - 1 > 3 * PyLong_SHIFT)) { + return (unsigned int) (((unsigned int)-1)*(((((((unsigned int)digits[2]) << PyLong_SHIFT) | (unsigned int)digits[1]) << PyLong_SHIFT) | (unsigned int)digits[0]))); } - break; - case 3: - if ((8 * sizeof(unsigned int) > 2 * PyLong_SHIFT)) { - if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { - __PYX_VERIFY_RETURN_INT(unsigned int, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if ((8 * sizeof(unsigned int) - 1 > 3 * PyLong_SHIFT)) { - return (unsigned int) ((((((((unsigned int)digits[2]) << PyLong_SHIFT) | (unsigned int)digits[1]) << PyLong_SHIFT) | (unsigned int)digits[0]))); - } + } + break; + case 3: + if ((8 * sizeof(unsigned int) > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(unsigned int, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(unsigned int) - 1 > 3 * PyLong_SHIFT)) { + return (unsigned int) ((((((((unsigned int)digits[2]) << PyLong_SHIFT) | (unsigned int)digits[1]) << PyLong_SHIFT) | (unsigned int)digits[0]))); } - break; - case -4: - if ((8 * sizeof(unsigned int) - 1 > 3 * PyLong_SHIFT)) { - if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { - __PYX_VERIFY_RETURN_INT(unsigned int, long, -(long) (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if ((8 * sizeof(unsigned int) - 1 > 4 * PyLong_SHIFT)) { - return (unsigned int) (((unsigned int)-1)*(((((((((unsigned int)digits[3]) << PyLong_SHIFT) | (unsigned int)digits[2]) << PyLong_SHIFT) | (unsigned int)digits[1]) << PyLong_SHIFT) | (unsigned int)digits[0]))); - } + } + break; + case -4: + if ((8 * sizeof(unsigned int) - 1 > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(unsigned int, long, -(long) (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(unsigned int) - 1 > 4 * PyLong_SHIFT)) { + return (unsigned int) (((unsigned int)-1)*(((((((((unsigned int)digits[3]) << PyLong_SHIFT) | (unsigned int)digits[2]) << PyLong_SHIFT) | (unsigned int)digits[1]) << PyLong_SHIFT) | (unsigned int)digits[0]))); } - break; - case 4: - if ((8 * sizeof(unsigned int) > 3 * PyLong_SHIFT)) { - if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { - __PYX_VERIFY_RETURN_INT(unsigned int, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if ((8 * sizeof(unsigned int) - 1 > 4 * PyLong_SHIFT)) { - return (unsigned int) ((((((((((unsigned int)digits[3]) << PyLong_SHIFT) | (unsigned int)digits[2]) << PyLong_SHIFT) | (unsigned int)digits[1]) << PyLong_SHIFT) | (unsigned int)digits[0]))); - } + } + break; + case 4: + if ((8 * sizeof(unsigned int) > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(unsigned int, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(unsigned int) - 1 > 4 * PyLong_SHIFT)) { + return (unsigned int) ((((((((((unsigned int)digits[3]) << PyLong_SHIFT) | (unsigned int)digits[2]) << PyLong_SHIFT) | (unsigned int)digits[1]) << PyLong_SHIFT) | (unsigned int)digits[0]))); } - break; - } + } + break; } + } #endif - if ((sizeof(unsigned int) <= sizeof(long))) { - __PYX_VERIFY_RETURN_INT_EXC(unsigned int, long, PyLong_AsLong(x)) + if ((sizeof(unsigned int) <= sizeof(long))) { + __PYX_VERIFY_RETURN_INT_EXC(unsigned int, long, PyLong_AsLong(x)) #ifdef HAVE_LONG_LONG - } else if ((sizeof(unsigned int) <= sizeof(PY_LONG_LONG))) { - __PYX_VERIFY_RETURN_INT_EXC(unsigned int, PY_LONG_LONG, PyLong_AsLongLong(x)) + } else if ((sizeof(unsigned int) <= sizeof(PY_LONG_LONG))) { + __PYX_VERIFY_RETURN_INT_EXC(unsigned int, PY_LONG_LONG, PyLong_AsLongLong(x)) #endif - } + } + } + { + unsigned int val; + int ret = -1; +#if PY_VERSION_HEX >= 0x030d00A6 && !CYTHON_COMPILING_IN_LIMITED_API + Py_ssize_t bytes_copied = PyLong_AsNativeBytes( + x, &val, sizeof(val), Py_ASNATIVEBYTES_NATIVE_ENDIAN | (is_unsigned ? Py_ASNATIVEBYTES_UNSIGNED_BUFFER | Py_ASNATIVEBYTES_REJECT_NEGATIVE : 0)); + if (unlikely(bytes_copied == -1)) { + } else if (unlikely(bytes_copied > (Py_ssize_t) sizeof(val))) { + goto raise_overflow; + } else { + ret = 0; + } +#elif PY_VERSION_HEX < 0x030d0000 && !(CYTHON_COMPILING_IN_PYPY || CYTHON_COMPILING_IN_LIMITED_API) || defined(_PyLong_AsByteArray) + int one = 1; int is_little = (int)*(unsigned char *)&one; + unsigned char *bytes = (unsigned char *)&val; + ret = _PyLong_AsByteArray((PyLongObject *)x, + bytes, sizeof(val), + is_little, !is_unsigned); +#else + PyObject *v; + PyObject *stepval = NULL, *mask = NULL, *shift = NULL; + int bits, remaining_bits, is_negative = 0; + int chunk_size = (sizeof(long) < 8) ? 30 : 62; + if (likely(PyLong_CheckExact(x))) { + v = __Pyx_NewRef(x); + } else { + v = PyNumber_Long(x); + if (unlikely(!v)) return (unsigned int) -1; + assert(PyLong_CheckExact(v)); } { - unsigned int val; - PyObject *v = __Pyx_PyNumber_IntOrLong(x); -#if PY_MAJOR_VERSION < 3 - if (likely(v) && !PyLong_Check(v)) { - PyObject *tmp = v; - v = PyNumber_Long(tmp); - Py_DECREF(tmp); - } -#endif - if (likely(v)) { - int ret = -1; -#if PY_VERSION_HEX < 0x030d0000 && !(CYTHON_COMPILING_IN_PYPY || CYTHON_COMPILING_IN_LIMITED_API) || defined(_PyLong_AsByteArray) - int one = 1; int is_little = (int)*(unsigned char *)&one; - unsigned char *bytes = (unsigned char *)&val; - ret = _PyLong_AsByteArray((PyLongObject *)v, - bytes, sizeof(val), - is_little, !is_unsigned); -#else - PyObject *stepval = NULL, *mask = NULL, *shift = NULL; - int bits, remaining_bits, is_negative = 0; - long idigit; - int chunk_size = (sizeof(long) < 8) ? 30 : 62; - if (unlikely(!PyLong_CheckExact(v))) { - PyObject *tmp = v; - v = PyNumber_Long(v); - assert(PyLong_CheckExact(v)); - Py_DECREF(tmp); - if (unlikely(!v)) return (unsigned int) -1; - } -#if CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030B0000 - if (Py_SIZE(x) == 0) - return (unsigned int) 0; - is_negative = Py_SIZE(x) < 0; -#else - { - int result = PyObject_RichCompareBool(x, Py_False, Py_LT); - if (unlikely(result < 0)) - return (unsigned int) -1; - is_negative = result == 1; - } -#endif - if (is_unsigned && unlikely(is_negative)) { - goto raise_neg_overflow; - } else if (is_negative) { - stepval = PyNumber_Invert(v); - if (unlikely(!stepval)) - return (unsigned int) -1; - } else { - stepval = __Pyx_NewRef(v); - } - val = (unsigned int) 0; - mask = PyLong_FromLong((1L << chunk_size) - 1); if (unlikely(!mask)) goto done; - shift = PyLong_FromLong(chunk_size); if (unlikely(!shift)) goto done; - for (bits = 0; bits < (int) sizeof(unsigned int) * 8 - chunk_size; bits += chunk_size) { - PyObject *tmp, *digit; - digit = PyNumber_And(stepval, mask); - if (unlikely(!digit)) goto done; - idigit = PyLong_AsLong(digit); - Py_DECREF(digit); - if (unlikely(idigit < 0)) goto done; - tmp = PyNumber_Rshift(stepval, shift); - if (unlikely(!tmp)) goto done; - Py_DECREF(stepval); stepval = tmp; - val |= ((unsigned int) idigit) << bits; - #if CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030B0000 - if (Py_SIZE(stepval) == 0) - goto unpacking_done; - #endif - } - idigit = PyLong_AsLong(stepval); - if (unlikely(idigit < 0)) goto done; - remaining_bits = ((int) sizeof(unsigned int) * 8) - bits - (is_unsigned ? 0 : 1); - if (unlikely(idigit >= (1L << remaining_bits))) - goto raise_overflow; - val |= ((unsigned int) idigit) << bits; - #if CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030B0000 - unpacking_done: - #endif - if (!is_unsigned) { - if (unlikely(val & (((unsigned int) 1) << (sizeof(unsigned int) * 8 - 1)))) - goto raise_overflow; - if (is_negative) - val = ~val; - } - ret = 0; - done: - Py_XDECREF(shift); - Py_XDECREF(mask); - Py_XDECREF(stepval); -#endif + int result = PyObject_RichCompareBool(v, Py_False, Py_LT); + if (unlikely(result < 0)) { Py_DECREF(v); - if (likely(!ret)) - return val; + return (unsigned int) -1; } - return (unsigned int) -1; + is_negative = result == 1; } - } else { - unsigned int val; - PyObject *tmp = __Pyx_PyNumber_IntOrLong(x); - if (!tmp) return (unsigned int) -1; - val = __Pyx_PyInt_As_unsigned_int(tmp); - Py_DECREF(tmp); + if (is_unsigned && unlikely(is_negative)) { + Py_DECREF(v); + goto raise_neg_overflow; + } else if (is_negative) { + stepval = PyNumber_Invert(v); + Py_DECREF(v); + if (unlikely(!stepval)) + return (unsigned int) -1; + } else { + stepval = v; + } + v = NULL; + val = (unsigned int) 0; + mask = PyLong_FromLong((1L << chunk_size) - 1); if (unlikely(!mask)) goto done; + shift = PyLong_FromLong(chunk_size); if (unlikely(!shift)) goto done; + for (bits = 0; bits < (int) sizeof(unsigned int) * 8 - chunk_size; bits += chunk_size) { + PyObject *tmp, *digit; + long idigit; + digit = PyNumber_And(stepval, mask); + if (unlikely(!digit)) goto done; + idigit = PyLong_AsLong(digit); + Py_DECREF(digit); + if (unlikely(idigit < 0)) goto done; + val |= ((unsigned int) idigit) << bits; + tmp = PyNumber_Rshift(stepval, shift); + if (unlikely(!tmp)) goto done; + Py_DECREF(stepval); stepval = tmp; + } + Py_DECREF(shift); shift = NULL; + Py_DECREF(mask); mask = NULL; + { + long idigit = PyLong_AsLong(stepval); + if (unlikely(idigit < 0)) goto done; + remaining_bits = ((int) sizeof(unsigned int) * 8) - bits - (is_unsigned ? 0 : 1); + if (unlikely(idigit >= (1L << remaining_bits))) + goto raise_overflow; + val |= ((unsigned int) idigit) << bits; + } + if (!is_unsigned) { + if (unlikely(val & (((unsigned int) 1) << (sizeof(unsigned int) * 8 - 1)))) + goto raise_overflow; + if (is_negative) + val = ~val; + } + ret = 0; + done: + Py_XDECREF(shift); + Py_XDECREF(mask); + Py_XDECREF(stepval); +#endif + if (unlikely(ret)) + return (unsigned int) -1; return val; } raise_overflow: @@ -7423,12 +7453,19 @@ static CYTHON_INLINE PyObject* __Pyx_PyInt_From_unsigned_int(unsigned int value) } } { - int one = 1; int little = (int)*(unsigned char *)&one; unsigned char *bytes = (unsigned char *)&value; -#if !CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030d0000 +#if !CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX >= 0x030d00A4 + if (is_unsigned) { + return PyLong_FromUnsignedNativeBytes(bytes, sizeof(value), -1); + } else { + return PyLong_FromNativeBytes(bytes, sizeof(value), -1); + } +#elif !CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030d0000 + int one = 1; int little = (int)*(unsigned char *)&one; return _PyLong_FromByteArray(bytes, sizeof(unsigned int), little, !is_unsigned); #else + int one = 1; int little = (int)*(unsigned char *)&one; PyObject *from_bytes, *result = NULL; PyObject *py_bytes = NULL, *arg_tuple = NULL, *kwds = NULL, *order_str = NULL; from_bytes = PyObject_GetAttrString((PyObject*)&PyLong_Type, "from_bytes"); @@ -7503,12 +7540,19 @@ static CYTHON_INLINE PyObject* __Pyx_PyInt_From_long(long value) { } } { - int one = 1; int little = (int)*(unsigned char *)&one; unsigned char *bytes = (unsigned char *)&value; -#if !CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030d0000 +#if !CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX >= 0x030d00A4 + if (is_unsigned) { + return PyLong_FromUnsignedNativeBytes(bytes, sizeof(value), -1); + } else { + return PyLong_FromNativeBytes(bytes, sizeof(value), -1); + } +#elif !CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030d0000 + int one = 1; int little = (int)*(unsigned char *)&one; return _PyLong_FromByteArray(bytes, sizeof(long), little, !is_unsigned); #else + int one = 1; int little = (int)*(unsigned char *)&one; PyObject *from_bytes, *result = NULL; PyObject *py_bytes = NULL, *arg_tuple = NULL, *kwds = NULL, *order_str = NULL; from_bytes = PyObject_GetAttrString((PyObject*)&PyLong_Type, "from_bytes"); @@ -7558,245 +7602,239 @@ static CYTHON_INLINE long __Pyx_PyInt_As_long(PyObject *x) { } return (long) val; } - } else + } #endif - if (likely(PyLong_Check(x))) { - if (is_unsigned) { + if (unlikely(!PyLong_Check(x))) { + long val; + PyObject *tmp = __Pyx_PyNumber_IntOrLong(x); + if (!tmp) return (long) -1; + val = __Pyx_PyInt_As_long(tmp); + Py_DECREF(tmp); + return val; + } + if (is_unsigned) { #if CYTHON_USE_PYLONG_INTERNALS - if (unlikely(__Pyx_PyLong_IsNeg(x))) { - goto raise_neg_overflow; - } else if (__Pyx_PyLong_IsCompact(x)) { - __PYX_VERIFY_RETURN_INT(long, __Pyx_compact_upylong, __Pyx_PyLong_CompactValueUnsigned(x)) - } else { - const digit* digits = __Pyx_PyLong_Digits(x); - assert(__Pyx_PyLong_DigitCount(x) > 1); - switch (__Pyx_PyLong_DigitCount(x)) { - case 2: - if ((8 * sizeof(long) > 1 * PyLong_SHIFT)) { - if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { - __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if ((8 * sizeof(long) >= 2 * PyLong_SHIFT)) { - return (long) (((((long)digits[1]) << PyLong_SHIFT) | (long)digits[0])); - } + if (unlikely(__Pyx_PyLong_IsNeg(x))) { + goto raise_neg_overflow; + } else if (__Pyx_PyLong_IsCompact(x)) { + __PYX_VERIFY_RETURN_INT(long, __Pyx_compact_upylong, __Pyx_PyLong_CompactValueUnsigned(x)) + } else { + const digit* digits = __Pyx_PyLong_Digits(x); + assert(__Pyx_PyLong_DigitCount(x) > 1); + switch (__Pyx_PyLong_DigitCount(x)) { + case 2: + if ((8 * sizeof(long) > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(long) >= 2 * PyLong_SHIFT)) { + return (long) (((((long)digits[1]) << PyLong_SHIFT) | (long)digits[0])); } - break; - case 3: - if ((8 * sizeof(long) > 2 * PyLong_SHIFT)) { - if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { - __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if ((8 * sizeof(long) >= 3 * PyLong_SHIFT)) { - return (long) (((((((long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0])); - } + } + break; + case 3: + if ((8 * sizeof(long) > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(long) >= 3 * PyLong_SHIFT)) { + return (long) (((((((long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0])); } - break; - case 4: - if ((8 * sizeof(long) > 3 * PyLong_SHIFT)) { - if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { - __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if ((8 * sizeof(long) >= 4 * PyLong_SHIFT)) { - return (long) (((((((((long)digits[3]) << PyLong_SHIFT) | (long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0])); - } + } + break; + case 4: + if ((8 * sizeof(long) > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(long) >= 4 * PyLong_SHIFT)) { + return (long) (((((((((long)digits[3]) << PyLong_SHIFT) | (long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0])); } - break; - } + } + break; } + } #endif #if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX < 0x030C00A7 - if (unlikely(Py_SIZE(x) < 0)) { - goto raise_neg_overflow; - } + if (unlikely(Py_SIZE(x) < 0)) { + goto raise_neg_overflow; + } #else - { - int result = PyObject_RichCompareBool(x, Py_False, Py_LT); - if (unlikely(result < 0)) - return (long) -1; - if (unlikely(result == 1)) - goto raise_neg_overflow; - } + { + int result = PyObject_RichCompareBool(x, Py_False, Py_LT); + if (unlikely(result < 0)) + return (long) -1; + if (unlikely(result == 1)) + goto raise_neg_overflow; + } #endif - if ((sizeof(long) <= sizeof(unsigned long))) { - __PYX_VERIFY_RETURN_INT_EXC(long, unsigned long, PyLong_AsUnsignedLong(x)) + if ((sizeof(long) <= sizeof(unsigned long))) { + __PYX_VERIFY_RETURN_INT_EXC(long, unsigned long, PyLong_AsUnsignedLong(x)) #ifdef HAVE_LONG_LONG - } else if ((sizeof(long) <= sizeof(unsigned PY_LONG_LONG))) { - __PYX_VERIFY_RETURN_INT_EXC(long, unsigned PY_LONG_LONG, PyLong_AsUnsignedLongLong(x)) + } else if ((sizeof(long) <= sizeof(unsigned PY_LONG_LONG))) { + __PYX_VERIFY_RETURN_INT_EXC(long, unsigned PY_LONG_LONG, PyLong_AsUnsignedLongLong(x)) #endif - } - } else { + } + } else { #if CYTHON_USE_PYLONG_INTERNALS - if (__Pyx_PyLong_IsCompact(x)) { - __PYX_VERIFY_RETURN_INT(long, __Pyx_compact_pylong, __Pyx_PyLong_CompactValue(x)) - } else { - const digit* digits = __Pyx_PyLong_Digits(x); - assert(__Pyx_PyLong_DigitCount(x) > 1); - switch (__Pyx_PyLong_SignedDigitCount(x)) { - case -2: - if ((8 * sizeof(long) - 1 > 1 * PyLong_SHIFT)) { - if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { - __PYX_VERIFY_RETURN_INT(long, long, -(long) (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if ((8 * sizeof(long) - 1 > 2 * PyLong_SHIFT)) { - return (long) (((long)-1)*(((((long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); - } + if (__Pyx_PyLong_IsCompact(x)) { + __PYX_VERIFY_RETURN_INT(long, __Pyx_compact_pylong, __Pyx_PyLong_CompactValue(x)) + } else { + const digit* digits = __Pyx_PyLong_Digits(x); + assert(__Pyx_PyLong_DigitCount(x) > 1); + switch (__Pyx_PyLong_SignedDigitCount(x)) { + case -2: + if ((8 * sizeof(long) - 1 > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(long, long, -(long) (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(long) - 1 > 2 * PyLong_SHIFT)) { + return (long) (((long)-1)*(((((long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); } - break; - case 2: - if ((8 * sizeof(long) > 1 * PyLong_SHIFT)) { - if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { - __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if ((8 * sizeof(long) - 1 > 2 * PyLong_SHIFT)) { - return (long) ((((((long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); - } + } + break; + case 2: + if ((8 * sizeof(long) > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(long) - 1 > 2 * PyLong_SHIFT)) { + return (long) ((((((long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); } - break; - case -3: - if ((8 * sizeof(long) - 1 > 2 * PyLong_SHIFT)) { - if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { - __PYX_VERIFY_RETURN_INT(long, long, -(long) (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if ((8 * sizeof(long) - 1 > 3 * PyLong_SHIFT)) { - return (long) (((long)-1)*(((((((long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); - } + } + break; + case -3: + if ((8 * sizeof(long) - 1 > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(long, long, -(long) (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(long) - 1 > 3 * PyLong_SHIFT)) { + return (long) (((long)-1)*(((((((long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); } - break; - case 3: - if ((8 * sizeof(long) > 2 * PyLong_SHIFT)) { - if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { - __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if ((8 * sizeof(long) - 1 > 3 * PyLong_SHIFT)) { - return (long) ((((((((long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); - } + } + break; + case 3: + if ((8 * sizeof(long) > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(long) - 1 > 3 * PyLong_SHIFT)) { + return (long) ((((((((long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); } - break; - case -4: - if ((8 * sizeof(long) - 1 > 3 * PyLong_SHIFT)) { - if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { - __PYX_VERIFY_RETURN_INT(long, long, -(long) (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if ((8 * sizeof(long) - 1 > 4 * PyLong_SHIFT)) { - return (long) (((long)-1)*(((((((((long)digits[3]) << PyLong_SHIFT) | (long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); - } + } + break; + case -4: + if ((8 * sizeof(long) - 1 > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(long, long, -(long) (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(long) - 1 > 4 * PyLong_SHIFT)) { + return (long) (((long)-1)*(((((((((long)digits[3]) << PyLong_SHIFT) | (long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); } - break; - case 4: - if ((8 * sizeof(long) > 3 * PyLong_SHIFT)) { - if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { - __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if ((8 * sizeof(long) - 1 > 4 * PyLong_SHIFT)) { - return (long) ((((((((((long)digits[3]) << PyLong_SHIFT) | (long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); - } + } + break; + case 4: + if ((8 * sizeof(long) > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(long) - 1 > 4 * PyLong_SHIFT)) { + return (long) ((((((((((long)digits[3]) << PyLong_SHIFT) | (long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); } - break; - } + } + break; } + } #endif - if ((sizeof(long) <= sizeof(long))) { - __PYX_VERIFY_RETURN_INT_EXC(long, long, PyLong_AsLong(x)) + if ((sizeof(long) <= sizeof(long))) { + __PYX_VERIFY_RETURN_INT_EXC(long, long, PyLong_AsLong(x)) #ifdef HAVE_LONG_LONG - } else if ((sizeof(long) <= sizeof(PY_LONG_LONG))) { - __PYX_VERIFY_RETURN_INT_EXC(long, PY_LONG_LONG, PyLong_AsLongLong(x)) + } else if ((sizeof(long) <= sizeof(PY_LONG_LONG))) { + __PYX_VERIFY_RETURN_INT_EXC(long, PY_LONG_LONG, PyLong_AsLongLong(x)) #endif - } + } + } + { + long val; + int ret = -1; +#if PY_VERSION_HEX >= 0x030d00A6 && !CYTHON_COMPILING_IN_LIMITED_API + Py_ssize_t bytes_copied = PyLong_AsNativeBytes( + x, &val, sizeof(val), Py_ASNATIVEBYTES_NATIVE_ENDIAN | (is_unsigned ? Py_ASNATIVEBYTES_UNSIGNED_BUFFER | Py_ASNATIVEBYTES_REJECT_NEGATIVE : 0)); + if (unlikely(bytes_copied == -1)) { + } else if (unlikely(bytes_copied > (Py_ssize_t) sizeof(val))) { + goto raise_overflow; + } else { + ret = 0; + } +#elif PY_VERSION_HEX < 0x030d0000 && !(CYTHON_COMPILING_IN_PYPY || CYTHON_COMPILING_IN_LIMITED_API) || defined(_PyLong_AsByteArray) + int one = 1; int is_little = (int)*(unsigned char *)&one; + unsigned char *bytes = (unsigned char *)&val; + ret = _PyLong_AsByteArray((PyLongObject *)x, + bytes, sizeof(val), + is_little, !is_unsigned); +#else + PyObject *v; + PyObject *stepval = NULL, *mask = NULL, *shift = NULL; + int bits, remaining_bits, is_negative = 0; + int chunk_size = (sizeof(long) < 8) ? 30 : 62; + if (likely(PyLong_CheckExact(x))) { + v = __Pyx_NewRef(x); + } else { + v = PyNumber_Long(x); + if (unlikely(!v)) return (long) -1; + assert(PyLong_CheckExact(v)); } { - long val; - PyObject *v = __Pyx_PyNumber_IntOrLong(x); -#if PY_MAJOR_VERSION < 3 - if (likely(v) && !PyLong_Check(v)) { - PyObject *tmp = v; - v = PyNumber_Long(tmp); - Py_DECREF(tmp); - } -#endif - if (likely(v)) { - int ret = -1; -#if PY_VERSION_HEX < 0x030d0000 && !(CYTHON_COMPILING_IN_PYPY || CYTHON_COMPILING_IN_LIMITED_API) || defined(_PyLong_AsByteArray) - int one = 1; int is_little = (int)*(unsigned char *)&one; - unsigned char *bytes = (unsigned char *)&val; - ret = _PyLong_AsByteArray((PyLongObject *)v, - bytes, sizeof(val), - is_little, !is_unsigned); -#else - PyObject *stepval = NULL, *mask = NULL, *shift = NULL; - int bits, remaining_bits, is_negative = 0; - long idigit; - int chunk_size = (sizeof(long) < 8) ? 30 : 62; - if (unlikely(!PyLong_CheckExact(v))) { - PyObject *tmp = v; - v = PyNumber_Long(v); - assert(PyLong_CheckExact(v)); - Py_DECREF(tmp); - if (unlikely(!v)) return (long) -1; - } -#if CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030B0000 - if (Py_SIZE(x) == 0) - return (long) 0; - is_negative = Py_SIZE(x) < 0; -#else - { - int result = PyObject_RichCompareBool(x, Py_False, Py_LT); - if (unlikely(result < 0)) - return (long) -1; - is_negative = result == 1; - } -#endif - if (is_unsigned && unlikely(is_negative)) { - goto raise_neg_overflow; - } else if (is_negative) { - stepval = PyNumber_Invert(v); - if (unlikely(!stepval)) - return (long) -1; - } else { - stepval = __Pyx_NewRef(v); - } - val = (long) 0; - mask = PyLong_FromLong((1L << chunk_size) - 1); if (unlikely(!mask)) goto done; - shift = PyLong_FromLong(chunk_size); if (unlikely(!shift)) goto done; - for (bits = 0; bits < (int) sizeof(long) * 8 - chunk_size; bits += chunk_size) { - PyObject *tmp, *digit; - digit = PyNumber_And(stepval, mask); - if (unlikely(!digit)) goto done; - idigit = PyLong_AsLong(digit); - Py_DECREF(digit); - if (unlikely(idigit < 0)) goto done; - tmp = PyNumber_Rshift(stepval, shift); - if (unlikely(!tmp)) goto done; - Py_DECREF(stepval); stepval = tmp; - val |= ((long) idigit) << bits; - #if CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030B0000 - if (Py_SIZE(stepval) == 0) - goto unpacking_done; - #endif - } - idigit = PyLong_AsLong(stepval); - if (unlikely(idigit < 0)) goto done; - remaining_bits = ((int) sizeof(long) * 8) - bits - (is_unsigned ? 0 : 1); - if (unlikely(idigit >= (1L << remaining_bits))) - goto raise_overflow; - val |= ((long) idigit) << bits; - #if CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030B0000 - unpacking_done: - #endif - if (!is_unsigned) { - if (unlikely(val & (((long) 1) << (sizeof(long) * 8 - 1)))) - goto raise_overflow; - if (is_negative) - val = ~val; - } - ret = 0; - done: - Py_XDECREF(shift); - Py_XDECREF(mask); - Py_XDECREF(stepval); -#endif + int result = PyObject_RichCompareBool(v, Py_False, Py_LT); + if (unlikely(result < 0)) { Py_DECREF(v); - if (likely(!ret)) - return val; + return (long) -1; } - return (long) -1; + is_negative = result == 1; } - } else { - long val; - PyObject *tmp = __Pyx_PyNumber_IntOrLong(x); - if (!tmp) return (long) -1; - val = __Pyx_PyInt_As_long(tmp); - Py_DECREF(tmp); + if (is_unsigned && unlikely(is_negative)) { + Py_DECREF(v); + goto raise_neg_overflow; + } else if (is_negative) { + stepval = PyNumber_Invert(v); + Py_DECREF(v); + if (unlikely(!stepval)) + return (long) -1; + } else { + stepval = v; + } + v = NULL; + val = (long) 0; + mask = PyLong_FromLong((1L << chunk_size) - 1); if (unlikely(!mask)) goto done; + shift = PyLong_FromLong(chunk_size); if (unlikely(!shift)) goto done; + for (bits = 0; bits < (int) sizeof(long) * 8 - chunk_size; bits += chunk_size) { + PyObject *tmp, *digit; + long idigit; + digit = PyNumber_And(stepval, mask); + if (unlikely(!digit)) goto done; + idigit = PyLong_AsLong(digit); + Py_DECREF(digit); + if (unlikely(idigit < 0)) goto done; + val |= ((long) idigit) << bits; + tmp = PyNumber_Rshift(stepval, shift); + if (unlikely(!tmp)) goto done; + Py_DECREF(stepval); stepval = tmp; + } + Py_DECREF(shift); shift = NULL; + Py_DECREF(mask); mask = NULL; + { + long idigit = PyLong_AsLong(stepval); + if (unlikely(idigit < 0)) goto done; + remaining_bits = ((int) sizeof(long) * 8) - bits - (is_unsigned ? 0 : 1); + if (unlikely(idigit >= (1L << remaining_bits))) + goto raise_overflow; + val |= ((long) idigit) << bits; + } + if (!is_unsigned) { + if (unlikely(val & (((long) 1) << (sizeof(long) * 8 - 1)))) + goto raise_overflow; + if (is_negative) + val = ~val; + } + ret = 0; + done: + Py_XDECREF(shift); + Py_XDECREF(mask); + Py_XDECREF(stepval); +#endif + if (unlikely(ret)) + return (long) -1; return val; } raise_overflow: @@ -7831,245 +7869,239 @@ static CYTHON_INLINE int __Pyx_PyInt_As_int(PyObject *x) { } return (int) val; } - } else + } #endif - if (likely(PyLong_Check(x))) { - if (is_unsigned) { + if (unlikely(!PyLong_Check(x))) { + int val; + PyObject *tmp = __Pyx_PyNumber_IntOrLong(x); + if (!tmp) return (int) -1; + val = __Pyx_PyInt_As_int(tmp); + Py_DECREF(tmp); + return val; + } + if (is_unsigned) { #if CYTHON_USE_PYLONG_INTERNALS - if (unlikely(__Pyx_PyLong_IsNeg(x))) { - goto raise_neg_overflow; - } else if (__Pyx_PyLong_IsCompact(x)) { - __PYX_VERIFY_RETURN_INT(int, __Pyx_compact_upylong, __Pyx_PyLong_CompactValueUnsigned(x)) - } else { - const digit* digits = __Pyx_PyLong_Digits(x); - assert(__Pyx_PyLong_DigitCount(x) > 1); - switch (__Pyx_PyLong_DigitCount(x)) { - case 2: - if ((8 * sizeof(int) > 1 * PyLong_SHIFT)) { - if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { - __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if ((8 * sizeof(int) >= 2 * PyLong_SHIFT)) { - return (int) (((((int)digits[1]) << PyLong_SHIFT) | (int)digits[0])); - } + if (unlikely(__Pyx_PyLong_IsNeg(x))) { + goto raise_neg_overflow; + } else if (__Pyx_PyLong_IsCompact(x)) { + __PYX_VERIFY_RETURN_INT(int, __Pyx_compact_upylong, __Pyx_PyLong_CompactValueUnsigned(x)) + } else { + const digit* digits = __Pyx_PyLong_Digits(x); + assert(__Pyx_PyLong_DigitCount(x) > 1); + switch (__Pyx_PyLong_DigitCount(x)) { + case 2: + if ((8 * sizeof(int) > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int) >= 2 * PyLong_SHIFT)) { + return (int) (((((int)digits[1]) << PyLong_SHIFT) | (int)digits[0])); } - break; - case 3: - if ((8 * sizeof(int) > 2 * PyLong_SHIFT)) { - if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { - __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if ((8 * sizeof(int) >= 3 * PyLong_SHIFT)) { - return (int) (((((((int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0])); - } + } + break; + case 3: + if ((8 * sizeof(int) > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int) >= 3 * PyLong_SHIFT)) { + return (int) (((((((int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0])); } - break; - case 4: - if ((8 * sizeof(int) > 3 * PyLong_SHIFT)) { - if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { - __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if ((8 * sizeof(int) >= 4 * PyLong_SHIFT)) { - return (int) (((((((((int)digits[3]) << PyLong_SHIFT) | (int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0])); - } + } + break; + case 4: + if ((8 * sizeof(int) > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int) >= 4 * PyLong_SHIFT)) { + return (int) (((((((((int)digits[3]) << PyLong_SHIFT) | (int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0])); } - break; - } + } + break; } + } #endif #if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX < 0x030C00A7 - if (unlikely(Py_SIZE(x) < 0)) { - goto raise_neg_overflow; - } + if (unlikely(Py_SIZE(x) < 0)) { + goto raise_neg_overflow; + } #else - { - int result = PyObject_RichCompareBool(x, Py_False, Py_LT); - if (unlikely(result < 0)) - return (int) -1; - if (unlikely(result == 1)) - goto raise_neg_overflow; - } + { + int result = PyObject_RichCompareBool(x, Py_False, Py_LT); + if (unlikely(result < 0)) + return (int) -1; + if (unlikely(result == 1)) + goto raise_neg_overflow; + } #endif - if ((sizeof(int) <= sizeof(unsigned long))) { - __PYX_VERIFY_RETURN_INT_EXC(int, unsigned long, PyLong_AsUnsignedLong(x)) + if ((sizeof(int) <= sizeof(unsigned long))) { + __PYX_VERIFY_RETURN_INT_EXC(int, unsigned long, PyLong_AsUnsignedLong(x)) #ifdef HAVE_LONG_LONG - } else if ((sizeof(int) <= sizeof(unsigned PY_LONG_LONG))) { - __PYX_VERIFY_RETURN_INT_EXC(int, unsigned PY_LONG_LONG, PyLong_AsUnsignedLongLong(x)) + } else if ((sizeof(int) <= sizeof(unsigned PY_LONG_LONG))) { + __PYX_VERIFY_RETURN_INT_EXC(int, unsigned PY_LONG_LONG, PyLong_AsUnsignedLongLong(x)) #endif - } - } else { + } + } else { #if CYTHON_USE_PYLONG_INTERNALS - if (__Pyx_PyLong_IsCompact(x)) { - __PYX_VERIFY_RETURN_INT(int, __Pyx_compact_pylong, __Pyx_PyLong_CompactValue(x)) - } else { - const digit* digits = __Pyx_PyLong_Digits(x); - assert(__Pyx_PyLong_DigitCount(x) > 1); - switch (__Pyx_PyLong_SignedDigitCount(x)) { - case -2: - if ((8 * sizeof(int) - 1 > 1 * PyLong_SHIFT)) { - if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { - __PYX_VERIFY_RETURN_INT(int, long, -(long) (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if ((8 * sizeof(int) - 1 > 2 * PyLong_SHIFT)) { - return (int) (((int)-1)*(((((int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); - } + if (__Pyx_PyLong_IsCompact(x)) { + __PYX_VERIFY_RETURN_INT(int, __Pyx_compact_pylong, __Pyx_PyLong_CompactValue(x)) + } else { + const digit* digits = __Pyx_PyLong_Digits(x); + assert(__Pyx_PyLong_DigitCount(x) > 1); + switch (__Pyx_PyLong_SignedDigitCount(x)) { + case -2: + if ((8 * sizeof(int) - 1 > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int, long, -(long) (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int) - 1 > 2 * PyLong_SHIFT)) { + return (int) (((int)-1)*(((((int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); } - break; - case 2: - if ((8 * sizeof(int) > 1 * PyLong_SHIFT)) { - if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { - __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if ((8 * sizeof(int) - 1 > 2 * PyLong_SHIFT)) { - return (int) ((((((int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); - } + } + break; + case 2: + if ((8 * sizeof(int) > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int) - 1 > 2 * PyLong_SHIFT)) { + return (int) ((((((int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); } - break; - case -3: - if ((8 * sizeof(int) - 1 > 2 * PyLong_SHIFT)) { - if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { - __PYX_VERIFY_RETURN_INT(int, long, -(long) (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if ((8 * sizeof(int) - 1 > 3 * PyLong_SHIFT)) { - return (int) (((int)-1)*(((((((int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); - } + } + break; + case -3: + if ((8 * sizeof(int) - 1 > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int, long, -(long) (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int) - 1 > 3 * PyLong_SHIFT)) { + return (int) (((int)-1)*(((((((int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); } - break; - case 3: - if ((8 * sizeof(int) > 2 * PyLong_SHIFT)) { - if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { - __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if ((8 * sizeof(int) - 1 > 3 * PyLong_SHIFT)) { - return (int) ((((((((int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); - } + } + break; + case 3: + if ((8 * sizeof(int) > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int) - 1 > 3 * PyLong_SHIFT)) { + return (int) ((((((((int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); } - break; - case -4: - if ((8 * sizeof(int) - 1 > 3 * PyLong_SHIFT)) { - if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { - __PYX_VERIFY_RETURN_INT(int, long, -(long) (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if ((8 * sizeof(int) - 1 > 4 * PyLong_SHIFT)) { - return (int) (((int)-1)*(((((((((int)digits[3]) << PyLong_SHIFT) | (int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); - } + } + break; + case -4: + if ((8 * sizeof(int) - 1 > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int, long, -(long) (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int) - 1 > 4 * PyLong_SHIFT)) { + return (int) (((int)-1)*(((((((((int)digits[3]) << PyLong_SHIFT) | (int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); } - break; - case 4: - if ((8 * sizeof(int) > 3 * PyLong_SHIFT)) { - if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { - __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) - } else if ((8 * sizeof(int) - 1 > 4 * PyLong_SHIFT)) { - return (int) ((((((((((int)digits[3]) << PyLong_SHIFT) | (int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); - } + } + break; + case 4: + if ((8 * sizeof(int) > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int) - 1 > 4 * PyLong_SHIFT)) { + return (int) ((((((((((int)digits[3]) << PyLong_SHIFT) | (int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); } - break; - } + } + break; } + } #endif - if ((sizeof(int) <= sizeof(long))) { - __PYX_VERIFY_RETURN_INT_EXC(int, long, PyLong_AsLong(x)) + if ((sizeof(int) <= sizeof(long))) { + __PYX_VERIFY_RETURN_INT_EXC(int, long, PyLong_AsLong(x)) #ifdef HAVE_LONG_LONG - } else if ((sizeof(int) <= sizeof(PY_LONG_LONG))) { - __PYX_VERIFY_RETURN_INT_EXC(int, PY_LONG_LONG, PyLong_AsLongLong(x)) + } else if ((sizeof(int) <= sizeof(PY_LONG_LONG))) { + __PYX_VERIFY_RETURN_INT_EXC(int, PY_LONG_LONG, PyLong_AsLongLong(x)) #endif - } + } + } + { + int val; + int ret = -1; +#if PY_VERSION_HEX >= 0x030d00A6 && !CYTHON_COMPILING_IN_LIMITED_API + Py_ssize_t bytes_copied = PyLong_AsNativeBytes( + x, &val, sizeof(val), Py_ASNATIVEBYTES_NATIVE_ENDIAN | (is_unsigned ? Py_ASNATIVEBYTES_UNSIGNED_BUFFER | Py_ASNATIVEBYTES_REJECT_NEGATIVE : 0)); + if (unlikely(bytes_copied == -1)) { + } else if (unlikely(bytes_copied > (Py_ssize_t) sizeof(val))) { + goto raise_overflow; + } else { + ret = 0; + } +#elif PY_VERSION_HEX < 0x030d0000 && !(CYTHON_COMPILING_IN_PYPY || CYTHON_COMPILING_IN_LIMITED_API) || defined(_PyLong_AsByteArray) + int one = 1; int is_little = (int)*(unsigned char *)&one; + unsigned char *bytes = (unsigned char *)&val; + ret = _PyLong_AsByteArray((PyLongObject *)x, + bytes, sizeof(val), + is_little, !is_unsigned); +#else + PyObject *v; + PyObject *stepval = NULL, *mask = NULL, *shift = NULL; + int bits, remaining_bits, is_negative = 0; + int chunk_size = (sizeof(long) < 8) ? 30 : 62; + if (likely(PyLong_CheckExact(x))) { + v = __Pyx_NewRef(x); + } else { + v = PyNumber_Long(x); + if (unlikely(!v)) return (int) -1; + assert(PyLong_CheckExact(v)); } { - int val; - PyObject *v = __Pyx_PyNumber_IntOrLong(x); -#if PY_MAJOR_VERSION < 3 - if (likely(v) && !PyLong_Check(v)) { - PyObject *tmp = v; - v = PyNumber_Long(tmp); - Py_DECREF(tmp); - } -#endif - if (likely(v)) { - int ret = -1; -#if PY_VERSION_HEX < 0x030d0000 && !(CYTHON_COMPILING_IN_PYPY || CYTHON_COMPILING_IN_LIMITED_API) || defined(_PyLong_AsByteArray) - int one = 1; int is_little = (int)*(unsigned char *)&one; - unsigned char *bytes = (unsigned char *)&val; - ret = _PyLong_AsByteArray((PyLongObject *)v, - bytes, sizeof(val), - is_little, !is_unsigned); -#else - PyObject *stepval = NULL, *mask = NULL, *shift = NULL; - int bits, remaining_bits, is_negative = 0; - long idigit; - int chunk_size = (sizeof(long) < 8) ? 30 : 62; - if (unlikely(!PyLong_CheckExact(v))) { - PyObject *tmp = v; - v = PyNumber_Long(v); - assert(PyLong_CheckExact(v)); - Py_DECREF(tmp); - if (unlikely(!v)) return (int) -1; - } -#if CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030B0000 - if (Py_SIZE(x) == 0) - return (int) 0; - is_negative = Py_SIZE(x) < 0; -#else - { - int result = PyObject_RichCompareBool(x, Py_False, Py_LT); - if (unlikely(result < 0)) - return (int) -1; - is_negative = result == 1; - } -#endif - if (is_unsigned && unlikely(is_negative)) { - goto raise_neg_overflow; - } else if (is_negative) { - stepval = PyNumber_Invert(v); - if (unlikely(!stepval)) - return (int) -1; - } else { - stepval = __Pyx_NewRef(v); - } - val = (int) 0; - mask = PyLong_FromLong((1L << chunk_size) - 1); if (unlikely(!mask)) goto done; - shift = PyLong_FromLong(chunk_size); if (unlikely(!shift)) goto done; - for (bits = 0; bits < (int) sizeof(int) * 8 - chunk_size; bits += chunk_size) { - PyObject *tmp, *digit; - digit = PyNumber_And(stepval, mask); - if (unlikely(!digit)) goto done; - idigit = PyLong_AsLong(digit); - Py_DECREF(digit); - if (unlikely(idigit < 0)) goto done; - tmp = PyNumber_Rshift(stepval, shift); - if (unlikely(!tmp)) goto done; - Py_DECREF(stepval); stepval = tmp; - val |= ((int) idigit) << bits; - #if CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030B0000 - if (Py_SIZE(stepval) == 0) - goto unpacking_done; - #endif - } - idigit = PyLong_AsLong(stepval); - if (unlikely(idigit < 0)) goto done; - remaining_bits = ((int) sizeof(int) * 8) - bits - (is_unsigned ? 0 : 1); - if (unlikely(idigit >= (1L << remaining_bits))) - goto raise_overflow; - val |= ((int) idigit) << bits; - #if CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030B0000 - unpacking_done: - #endif - if (!is_unsigned) { - if (unlikely(val & (((int) 1) << (sizeof(int) * 8 - 1)))) - goto raise_overflow; - if (is_negative) - val = ~val; - } - ret = 0; - done: - Py_XDECREF(shift); - Py_XDECREF(mask); - Py_XDECREF(stepval); -#endif + int result = PyObject_RichCompareBool(v, Py_False, Py_LT); + if (unlikely(result < 0)) { Py_DECREF(v); - if (likely(!ret)) - return val; + return (int) -1; } - return (int) -1; + is_negative = result == 1; } - } else { - int val; - PyObject *tmp = __Pyx_PyNumber_IntOrLong(x); - if (!tmp) return (int) -1; - val = __Pyx_PyInt_As_int(tmp); - Py_DECREF(tmp); + if (is_unsigned && unlikely(is_negative)) { + Py_DECREF(v); + goto raise_neg_overflow; + } else if (is_negative) { + stepval = PyNumber_Invert(v); + Py_DECREF(v); + if (unlikely(!stepval)) + return (int) -1; + } else { + stepval = v; + } + v = NULL; + val = (int) 0; + mask = PyLong_FromLong((1L << chunk_size) - 1); if (unlikely(!mask)) goto done; + shift = PyLong_FromLong(chunk_size); if (unlikely(!shift)) goto done; + for (bits = 0; bits < (int) sizeof(int) * 8 - chunk_size; bits += chunk_size) { + PyObject *tmp, *digit; + long idigit; + digit = PyNumber_And(stepval, mask); + if (unlikely(!digit)) goto done; + idigit = PyLong_AsLong(digit); + Py_DECREF(digit); + if (unlikely(idigit < 0)) goto done; + val |= ((int) idigit) << bits; + tmp = PyNumber_Rshift(stepval, shift); + if (unlikely(!tmp)) goto done; + Py_DECREF(stepval); stepval = tmp; + } + Py_DECREF(shift); shift = NULL; + Py_DECREF(mask); mask = NULL; + { + long idigit = PyLong_AsLong(stepval); + if (unlikely(idigit < 0)) goto done; + remaining_bits = ((int) sizeof(int) * 8) - bits - (is_unsigned ? 0 : 1); + if (unlikely(idigit >= (1L << remaining_bits))) + goto raise_overflow; + val |= ((int) idigit) << bits; + } + if (!is_unsigned) { + if (unlikely(val & (((int) 1) << (sizeof(int) * 8 - 1)))) + goto raise_overflow; + if (is_negative) + val = ~val; + } + ret = 0; + done: + Py_XDECREF(shift); + Py_XDECREF(mask); + Py_XDECREF(stepval); +#endif + if (unlikely(ret)) + return (int) -1; return val; } raise_overflow: diff --git a/supriya/osc.py b/supriya/osc.py deleted file mode 100644 index 3fb03466e..000000000 --- a/supriya/osc.py +++ /dev/null @@ -1,1133 +0,0 @@ -""" -Tools for sending, receiving and handling OSC messages. -""" - -import abc -import asyncio -import collections -import contextlib -import dataclasses -import datetime -import enum -import inspect -import logging -import queue -import socket -import socketserver -import struct -import threading -import time -from collections.abc import Sequence as SequenceABC -from typing import ( - Any, - Callable, - Dict, - List, - NamedTuple, - Optional, - Sequence, - Set, - Tuple, - Union, -) - -from uqbar.objects import get_repr - -from .utils import group_by_count - -osc_protocol_logger = logging.getLogger(__name__) -osc_in_logger = logging.getLogger("supriya.osc.in") -osc_out_logger = logging.getLogger("supriya.osc.out") -udp_in_logger = logging.getLogger("supriya.udp.in") -udp_out_logger = logging.getLogger("supriya.udp.out") - -BUNDLE_PREFIX = b"#bundle\x00" -IMMEDIATELY = struct.pack(">Q", 1) -NTP_TIMESTAMP_TO_SECONDS = 1.0 / 2.0**32.0 -SECONDS_TO_NTP_TIMESTAMP = 2.0**32.0 -SYSTEM_EPOCH = datetime.date(*time.gmtime(0)[0:3]) -NTP_EPOCH = datetime.date(1900, 1, 1) -NTP_DELTA = (SYSTEM_EPOCH - NTP_EPOCH).days * 24 * 3600 - - -class OscMessage: - """ - An OSC message. - - .. container:: example - - :: - - >>> from supriya.osc import OscMessage - >>> osc_message = OscMessage("/g_new", 0, 0) - >>> osc_message - OscMessage('/g_new', 0, 0) - - :: - - >>> datagram = osc_message.to_datagram() - >>> OscMessage.from_datagram(datagram) - OscMessage('/g_new', 0, 0) - - :: - - >>> print(osc_message) - size 20 - 0 2f 67 5f 6e 65 77 00 00 2c 69 69 00 00 00 00 00 |/g_new..,ii.....| - 16 00 00 00 00 |....| - - .. container:: example - - :: - - >>> osc_message = OscMessage("/foo", True, [None, [3.25]], OscMessage("/bar")) - >>> osc_message - OscMessage('/foo', True, [None, [3.25]], OscMessage('/bar')) - - :: - - >>> datagram = osc_message.to_datagram() - >>> OscMessage.from_datagram(datagram) - OscMessage('/foo', True, [None, [3.25]], OscMessage('/bar')) - - :: - - >>> print(osc_message) - size 40 - 0 2f 66 6f 6f 00 00 00 00 2c 54 5b 4e 5b 66 5d 5d |/foo....,T[N[f]]| - 16 62 00 00 00 40 50 00 00 00 00 00 0c 2f 62 61 72 |b...@P....../bar| - 32 00 00 00 00 2c 00 00 00 |....,...| - - .. container:: example - - :: - - >>> osc_message = supriya.osc.OscMessage( - ... "/foo", - ... 1, - ... 2.5, - ... supriya.osc.OscBundle( - ... contents=( - ... supriya.osc.OscMessage("/bar", "baz", 3.0), - ... supriya.osc.OscMessage("/ffff", False, True, None), - ... ) - ... ), - ... ["a", "b", ["c", "d"]], - ... ) - >>> osc_message - OscMessage('/foo', 1, 2.5, OscBundle( - contents=( - OscMessage('/bar', 'baz', 3.0), - OscMessage('/ffff', False, True, None), - ), - ), ['a', 'b', ['c', 'd']]) - - :: - - >>> datagram = osc_message.to_datagram() - >>> OscMessage.from_datagram(datagram) - OscMessage('/foo', 1, 2.5, OscBundle( - contents=( - OscMessage('/bar', 'baz', 3.0), - OscMessage('/ffff', False, True, None), - ), - ), ['a', 'b', ['c', 'd']]) - - :: - - >>> print(osc_message) - size 112 - 0 2f 66 6f 6f 00 00 00 00 2c 69 66 62 5b 73 73 5b |/foo....,ifb[ss[| - 16 73 73 5d 5d 00 00 00 00 00 00 00 01 40 20 00 00 |ss]]........@ ..| - 32 00 00 00 3c 23 62 75 6e 64 6c 65 00 00 00 00 00 |...<#bundle.....| - 48 00 00 00 01 00 00 00 14 2f 62 61 72 00 00 00 00 |......../bar....| - 64 2c 73 66 00 62 61 7a 00 40 40 00 00 00 00 00 10 |,sf.baz.@@......| - 80 2f 66 66 66 66 00 00 00 2c 46 54 4e 00 00 00 00 |/ffff...,FTN....| - 96 61 00 00 00 62 00 00 00 63 00 00 00 64 00 00 00 |a...b...c...d...| - """ - - ### INITIALIZER ### - - def __init__(self, address, *contents) -> None: - if isinstance(address, enum.Enum): - address = address.value - if not isinstance(address, (str, int)): - raise ValueError(f"address must be int or str, got {address}") - self.address = address - self.contents = tuple(contents) - - ### SPECIAL METHODS ### - - def __eq__(self, other) -> bool: - if type(self) is not type(other): - return False - if self.address != other.address: - return False - if self.contents != other.contents: - return False - return True - - def __repr__(self) -> str: - return "{}({})".format( - type(self).__name__, - ", ".join(repr(_) for _ in [self.address, *self.contents]), - ) - - def __str__(self) -> str: - return format_datagram(bytearray(self.to_datagram())) - - ### PRIVATE METHODS ### - - @staticmethod - def _decode_blob(data): - actual_length, remainder = struct.unpack(">I", data[:4])[0], data[4:] - padded_length = actual_length - if actual_length % 4 != 0: - padded_length = (actual_length // 4 + 1) * 4 - return remainder[:padded_length][:actual_length], remainder[padded_length:] - - @staticmethod - def _decode_string(data): - actual_length = data.index(b"\x00") - padded_length = (actual_length // 4 + 1) * 4 - return str(data[:actual_length], "ascii"), data[padded_length:] - - @staticmethod - def _encode_string(value): - result = bytes(value + "\x00", "ascii") - if len(result) % 4 != 0: - width = (len(result) // 4 + 1) * 4 - result = result.ljust(width, b"\x00") - return result - - @staticmethod - def _encode_blob(value): - result = bytes(struct.pack(">I", len(value)) + value) - if len(result) % 4 != 0: - width = (len(result) // 4 + 1) * 4 - result = result.ljust(width, b"\x00") - return result - - @classmethod - def _encode_value(cls, value): - if hasattr(value, "to_datagram"): - value = bytearray(value.to_datagram()) - elif isinstance(value, enum.Enum): - value = value.value - type_tags, encoded_value = "", b"" - if isinstance(value, (bytearray, bytes)): - type_tags += "b" - encoded_value = cls._encode_blob(value) - elif isinstance(value, str): - type_tags += "s" - encoded_value = cls._encode_string(value) - elif isinstance(value, bool): - type_tags += "T" if value else "F" - elif isinstance(value, float): - type_tags += "f" - encoded_value += struct.pack(">f", value) - elif isinstance(value, int): - type_tags += "i" - encoded_value += struct.pack(">i", value) - elif value is None: - type_tags += "N" - elif isinstance(value, SequenceABC): - type_tags += "[" - for sub_value in value: - sub_type_tags, sub_encoded_value = cls._encode_value(sub_value) - type_tags += sub_type_tags - encoded_value += sub_encoded_value - type_tags += "]" - else: - message = "Cannot encode {!r}".format(value) - raise TypeError(message) - return type_tags, encoded_value - - ### PUBLIC METHODS ### - - def to_datagram(self) -> bytes: - # address can be a string or (in SuperCollider) an int - if isinstance(self.address, str): - encoded_address = self._encode_string(self.address) - else: - encoded_address = struct.pack(">i", self.address) - encoded_type_tags = "," - encoded_contents = b"" - for value in self.contents or (): - type_tags, encoded_value = self._encode_value(value) - encoded_type_tags += type_tags - encoded_contents += encoded_value - return ( - encoded_address + self._encode_string(encoded_type_tags) + encoded_contents - ) - - @classmethod - def from_datagram(cls, datagram): - remainder = datagram - address, remainder = cls._decode_string(remainder) - type_tags, remainder = cls._decode_string(remainder) - contents = [] - array_stack = [contents] - for type_tag in type_tags[1:]: - if type_tag == "i": - value, remainder = struct.unpack(">i", remainder[:4])[0], remainder[4:] - array_stack[-1].append(value) - elif type_tag == "f": - value, remainder = struct.unpack(">f", remainder[:4])[0], remainder[4:] - array_stack[-1].append(value) - elif type_tag == "d": - value, remainder = struct.unpack(">d", remainder[:8])[0], remainder[8:] - array_stack[-1].append(value) - elif type_tag == "s": - value, remainder = cls._decode_string(remainder) - array_stack[-1].append(value) - elif type_tag == "b": - value, remainder = cls._decode_blob(remainder) - for class_ in (OscBundle, OscMessage): - try: - value = class_.from_datagram(value) - break - except Exception: - pass - array_stack[-1].append(value) - elif type_tag == "T": - array_stack[-1].append(True) - elif type_tag == "F": - array_stack[-1].append(False) - elif type_tag == "N": - array_stack[-1].append(None) - elif type_tag == "[": - array = [] - array_stack[-1].append(array) - array_stack.append(array) - elif type_tag == "]": - array_stack.pop() - else: - raise RuntimeError(f"Unable to parse type {type_tag!r}") - return cls(address, *contents) - - def to_list(self): - result = [self.address] - for x in self.contents: - if hasattr(x, "to_list"): - result.append(x.to_list()) - else: - result.append(x) - return result - - -class OscBundle: - """ - An OSC bundle. - - :: - - >>> import supriya.osc - >>> message_one = supriya.osc.OscMessage("/one", 1) - >>> message_two = supriya.osc.OscMessage("/two", 2) - >>> message_three = supriya.osc.OscMessage("/three", 3) - - :: - - >>> inner_bundle = supriya.osc.OscBundle( - ... timestamp=1401557034.5, - ... contents=(message_one, message_two), - ... ) - >>> inner_bundle - OscBundle( - contents=( - OscMessage('/one', 1), - OscMessage('/two', 2), - ), - timestamp=1401557034.5, - ) - - :: - - >>> print(inner_bundle) - size 56 - 0 23 62 75 6e 64 6c 65 00 d7 34 8e aa 80 00 00 00 |#bundle..4......| - 16 00 00 00 10 2f 6f 6e 65 00 00 00 00 2c 69 00 00 |..../one....,i..| - 32 00 00 00 01 00 00 00 10 2f 74 77 6f 00 00 00 00 |......../two....| - 48 2c 69 00 00 00 00 00 02 |,i......| - - :: - - >>> outer_bundle = supriya.osc.OscBundle( - ... contents=(inner_bundle, message_three), - ... ) - >>> outer_bundle - OscBundle( - contents=( - OscBundle( - contents=( - OscMessage('/one', 1), - OscMessage('/two', 2), - ), - timestamp=1401557034.5, - ), - OscMessage('/three', 3), - ), - ) - - :: - - >>> print(outer_bundle) - size 96 - 0 23 62 75 6e 64 6c 65 00 00 00 00 00 00 00 00 01 |#bundle.........| - 16 00 00 00 38 23 62 75 6e 64 6c 65 00 d7 34 8e aa |...8#bundle..4..| - 32 80 00 00 00 00 00 00 10 2f 6f 6e 65 00 00 00 00 |......../one....| - 48 2c 69 00 00 00 00 00 01 00 00 00 10 2f 74 77 6f |,i........../two| - 64 00 00 00 00 2c 69 00 00 00 00 00 02 00 00 00 10 |....,i..........| - 80 2f 74 68 72 65 65 00 00 2c 69 00 00 00 00 00 03 |/three..,i......| - - :: - - >>> datagram = outer_bundle.to_datagram() - - :: - - >>> decoded_bundle = supriya.osc.OscBundle.from_datagram(datagram) - >>> decoded_bundle - OscBundle( - contents=( - OscBundle( - contents=( - OscMessage('/one', 1), - OscMessage('/two', 2), - ), - timestamp=1401557034.5, - ), - OscMessage('/three', 3), - ), - ) - - :: - - >>> decoded_bundle == outer_bundle - True - """ - - ### INITIALIZER ### - - def __init__(self, timestamp=None, contents=None) -> None: - prototype = (OscMessage, type(self)) - self.timestamp = timestamp - contents = contents or () - for x in contents or (): - if not isinstance(x, prototype): - raise ValueError(contents) - self.contents = tuple(contents) - - ### SPECIAL METHODS ### - - def __eq__(self, other) -> bool: - if type(self) is not type(other): - return False - if self.timestamp != other.timestamp: - return False - if self.contents != other.contents: - return False - return True - - def __repr__(self) -> str: - return get_repr(self) - - def __str__(self) -> str: - return format_datagram(bytearray(self.to_datagram())) - - ### PRIVATE METHODS ### - - @staticmethod - def _decode_date(data): - data, remainder = data[:8], data[8:] - if data == IMMEDIATELY: - return None, remainder - date = (struct.unpack(">Q", data)[0] / SECONDS_TO_NTP_TIMESTAMP) - NTP_DELTA - return date, remainder - - @staticmethod - def _encode_date(seconds, realtime=True): - if seconds is None: - return IMMEDIATELY - if realtime: - seconds = seconds + NTP_DELTA - if seconds >= 4294967296: # 2**32 - seconds = seconds % 4294967296 - return struct.pack(">Q", int(seconds * SECONDS_TO_NTP_TIMESTAMP)) - - ### PUBLIC METHODS ### - - @classmethod - def from_datagram(cls, datagram): - if not datagram.startswith(BUNDLE_PREFIX): - raise ValueError("datagram is not a bundle") - remainder = datagram[8:] - timestamp, remainder = cls._decode_date(remainder) - contents = [] - while len(remainder): - length, remainder = struct.unpack(">i", remainder[:4])[0], remainder[4:] - if remainder.startswith(BUNDLE_PREFIX): - item = cls.from_datagram(remainder[:length]) - else: - item = OscMessage.from_datagram(remainder[:length]) - contents.append(item) - remainder = remainder[length:] - osc_bundle = cls(timestamp=timestamp, contents=tuple(contents)) - return osc_bundle - - @classmethod - def partition(cls, messages, timestamp=None): - bundles = [] - contents = [] - message = collections.deque(messages) - remaining = maximum = 8192 - len(BUNDLE_PREFIX) - 4 - while messages: - message = messages.popleft() - datagram = message.to_datagram() - remaining -= len(datagram) + 4 - if remaining > 0: - contents.append(message) - else: - bundles.append(cls(timestamp=timestamp, contents=contents)) - contents = [message] - remaining = maximum - if contents: - bundles.append(cls(timestamp=timestamp, contents=contents)) - return bundles - - def to_datagram(self, realtime=True) -> bytes: - datagram = BUNDLE_PREFIX - datagram += self._encode_date(self.timestamp, realtime=realtime) - for content in self.contents: - content_datagram = content.to_datagram() - datagram += struct.pack(">i", len(content_datagram)) - datagram += content_datagram - return datagram - - def to_list(self): - result = [self.timestamp] - result.append([x.to_list() for x in self.contents]) - return result - - -class OscProtocolOffline(Exception): - pass - - -class OscProtocolAlreadyConnected(Exception): - pass - - -class OscCallback(NamedTuple): - pattern: Tuple[Union[str, int, float], ...] - procedure: Callable - failure_pattern: Optional[Tuple[Union[str, int, float], ...]] = None - once: bool = False - args: Optional[Tuple] = None - kwargs: Optional[Dict] = None - - -@dataclasses.dataclass -class HealthCheck: - request_pattern: List[str] - response_pattern: List[str] - callback: Callable - active: bool = True - timeout: float = 1.0 - backoff_factor: float = 1.5 - max_attempts: int = 5 - - -class OscProtocol(metaclass=abc.ABCMeta): - ### INITIALIZER ### - - def __init__(self) -> None: - self.callbacks: Dict[Any, Any] = {} - self.captures: Set[Capture] = set() - self.healthcheck: Optional[HealthCheck] = None - self.healthcheck_osc_callback: Optional[OscCallback] = None - self.attempts = 0 - self.ip_address = "127.0.0.1" - self.is_running: bool = False - self.port = 57551 - - ### PRIVATE METHODS ### - - def _add_callback(self, callback: OscCallback) -> None: - patterns = [callback.pattern] - if callback.failure_pattern: - patterns.append(callback.failure_pattern) - for pattern in patterns: - callback_map = self.callbacks - for item in pattern: - callbacks, callback_map = callback_map.setdefault(item, ([], {})) - callbacks.append(callback) - - def _disconnect(self) -> None: - raise NotImplementedError - - def _match_callbacks(self, message) -> List[OscCallback]: - items = (message.address,) + message.contents - matching_callbacks = [] - callback_map = self.callbacks - for item in items: - if item not in callback_map: - break - callbacks, callback_map = callback_map[item] - matching_callbacks.extend(callbacks) - for callback in matching_callbacks: - if callback.once: - self.unregister(callback) - return matching_callbacks - - def _remove_callback(self, callback: OscCallback) -> None: - def delete(pattern, original_callback_map): - key = pattern.pop(0) - if key not in original_callback_map: - return - callbacks, callback_map = original_callback_map[key] - if pattern: - delete(pattern, callback_map) - if callback in callbacks: - callbacks.remove(callback) - if not callbacks and not callback_map: - original_callback_map.pop(key) - - patterns = [callback.pattern] - if callback.failure_pattern: - patterns.append(callback.failure_pattern) - for pattern in patterns: - delete(list(pattern), self.callbacks) - - def _pass_healthcheck(self, message) -> None: - osc_protocol_logger.info(f"[{self.ip_address}:{self.port}] healthcheck: passed") - self.attempts = 0 - - def _setup( - self, ip_address: str, port: int, healthcheck: Optional[HealthCheck] - ) -> None: - self.ip_address = ip_address - self.port = port - self.healthcheck = healthcheck - if self.healthcheck: - self.healthcheck_osc_callback = self.register( - pattern=self.healthcheck.response_pattern, - procedure=self._pass_healthcheck, - ) - - def _teardown(self): - osc_protocol_logger.info(f"[{self.ip_address}:{self.port}] Tearing down...") - self.is_running = False - if self.healthcheck is not None: - self.unregister(self.healthcheck_osc_callback) - - def _validate_callback( - self, - pattern, - procedure, - *, - failure_pattern=None, - once=False, - args: Optional[Tuple] = None, - kwargs: Optional[Dict] = None, - ): - if isinstance(pattern, (str, int, float)): - pattern = [pattern] - if isinstance(failure_pattern, (str, int, float)): - failure_pattern = [failure_pattern] - if not callable(procedure): - raise ValueError(procedure) - return OscCallback( - pattern=tuple(pattern), - failure_pattern=failure_pattern, - procedure=procedure, - once=bool(once), - args=args, - kwargs=kwargs, - ) - - def _validate_receive(self, datagram): - udp_in_logger.debug(f"[{self.ip_address}:{self.port}] {datagram}") - try: - message = OscMessage.from_datagram(datagram) - except Exception: - raise - osc_in_logger.debug(f"[{self.ip_address}:{self.port}] {message!r}") - for capture in self.captures: - capture.messages.append( - CaptureEntry(timestamp=time.time(), label="R", message=message) - ) - for callback in self._match_callbacks(message): - yield callback, message - - def _validate_send(self, message): - if not self.is_running: - raise OscProtocolOffline - if not isinstance(message, (str, SequenceABC, OscBundle, OscMessage)): - raise ValueError(message) - if isinstance(message, str): - message = OscMessage(message) - elif isinstance(message, SequenceABC): - message = OscMessage(*message) - osc_out_logger.debug(f"[{self.ip_address}:{self.port}] {message!r}") - for capture in self.captures: - capture.messages.append( - CaptureEntry(timestamp=time.time(), label="S", message=message) - ) - datagram = message.to_datagram() - udp_out_logger.debug(f"[{self.ip_address}:{self.port}] {datagram}") - return datagram - - ### PUBLIC METHODS ### - - @abc.abstractmethod - def activate_healthcheck(self) -> None: - raise NotImplementedError - - def capture(self) -> "Capture": - return Capture(self) - - def disconnect(self) -> None: - osc_protocol_logger.info(f"[{self.ip_address}:{self.port}] disconnecting") - self._disconnect() - osc_protocol_logger.info(f"[{self.ip_address}:{self.port}] ...disconnected") - - @abc.abstractmethod - def register( - self, - pattern: Sequence[Union[str, float]], - procedure: Callable[[OscMessage], None], - *, - failure_pattern: Optional[Sequence[Union[str, float]]] = None, - once: bool = False, - args: Optional[Tuple] = None, - kwargs: Optional[Dict] = None, - ) -> OscCallback: - raise NotImplementedError - - @abc.abstractmethod - def send(self, message) -> None: - raise NotImplementedError - - @abc.abstractmethod - def unregister(self, callback: OscCallback): - raise NotImplementedError - - -class AsyncOscProtocol(asyncio.DatagramProtocol, OscProtocol): - ### INITIALIZER ### - - def __init__(self) -> None: - asyncio.DatagramProtocol.__init__(self) - OscProtocol.__init__(self) - self.background_tasks: Set[asyncio.Task] = set() - self.healthcheck_task: Optional[asyncio.Task] = None - - ### PRIVATE METHODS ### - - def _disconnect(self) -> None: - if not self.is_running: - osc_protocol_logger.info( - f"{self.ip_address}:{self.port} already disconnected!" - ) - return - self._teardown() - self.transport.close() - if self.healthcheck_task: - self.healthcheck_task.cancel() - - async def _run_healthcheck(self): - while self.is_running: - if self.attempts >= self.healthcheck.max_attempts: - osc_protocol_logger.info( - f"[{self.ip_address}:{self.port}] health check: failure limit exceeded" - ) - self.exit_future.set_result(True) - self._teardown() - self.transport.close() - obj_ = self.healthcheck.callback() - if asyncio.iscoroutine(obj_): - asyncio.get_running_loop().create_task(obj_) - return - osc_protocol_logger.info( - f"[{self.ip_address}:{self.port}] healthcheck: checking..." - ) - self.send(OscMessage(*self.healthcheck.request_pattern)) - sleep_time = self.healthcheck.timeout * pow( - self.healthcheck.backoff_factor, self.attempts - ) - self.attempts += 1 - await asyncio.sleep(sleep_time) - - ### PUBLIC METHODS ### - - def activate_healthcheck(self) -> None: - if not self.healthcheck: - return - elif self.healthcheck.active: - return - self.healthcheck_task = asyncio.get_running_loop().create_task( - self._run_healthcheck() - ) - - async def connect( - self, ip_address: str, port: int, *, healthcheck: Optional[HealthCheck] = None - ): - osc_protocol_logger.info(f"[{self.ip_address}:{self.port}] connecting...") - if self.is_running: - osc_protocol_logger.info( - f"[{self.ip_address}:{self.port}] already connected!" - ) - raise OscProtocolAlreadyConnected - self._setup(ip_address, port, healthcheck) - loop = asyncio.get_running_loop() - self.exit_future = loop.create_future() - _, protocol = await loop.create_datagram_endpoint( - lambda: self, remote_addr=(ip_address, port) - ) - if self.healthcheck and self.healthcheck.active: - self.healthcheck_task = asyncio.get_running_loop().create_task( - self._run_healthcheck() - ) - osc_protocol_logger.info(f"[{self.ip_address}:{self.port}] ...connected") - - def connection_made(self, transport): - osc_protocol_logger.info(f"[{self.ip_address}:{self.port}] connection made") - self.transport = transport - self.is_running = True - - def connection_lost(self, exc): - osc_protocol_logger.info(f"[{self.ip_address}:{self.port}] connection lost") - self.exit_future.set_result(True) - - def datagram_received(self, data, addr): - loop = asyncio.get_running_loop() - for callback, message in self._validate_receive(data): - result = callback.procedure( - message, *(callback.args or ()), **(callback.kwargs or {}) - ) - if inspect.iscoroutine(result): - task = loop.create_task(result) - self.background_tasks.add(task) - task.add_done_callback(self.background_tasks.discard) - - def error_received(self, exc): - osc_out_logger.warning(f"[{self.ip_address}:{self.port}] errored: {exc}") - - def register( - self, - pattern: Sequence[Union[str, float]], - procedure: Callable[[OscMessage], None], - *, - failure_pattern: Optional[Sequence[Union[str, float]]] = None, - once: bool = False, - args: Optional[Tuple] = None, - kwargs: Optional[Dict] = None, - ) -> OscCallback: - osc_protocol_logger.info( - f"[{self.ip_address}:{self.port}] registering pattern: {pattern!r}" - ) - callback = self._validate_callback( - pattern, - procedure, - failure_pattern=failure_pattern, - once=once, - args=args, - kwargs=kwargs, - ) - self._add_callback(callback) - return callback - - def send(self, message): - osc_protocol_logger.debug( - f"[{self.ip_address}:{self.port}] sending: {message!r}" - ) - datagram = self._validate_send(message) - return self.transport.sendto(datagram) - - def unregister(self, callback: OscCallback): - self._remove_callback(callback) - - -class ThreadedOscServer(socketserver.UDPServer): - osc_protocol: "ThreadedOscProtocol" - - def verify_request(self, request, client_address): - self.osc_protocol._process_command_queue() - return True - - def service_actions(self): - if self.osc_protocol.healthcheck.active: - self.osc_protocol._run_healthcheck() - - -class ThreadedOscHandler(socketserver.BaseRequestHandler): - def handle(self): - data = self.request[0] - for callback, message in self.server.osc_protocol._validate_receive(data): - callback.procedure( - message, *(callback.args or ()), **(callback.kwargs or {}) - ) - - -class ThreadedOscProtocol(OscProtocol): - ### INITIALIZER ### - - def __init__(self): - OscProtocol.__init__(self) - self.command_queue = queue.Queue() - self.lock = threading.RLock() - self.osc_server = None - self.osc_server_thread = None - - ### PRIVATE METHODS ### - - def _disconnect(self) -> None: - with self.lock: - if not self.is_running: - osc_protocol_logger.info( - f"{self.ip_address}:{self.port} already disconnected!" - ) - return - self._teardown() - if not self.osc_server._BaseServer__shutdown_request: - self.osc_server.shutdown() - self.osc_server = None - self.osc_server_thread = None - - def _process_command_queue(self): - while self.command_queue.qsize(): - try: - action, callback = self.command_queue.get() - except queue.Empty: - continue - if action == "add": - self._add_callback(callback) - elif action == "remove": - self._remove_callback(callback) - - def _run_healthcheck(self): - if self.healthcheck is None: - return - now = time.time() - if now < self.healthcheck_deadline: - return - if self.attempts > 0: - remaining = self.healthcheck.max_attempts - self.attempts - osc_protocol_logger.info( - f"[{self.ip_address}:{self.port}] healthcheck failed, {remaining} attempts remaining" - ) - new_timeout = self.healthcheck.timeout * pow( - self.healthcheck.backoff_factor, self.attempts - ) - self.healthcheck_deadline = now + new_timeout - self.attempts += 1 - if self.attempts <= self.healthcheck.max_attempts: - osc_protocol_logger.info( - f"[{self.ip_address}:{self.port}] healthcheck: checking..." - ) - self.send(OscMessage(*self.healthcheck.request_pattern)) - return - osc_protocol_logger.info( - f"[{self.ip_address}:{self.port}] healthcheck: failure limit exceeded" - ) - self.osc_server._BaseServer__shutdown_request = True - self.disconnect() - self.healthcheck.callback() - - def _server_factory(self, ip_address, port): - server = ThreadedOscServer( - (self.ip_address, self.port), ThreadedOscHandler, bind_and_activate=False - ) - server.osc_protocol = self - return server - - ### PUBLIC METHODS ### - - def activate_healthcheck(self) -> None: - if not self.healthcheck: - return - elif self.healthcheck.active: - return - self.healthcheck.active = True - - def connect( - self, ip_address: str, port: int, *, healthcheck: Optional[HealthCheck] = None - ): - osc_protocol_logger.info(f"[{self.ip_address}:{self.port}] connecting...") - if self.is_running: - osc_protocol_logger.info( - f"[{self.ip_address}:{self.port}] already connected!" - ) - raise OscProtocolAlreadyConnected - self._setup(ip_address, port, healthcheck) - self.healthcheck_deadline = time.time() - self.osc_server = self._server_factory(ip_address, port) - self.osc_server_thread = threading.Thread(target=self.osc_server.serve_forever) - self.osc_server_thread.daemon = True - self.osc_server_thread.start() - self.is_running = True - osc_protocol_logger.info(f"[{self.ip_address}:{self.port}] ...connected") - - def register( - self, - pattern: Sequence[Union[str, float]], - procedure: Callable[[OscMessage], None], - *, - failure_pattern: Optional[Sequence[Union[str, float]]] = None, - once: bool = False, - args: Optional[Tuple] = None, - kwargs: Optional[Dict] = None, - ) -> OscCallback: - """ - Register a callback. - """ - callback = self._validate_callback( - pattern, - procedure, - failure_pattern=failure_pattern, - once=once, - args=args, - kwargs=kwargs, - ) - # Command queue prevents lock contention. - self.command_queue.put(("add", callback)) - return callback - - def send(self, message) -> None: - datagram = self._validate_send(message) - try: - self.osc_server.socket.sendto(datagram, (self.ip_address, self.port)) - except OSError: - # print(message) - raise - - def unregister(self, callback: OscCallback) -> None: - """ - Unregister a callback. - """ - # Command queue prevents lock contention. - self.command_queue.put(("remove", callback)) - - -class CaptureEntry(NamedTuple): - timestamp: float - label: str - message: Union[OscMessage, OscBundle] - - -class Capture: - ### INITIALIZER ### - - def __init__(self, osc_protocol): - self.osc_protocol = osc_protocol - self.messages = [] - - ### SPECIAL METHODS ### - - def __enter__(self): - self.osc_protocol.captures.add(self) - self.messages[:] = [] - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.osc_protocol.captures.remove(self) - - def __iter__(self): - return iter(self.messages) - - def __len__(self): - return len(self.messages) - - ### PUBLIC METHODS ### - - def filtered( - self, sent=True, received=True, status=True - ) -> List[Union[OscBundle, OscMessage]]: - messages = [] - for _, label, message in self.messages: - if label == "R" and not received: - continue - if label == "S" and not sent: - continue - if ( - isinstance(message, OscMessage) - and message.address in ("/status", "/status.reply") - and not status - ): - continue - messages.append(message) - return messages - - ### PUBLIC PROPERTIES ### - - @property - def received_messages(self): - return [ - (timestamp, osc_message) - for timestamp, label, osc_message in self.messages - if label == "R" - ] - - @property - def sent_messages(self): - return [ - (timestamp, osc_message) - for timestamp, label, osc_message in self.messages - if label == "S" - ] - - -def find_free_port(): - with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_DGRAM)) as s: - s.bind(("", 0)) - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - return s.getsockname()[1] - - -def format_datagram(datagram): - result = [] - result.append("size {}".format(len(datagram))) - index = 0 - while index < len(datagram): - chunk = datagram[index : index + 16] - line = "{: >4} ".format(index) - hex_blocks = [] - ascii_block = "" - for chunk in group_by_count(chunk, 4): - hex_block = [] - for byte in chunk: - char = int(byte) - if 31 < char < 127: - char = chr(char) - else: - char = "." - ascii_block += char - hexed = hex(byte)[2:].zfill(2) - hex_block.append(hexed) - hex_block = " ".join(hex_block) - hex_blocks.append(hex_block) - hex_blocks = " ".join(hex_blocks) - ascii_block = "|{}|".format(ascii_block) - hex_blocks = "{: <53}".format(hex_blocks) - line += hex_blocks - line += ascii_block - result.append(line) - index += 16 - result = "\n".join(result) - return result - - -__all__ = [ - "AsyncOscProtocol", - "Capture", - "CaptureEntry", - "HealthCheck", - "OscBundle", - "OscCallback", - "OscMessage", - "OscProtocol", - "ThreadedOscProtocol", - "find_free_port", -] diff --git a/supriya/osc/__init__.py b/supriya/osc/__init__.py new file mode 100644 index 000000000..7e773c0bb --- /dev/null +++ b/supriya/osc/__init__.py @@ -0,0 +1,32 @@ +""" +Tools for sending, receiving and handling OSC messages. +""" + +from .asynchronous import AsyncOscProtocol +from .messages import OscBundle, OscMessage +from .protocols import ( + Capture, + CaptureEntry, + HealthCheck, + OscCallback, + OscProtocol, + OscProtocolAlreadyConnected, + OscProtocolOffline, + find_free_port, +) +from .threaded import ThreadedOscProtocol + +__all__ = [ + "AsyncOscProtocol", + "Capture", + "CaptureEntry", + "HealthCheck", + "OscBundle", + "OscCallback", + "OscMessage", + "OscProtocol", + "OscProtocolAlreadyConnected", + "OscProtocolOffline", + "ThreadedOscProtocol", + "find_free_port", +] diff --git a/supriya/osc/asynchronous.py b/supriya/osc/asynchronous.py new file mode 100644 index 000000000..308405439 --- /dev/null +++ b/supriya/osc/asynchronous.py @@ -0,0 +1,193 @@ +import asyncio +from collections.abc import Sequence as SequenceABC +from typing import Awaitable, Callable, Dict, Optional, Sequence, Set, Tuple, Union + +from .messages import OscBundle, OscMessage +from .protocols import ( + BootStatus, + HealthCheck, + OscCallback, + OscProtocol, + OscProtocolAlreadyConnected, + osc_out_logger, + osc_protocol_logger, +) + + +class AsyncOscProtocol(asyncio.DatagramProtocol, OscProtocol): + ### INITIALIZER ### + + def __init__( + self, + *, + name: Optional[str] = None, + on_connect_callback: Optional[Callable] = None, + on_disconnect_callback: Optional[Callable] = None, + on_panic_callback: Optional[Callable] = None, + ) -> None: + asyncio.DatagramProtocol.__init__(self) + OscProtocol.__init__( + self, + boot_future=asyncio.Future(), + exit_future=asyncio.Future(), + name=name, + on_connect_callback=on_connect_callback, + on_disconnect_callback=on_disconnect_callback, + on_panic_callback=on_panic_callback, + ) + self.background_tasks: Set[asyncio.Task] = set() + self.healthcheck_task: Optional[asyncio.Task] = None + + ### PRIVATE METHODS ### + + async def _disconnect(self, panicked: bool = False) -> None: + super()._disconnect(panicked=panicked) + self.transport.close() + if self.healthcheck_task: + self.healthcheck_task.cancel() + await self._on_disconnect(panicked=panicked) + + async def _on_connect(self) -> None: + super()._on_connect() + if self.on_connect_callback: + if asyncio.iscoroutine(result := self.on_connect_callback()): + await result + + async def _on_disconnect(self, panicked: bool = False) -> None: + super()._on_disconnect(panicked=panicked) + if panicked and self.on_panic_callback: + if asyncio.iscoroutine(result := self.on_panic_callback()): + await result + elif not panicked and self.on_disconnect_callback: + if asyncio.iscoroutine(result := self.on_disconnect_callback()): + await result + + async def _on_healthcheck_passed(self, message: OscMessage) -> None: + super()._on_healthcheck_passed(message) + if self.status == BootStatus.BOOTING: + await self._on_connect() + + async def _run_healthcheck(self): + if self.healthcheck is None: + return + while self.status in (BootStatus.BOOTING, BootStatus.ONLINE): + if self.attempts >= self.healthcheck.max_attempts: + await self._disconnect(panicked=True) + return + osc_protocol_logger.info( + f"[{self.ip_address}:{self.port}/{self.name or hex(id(self))}] " + "healthcheck: checking ..." + ) + self.send(OscMessage(*self.healthcheck.request_pattern)) + sleep_time = self.healthcheck.timeout * pow( + self.healthcheck.backoff_factor, self.attempts + ) + self.attempts += 1 + await asyncio.sleep(sleep_time) + + ### OVERRIDES ### + + def connection_made(self, transport): + osc_protocol_logger.info( + f"[{self.ip_address}:{self.port}/{self.name or hex(id(self))}] " + "connection made!" + ) + self.transport = transport + + def connection_lost(self, exc): + osc_protocol_logger.info( + f"[{self.ip_address}:{self.port}/{self.name or hex(id(self))}] " + "connection lost!" + ) + + def datagram_received(self, data, addr): + loop = asyncio.get_running_loop() + for callback, message in self._validate_receive(data): + if asyncio.iscoroutine( + result := callback.procedure( + message, *(callback.args or ()), **(callback.kwargs or {}) + ) + ): + self.background_tasks.add(task := loop.create_task(result)) + task.add_done_callback(self.background_tasks.discard) + + def error_received(self, exc): + osc_out_logger.warning( + f"[{self.ip_address}:{self.port}/{self.name or hex(id(self))}] " + f"errored: {exc}" + ) + + ### PUBLIC METHODS ### + + def activate_healthcheck(self) -> None: + if self._activate_healthcheck(): + self.healthcheck_task = asyncio.get_running_loop().create_task( + self._run_healthcheck() + ) + + async def connect( + self, ip_address: str, port: int, *, healthcheck: Optional[HealthCheck] = None + ): + if self.status != BootStatus.OFFLINE: + osc_protocol_logger.info( + f"[{self.ip_address}:{self.port}/{self.name or hex(id(self))}] " + "already connected!" + ) + raise OscProtocolAlreadyConnected + osc_protocol_logger.info( + f"[{self.ip_address}:{self.port}/{self.name or hex(id(self))}] " + "connecting ..." + ) + self._setup(ip_address, port, healthcheck) + loop = asyncio.get_running_loop() + self.boot_future = loop.create_future() + self.exit_future = loop.create_future() + _, protocol = await loop.create_datagram_endpoint( + lambda: self, remote_addr=(ip_address, port) + ) + if self.healthcheck and self.healthcheck.active: + self.healthcheck_task = asyncio.get_running_loop().create_task( + self._run_healthcheck() + ) + elif not self.healthcheck: + await self._on_connect() + + async def disconnect(self) -> None: + if self.status != BootStatus.ONLINE: + osc_protocol_logger.info( + f"[{self.ip_address}:{self.port}/{self.name or hex(id(self))}] " + "already disconnected!" + ) + return + await self._disconnect() + + def register( + self, + pattern: Sequence[Union[str, float]], + procedure: Callable[[OscMessage], Optional[Awaitable[None]]], + *, + failure_pattern: Optional[Sequence[Union[str, float]]] = None, + once: bool = False, + args: Optional[Tuple] = None, + kwargs: Optional[Dict] = None, + ) -> OscCallback: + """ + Register a callback. + """ + self._add_callback( + callback := self._register( + pattern, + procedure, + failure_pattern=failure_pattern, + once=once, + args=args, + kwargs=kwargs, + ) + ) + return callback + + def send(self, message: Union[OscBundle, OscMessage, SequenceABC, str]) -> None: + self.transport.sendto(self._send(message)) + + def unregister(self, callback: OscCallback): + self._remove_callback(callback) diff --git a/supriya/osc/messages.py b/supriya/osc/messages.py new file mode 100644 index 000000000..878431abc --- /dev/null +++ b/supriya/osc/messages.py @@ -0,0 +1,505 @@ +import collections +import datetime +import enum +import struct +import time +from collections.abc import Sequence as SequenceABC +from typing import List + +from uqbar.objects import get_repr + +from ..utils import group_by_count + +BUNDLE_PREFIX = b"#bundle\x00" +IMMEDIATELY = struct.pack(">Q", 1) +NTP_TIMESTAMP_TO_SECONDS = 1.0 / 2.0**32.0 +SECONDS_TO_NTP_TIMESTAMP = 2.0**32.0 +SYSTEM_EPOCH = datetime.date(*time.gmtime(0)[0:3]) +NTP_EPOCH = datetime.date(1900, 1, 1) +NTP_DELTA = (SYSTEM_EPOCH - NTP_EPOCH).days * 24 * 3600 + + +def format_datagram(datagram: bytes) -> str: + result: List[str] = ["size {}".format(len(datagram))] + index = 0 + while index < len(datagram): + hex_blocks = [] + ascii_block = "" + for chunk in group_by_count(datagram[index : index + 16], 4): + hex_block = [] + for byte in chunk: + if 31 < int(byte) < 127: + char = chr(int(byte)) + else: + char = "." + ascii_block += char + hexed = hex(byte)[2:].zfill(2) + hex_block.append(hexed) + hex_blocks.append(" ".join(hex_block)) + line = "{: >4} ".format(index) + line += "{: <53}".format(" ".join(hex_blocks)) + line += "|{}|".format(ascii_block) + result.append(line) + index += 16 + return "\n".join(result) + + +class OscMessage: + """ + An OSC message. + + .. container:: example + + :: + + >>> from supriya.osc import OscMessage + >>> osc_message = OscMessage("/g_new", 0, 0) + >>> osc_message + OscMessage('/g_new', 0, 0) + + :: + + >>> datagram = osc_message.to_datagram() + >>> OscMessage.from_datagram(datagram) + OscMessage('/g_new', 0, 0) + + :: + + >>> print(osc_message) + size 20 + 0 2f 67 5f 6e 65 77 00 00 2c 69 69 00 00 00 00 00 |/g_new..,ii.....| + 16 00 00 00 00 |....| + + .. container:: example + + :: + + >>> osc_message = OscMessage("/foo", True, [None, [3.25]], OscMessage("/bar")) + >>> osc_message + OscMessage('/foo', True, [None, [3.25]], OscMessage('/bar')) + + :: + + >>> datagram = osc_message.to_datagram() + >>> OscMessage.from_datagram(datagram) + OscMessage('/foo', True, [None, [3.25]], OscMessage('/bar')) + + :: + + >>> print(osc_message) + size 40 + 0 2f 66 6f 6f 00 00 00 00 2c 54 5b 4e 5b 66 5d 5d |/foo....,T[N[f]]| + 16 62 00 00 00 40 50 00 00 00 00 00 0c 2f 62 61 72 |b...@P....../bar| + 32 00 00 00 00 2c 00 00 00 |....,...| + + .. container:: example + + :: + + >>> osc_message = supriya.osc.OscMessage( + ... "/foo", + ... 1, + ... 2.5, + ... supriya.osc.OscBundle( + ... contents=( + ... supriya.osc.OscMessage("/bar", "baz", 3.0), + ... supriya.osc.OscMessage("/ffff", False, True, None), + ... ) + ... ), + ... ["a", "b", ["c", "d"]], + ... ) + >>> osc_message + OscMessage('/foo', 1, 2.5, OscBundle( + contents=( + OscMessage('/bar', 'baz', 3.0), + OscMessage('/ffff', False, True, None), + ), + ), ['a', 'b', ['c', 'd']]) + + :: + + >>> datagram = osc_message.to_datagram() + >>> OscMessage.from_datagram(datagram) + OscMessage('/foo', 1, 2.5, OscBundle( + contents=( + OscMessage('/bar', 'baz', 3.0), + OscMessage('/ffff', False, True, None), + ), + ), ['a', 'b', ['c', 'd']]) + + :: + + >>> print(osc_message) + size 112 + 0 2f 66 6f 6f 00 00 00 00 2c 69 66 62 5b 73 73 5b |/foo....,ifb[ss[| + 16 73 73 5d 5d 00 00 00 00 00 00 00 01 40 20 00 00 |ss]]........@ ..| + 32 00 00 00 3c 23 62 75 6e 64 6c 65 00 00 00 00 00 |...<#bundle.....| + 48 00 00 00 01 00 00 00 14 2f 62 61 72 00 00 00 00 |......../bar....| + 64 2c 73 66 00 62 61 7a 00 40 40 00 00 00 00 00 10 |,sf.baz.@@......| + 80 2f 66 66 66 66 00 00 00 2c 46 54 4e 00 00 00 00 |/ffff...,FTN....| + 96 61 00 00 00 62 00 00 00 63 00 00 00 64 00 00 00 |a...b...c...d...| + """ + + ### INITIALIZER ### + + def __init__(self, address, *contents) -> None: + if isinstance(address, enum.Enum): + address = address.value + if not isinstance(address, (str, int)): + raise ValueError(f"address must be int or str, got {address}") + self.address = address + self.contents = tuple(contents) + + ### SPECIAL METHODS ### + + def __eq__(self, other) -> bool: + if type(self) is not type(other): + return False + if self.address != other.address: + return False + if self.contents != other.contents: + return False + return True + + def __repr__(self) -> str: + return "{}({})".format( + type(self).__name__, + ", ".join(repr(_) for _ in [self.address, *self.contents]), + ) + + def __str__(self) -> str: + return format_datagram(bytearray(self.to_datagram())) + + ### PRIVATE METHODS ### + + @staticmethod + def _decode_blob(data): + actual_length, remainder = struct.unpack(">I", data[:4])[0], data[4:] + padded_length = actual_length + if actual_length % 4 != 0: + padded_length = (actual_length // 4 + 1) * 4 + return remainder[:padded_length][:actual_length], remainder[padded_length:] + + @staticmethod + def _decode_string(data): + actual_length = data.index(b"\x00") + padded_length = (actual_length // 4 + 1) * 4 + return str(data[:actual_length], "ascii"), data[padded_length:] + + @staticmethod + def _encode_string(value): + result = bytes(value + "\x00", "ascii") + if len(result) % 4 != 0: + width = (len(result) // 4 + 1) * 4 + result = result.ljust(width, b"\x00") + return result + + @staticmethod + def _encode_blob(value): + result = bytes(struct.pack(">I", len(value)) + value) + if len(result) % 4 != 0: + width = (len(result) // 4 + 1) * 4 + result = result.ljust(width, b"\x00") + return result + + @classmethod + def _encode_value(cls, value): + if hasattr(value, "to_datagram"): + value = bytearray(value.to_datagram()) + elif isinstance(value, enum.Enum): + value = value.value + type_tags, encoded_value = "", b"" + if isinstance(value, (bytearray, bytes)): + type_tags += "b" + encoded_value = cls._encode_blob(value) + elif isinstance(value, str): + type_tags += "s" + encoded_value = cls._encode_string(value) + elif isinstance(value, bool): + type_tags += "T" if value else "F" + elif isinstance(value, float): + type_tags += "f" + encoded_value += struct.pack(">f", value) + elif isinstance(value, int): + type_tags += "i" + encoded_value += struct.pack(">i", value) + elif value is None: + type_tags += "N" + elif isinstance(value, SequenceABC): + type_tags += "[" + for sub_value in value: + sub_type_tags, sub_encoded_value = cls._encode_value(sub_value) + type_tags += sub_type_tags + encoded_value += sub_encoded_value + type_tags += "]" + else: + message = "Cannot encode {!r}".format(value) + raise TypeError(message) + return type_tags, encoded_value + + ### PUBLIC METHODS ### + + def to_datagram(self) -> bytes: + # address can be a string or (in SuperCollider) an int + if isinstance(self.address, str): + encoded_address = self._encode_string(self.address) + else: + encoded_address = struct.pack(">i", self.address) + encoded_type_tags = "," + encoded_contents = b"" + for value in self.contents or (): + type_tags, encoded_value = self._encode_value(value) + encoded_type_tags += type_tags + encoded_contents += encoded_value + return ( + encoded_address + self._encode_string(encoded_type_tags) + encoded_contents + ) + + @classmethod + def from_datagram(cls, datagram): + remainder = datagram + address, remainder = cls._decode_string(remainder) + type_tags, remainder = cls._decode_string(remainder) + contents = [] + array_stack = [contents] + for type_tag in type_tags[1:]: + if type_tag == "i": + value, remainder = struct.unpack(">i", remainder[:4])[0], remainder[4:] + array_stack[-1].append(value) + elif type_tag == "f": + value, remainder = struct.unpack(">f", remainder[:4])[0], remainder[4:] + array_stack[-1].append(value) + elif type_tag == "d": + value, remainder = struct.unpack(">d", remainder[:8])[0], remainder[8:] + array_stack[-1].append(value) + elif type_tag == "s": + value, remainder = cls._decode_string(remainder) + array_stack[-1].append(value) + elif type_tag == "b": + value, remainder = cls._decode_blob(remainder) + for class_ in (OscBundle, OscMessage): + try: + value = class_.from_datagram(value) + break + except Exception: + pass + array_stack[-1].append(value) + elif type_tag == "T": + array_stack[-1].append(True) + elif type_tag == "F": + array_stack[-1].append(False) + elif type_tag == "N": + array_stack[-1].append(None) + elif type_tag == "[": + array = [] + array_stack[-1].append(array) + array_stack.append(array) + elif type_tag == "]": + array_stack.pop() + else: + raise RuntimeError(f"Unable to parse type {type_tag!r}") + return cls(address, *contents) + + def to_list(self): + result = [self.address] + for x in self.contents: + if hasattr(x, "to_list"): + result.append(x.to_list()) + else: + result.append(x) + return result + + +class OscBundle: + """ + An OSC bundle. + + :: + + >>> import supriya.osc + >>> message_one = supriya.osc.OscMessage("/one", 1) + >>> message_two = supriya.osc.OscMessage("/two", 2) + >>> message_three = supriya.osc.OscMessage("/three", 3) + + :: + + >>> inner_bundle = supriya.osc.OscBundle( + ... timestamp=1401557034.5, + ... contents=(message_one, message_two), + ... ) + >>> inner_bundle + OscBundle( + contents=( + OscMessage('/one', 1), + OscMessage('/two', 2), + ), + timestamp=1401557034.5, + ) + + :: + + >>> print(inner_bundle) + size 56 + 0 23 62 75 6e 64 6c 65 00 d7 34 8e aa 80 00 00 00 |#bundle..4......| + 16 00 00 00 10 2f 6f 6e 65 00 00 00 00 2c 69 00 00 |..../one....,i..| + 32 00 00 00 01 00 00 00 10 2f 74 77 6f 00 00 00 00 |......../two....| + 48 2c 69 00 00 00 00 00 02 |,i......| + + :: + + >>> outer_bundle = supriya.osc.OscBundle( + ... contents=(inner_bundle, message_three), + ... ) + >>> outer_bundle + OscBundle( + contents=( + OscBundle( + contents=( + OscMessage('/one', 1), + OscMessage('/two', 2), + ), + timestamp=1401557034.5, + ), + OscMessage('/three', 3), + ), + ) + + :: + + >>> print(outer_bundle) + size 96 + 0 23 62 75 6e 64 6c 65 00 00 00 00 00 00 00 00 01 |#bundle.........| + 16 00 00 00 38 23 62 75 6e 64 6c 65 00 d7 34 8e aa |...8#bundle..4..| + 32 80 00 00 00 00 00 00 10 2f 6f 6e 65 00 00 00 00 |......../one....| + 48 2c 69 00 00 00 00 00 01 00 00 00 10 2f 74 77 6f |,i........../two| + 64 00 00 00 00 2c 69 00 00 00 00 00 02 00 00 00 10 |....,i..........| + 80 2f 74 68 72 65 65 00 00 2c 69 00 00 00 00 00 03 |/three..,i......| + + :: + + >>> datagram = outer_bundle.to_datagram() + + :: + + >>> decoded_bundle = supriya.osc.OscBundle.from_datagram(datagram) + >>> decoded_bundle + OscBundle( + contents=( + OscBundle( + contents=( + OscMessage('/one', 1), + OscMessage('/two', 2), + ), + timestamp=1401557034.5, + ), + OscMessage('/three', 3), + ), + ) + + :: + + >>> decoded_bundle == outer_bundle + True + """ + + ### INITIALIZER ### + + def __init__(self, timestamp=None, contents=None) -> None: + prototype = (OscMessage, type(self)) + self.timestamp = timestamp + contents = contents or () + for x in contents or (): + if not isinstance(x, prototype): + raise ValueError(contents) + self.contents = tuple(contents) + + ### SPECIAL METHODS ### + + def __eq__(self, other) -> bool: + if type(self) is not type(other): + return False + if self.timestamp != other.timestamp: + return False + if self.contents != other.contents: + return False + return True + + def __repr__(self) -> str: + return get_repr(self) + + def __str__(self) -> str: + return format_datagram(bytearray(self.to_datagram())) + + ### PRIVATE METHODS ### + + @staticmethod + def _decode_date(data): + data, remainder = data[:8], data[8:] + if data == IMMEDIATELY: + return None, remainder + date = (struct.unpack(">Q", data)[0] / SECONDS_TO_NTP_TIMESTAMP) - NTP_DELTA + return date, remainder + + @staticmethod + def _encode_date(seconds, realtime=True): + if seconds is None: + return IMMEDIATELY + if realtime: + seconds = seconds + NTP_DELTA + if seconds >= 4294967296: # 2**32 + seconds = seconds % 4294967296 + return struct.pack(">Q", int(seconds * SECONDS_TO_NTP_TIMESTAMP)) + + ### PUBLIC METHODS ### + + @classmethod + def from_datagram(cls, datagram): + if not datagram.startswith(BUNDLE_PREFIX): + raise ValueError("datagram is not a bundle") + remainder = datagram[8:] + timestamp, remainder = cls._decode_date(remainder) + contents = [] + while len(remainder): + length, remainder = struct.unpack(">i", remainder[:4])[0], remainder[4:] + if remainder.startswith(BUNDLE_PREFIX): + item = cls.from_datagram(remainder[:length]) + else: + item = OscMessage.from_datagram(remainder[:length]) + contents.append(item) + remainder = remainder[length:] + osc_bundle = cls(timestamp=timestamp, contents=tuple(contents)) + return osc_bundle + + @classmethod + def partition(cls, messages, timestamp=None): + bundles = [] + contents = [] + message = collections.deque(messages) + remaining = maximum = 8192 - len(BUNDLE_PREFIX) - 4 + while messages: + message = messages.popleft() + datagram = message.to_datagram() + remaining -= len(datagram) + 4 + if remaining > 0: + contents.append(message) + else: + bundles.append(cls(timestamp=timestamp, contents=contents)) + contents = [message] + remaining = maximum + if contents: + bundles.append(cls(timestamp=timestamp, contents=contents)) + return bundles + + def to_datagram(self, realtime=True) -> bytes: + datagram = BUNDLE_PREFIX + datagram += self._encode_date(self.timestamp, realtime=realtime) + for content in self.contents: + content_datagram = content.to_datagram() + datagram += struct.pack(">i", len(content_datagram)) + datagram += content_datagram + return datagram + + def to_list(self): + result = [self.timestamp] + result.append([x.to_list() for x in self.contents]) + return result diff --git a/supriya/osc/protocols.py b/supriya/osc/protocols.py new file mode 100644 index 000000000..169ddc376 --- /dev/null +++ b/supriya/osc/protocols.py @@ -0,0 +1,383 @@ +import contextlib +import dataclasses +import enum +import logging +import socket +import time +from collections.abc import Sequence as SequenceABC +from typing import ( + Any, + Awaitable, + Callable, + Dict, + List, + NamedTuple, + Optional, + Sequence, + Set, + Tuple, + Union, +) + +from ..typing import FutureLike +from .messages import OscBundle, OscMessage + +osc_protocol_logger = logging.getLogger(__name__) +osc_in_logger = logging.getLogger("supriya.osc.in") +osc_out_logger = logging.getLogger("supriya.osc.out") +udp_in_logger = logging.getLogger("supriya.udp.in") +udp_out_logger = logging.getLogger("supriya.udp.out") + + +def find_free_port(): + with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_DGRAM)) as s: + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] + + +class OscProtocolOffline(Exception): + pass + + +class OscProtocolAlreadyConnected(Exception): + pass + + +class OscCallback(NamedTuple): + pattern: Tuple[Union[str, int, float], ...] + procedure: Callable + failure_pattern: Optional[Tuple[Union[str, int, float], ...]] = None + once: bool = False + args: Optional[Tuple] = None + kwargs: Optional[Dict] = None + + +@dataclasses.dataclass +class HealthCheck: + request_pattern: List[str] + response_pattern: List[str] + active: bool = True + timeout: float = 1.0 + backoff_factor: float = 1.5 + max_attempts: int = 5 + + +class BootStatus(enum.IntEnum): + OFFLINE = 0 + BOOTING = 1 + ONLINE = 2 + QUITTING = 3 + + +class CaptureEntry(NamedTuple): + timestamp: float + label: str + message: Union[OscMessage, OscBundle] + + +class Capture: + ### INITIALIZER ### + + def __init__(self, osc_protocol): + self.osc_protocol = osc_protocol + self.messages = [] + + ### SPECIAL METHODS ### + + def __enter__(self): + self.osc_protocol.captures.add(self) + self.messages[:] = [] + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.osc_protocol.captures.remove(self) + + def __iter__(self): + return iter(self.messages) + + def __len__(self): + return len(self.messages) + + ### PUBLIC METHODS ### + + def filtered( + self, sent=True, received=True, status=True + ) -> List[Union[OscBundle, OscMessage]]: + messages = [] + for _, label, message in self.messages: + if label == "R" and not received: + continue + if label == "S" and not sent: + continue + if ( + isinstance(message, OscMessage) + and message.address in ("/status", "/status.reply") + and not status + ): + continue + messages.append(message) + return messages + + ### PUBLIC PROPERTIES ### + + @property + def received_messages(self): + return [ + (timestamp, osc_message) + for timestamp, label, osc_message in self.messages + if label == "R" + ] + + @property + def sent_messages(self): + return [ + (timestamp, osc_message) + for timestamp, label, osc_message in self.messages + if label == "S" + ] + + +class OscProtocol: + ### INITIALIZER ### + + def __init__( + self, + *, + boot_future: FutureLike[bool], + exit_future: FutureLike[bool], + name: Optional[str] = None, + on_connect_callback: Optional[Callable] = None, + on_disconnect_callback: Optional[Callable] = None, + on_panic_callback: Optional[Callable] = None, + ) -> None: + self.callbacks: Dict[Any, Any] = {} + self.captures: Set[Capture] = set() + self.boot_future = boot_future + self.exit_future = exit_future + self.healthcheck: Optional[HealthCheck] = None + self.healthcheck_osc_callback: Optional[OscCallback] = None + self.attempts = 0 + self.ip_address = "127.0.0.1" + self.name = name + self.port = 57551 + self.on_connect_callback = on_connect_callback + self.on_disconnect_callback = on_disconnect_callback + self.on_panic_callback = on_panic_callback + self.status = BootStatus.OFFLINE + + ### PRIVATE METHODS ### + + def _activate_healthcheck(self) -> bool: + if not self.healthcheck: + return False + elif self.healthcheck.active: + return False + osc_protocol_logger.info( + f"[{self.ip_address}:{self.port}/{self.name or hex(id(self))}] " + "activating healthcheck..." + ) + return True + + def _add_callback(self, callback: OscCallback) -> None: + patterns = [callback.pattern] + if callback.failure_pattern: + patterns.append(callback.failure_pattern) + for pattern in patterns: + callback_map = self.callbacks + for item in pattern: + callbacks, callback_map = callback_map.setdefault(item, ([], {})) + callbacks.append(callback) + + def _disconnect(self, panicked: bool = False) -> Optional[Awaitable[None]]: + if panicked: + osc_protocol_logger.info( + f"[{self.ip_address}:{self.port}/{self.name or hex(id(self))}] " + "panicking ..." + ) + else: + osc_protocol_logger.info( + f"[{self.ip_address}:{self.port}/{self.name or hex(id(self))}] " + "disconnecting ..." + ) + self.status = BootStatus.QUITTING + if self.healthcheck_osc_callback is not None: + self.unregister(self.healthcheck_osc_callback) + return None + + def _match_callbacks(self, message) -> List[OscCallback]: + items = (message.address,) + message.contents + matching_callbacks = [] + callback_map = self.callbacks + for item in items: + if item not in callback_map: + break + callbacks, callback_map = callback_map[item] + matching_callbacks.extend(callbacks) + for callback in matching_callbacks: + if callback.once: + self.unregister(callback) + return matching_callbacks + + def _on_connect(self) -> Optional[Awaitable[None]]: + osc_protocol_logger.info( + f"[{self.ip_address}:{self.port}/{self.name or hex(id(self))}] " + "... connected!" + ) + self.status = BootStatus.ONLINE + self.boot_future.set_result(True) + return None + + def _on_disconnect(self, panicked: bool = False) -> Optional[Awaitable[None]]: + osc_protocol_logger.info( + f"[{self.ip_address}:{self.port}/{self.name or hex(id(self))}] " + "... disconnected!" + ) + self.status = BootStatus.OFFLINE + if not self.boot_future.done(): + self.boot_future.set_result(False) + if not self.exit_future.done(): + self.exit_future.set_result(not panicked) + return None + + def _on_healthcheck_passed(self, message: OscMessage) -> Optional[Awaitable[None]]: + osc_protocol_logger.info( + f"[{self.ip_address}:{self.port}/{self.name or hex(id(self))}] " + "healthcheck: passed" + ) + self.attempts = 0 + return None + + def _remove_callback(self, callback: OscCallback) -> None: + def delete(pattern, original_callback_map): + key = pattern.pop(0) + if key not in original_callback_map: + return + callbacks, callback_map = original_callback_map[key] + if pattern: + delete(pattern, callback_map) + if callback in callbacks: + callbacks.remove(callback) + if not callbacks and not callback_map: + original_callback_map.pop(key) + + patterns = [callback.pattern] + if callback.failure_pattern: + patterns.append(callback.failure_pattern) + for pattern in patterns: + delete(list(pattern), self.callbacks) + + def _register( + self, + pattern, + procedure, + *, + failure_pattern=None, + once: bool = False, + args: Optional[Tuple] = None, + kwargs: Optional[Dict] = None, + ) -> OscCallback: + if isinstance(pattern, (str, int, float)): + pattern = [pattern] + if isinstance(failure_pattern, (str, int, float)): + failure_pattern = [failure_pattern] + if not callable(procedure): + raise ValueError(procedure) + osc_protocol_logger.info( + f"[{self.ip_address}:{self.port}/{self.name or hex(id(self))}] " + f"registering pattern: {pattern!r}" + ) + return OscCallback( + pattern=tuple(pattern), + failure_pattern=failure_pattern, + procedure=procedure, + once=bool(once), + args=args, + kwargs=kwargs, + ) + + def _send(self, message): + if self.status not in (BootStatus.BOOTING, BootStatus.ONLINE): + raise OscProtocolOffline + if not isinstance(message, (str, SequenceABC, OscBundle, OscMessage)): + raise ValueError(message) + if isinstance(message, str): + message = OscMessage(message) + elif isinstance(message, SequenceABC): + message = OscMessage(*message) + osc_out_logger.debug( + f"[{self.ip_address}:{self.port}/{self.name or hex(id(self))}] " + f"{message!r}" + ) + for capture in self.captures: + capture.messages.append( + CaptureEntry(timestamp=time.time(), label="S", message=message) + ) + datagram = message.to_datagram() + udp_out_logger.debug( + f"[{self.ip_address}:{self.port}/{self.name or hex(id(self))}] " + f"{datagram}" + ) + return datagram + + def _setup( + self, ip_address: str, port: int, healthcheck: Optional[HealthCheck] + ) -> None: + self.status = BootStatus.BOOTING + self.ip_address = ip_address + self.port = port + self.healthcheck = healthcheck + if self.healthcheck: + self.healthcheck_osc_callback = self.register( + pattern=self.healthcheck.response_pattern, + procedure=self._on_healthcheck_passed, + ) + + def _validate_receive(self, datagram): + udp_in_logger.debug( + f"[{self.ip_address}:{self.port}/{self.name or hex(id(self))}] " + f"{datagram}" + ) + try: + message = OscMessage.from_datagram(datagram) + except Exception: + raise + osc_in_logger.debug( + f"[{self.ip_address}:{self.port}/{self.name or hex(id(self))}] " + f"{message!r}" + ) + for capture in self.captures: + capture.messages.append( + CaptureEntry(timestamp=time.time(), label="R", message=message) + ) + for callback in self._match_callbacks(message): + yield callback, message + + ### PUBLIC METHODS ### + + def activate_healthcheck(self) -> None: + raise NotImplementedError + + def capture(self) -> "Capture": + return Capture(self) + + def disconnect(self) -> Optional[Awaitable[None]]: + raise NotImplementedError + + def register( + self, + pattern: Sequence[Union[str, float]], + procedure: Callable[[OscMessage], Optional[Awaitable[None]]], + *, + failure_pattern: Optional[Sequence[Union[str, float]]] = None, + once: bool = False, + args: Optional[Tuple] = None, + kwargs: Optional[Dict] = None, + ) -> OscCallback: + raise NotImplementedError + + def send(self, message: Union[OscBundle, OscMessage, SequenceABC, str]) -> None: + raise NotImplementedError + + def unregister(self, callback: OscCallback) -> None: + raise NotImplementedError diff --git a/supriya/osc/threaded.py b/supriya/osc/threaded.py new file mode 100644 index 000000000..59860b1db --- /dev/null +++ b/supriya/osc/threaded.py @@ -0,0 +1,234 @@ +import concurrent.futures +import socketserver +import threading +import time +from collections.abc import Sequence as SequenceABC +from queue import Empty, Queue +from typing import ( + Awaitable, + Callable, + Dict, + Literal, + Optional, + Sequence, + Tuple, + Union, + cast, +) + +from .messages import OscBundle, OscMessage +from .protocols import ( + BootStatus, + HealthCheck, + OscCallback, + OscProtocol, + OscProtocolAlreadyConnected, + osc_protocol_logger, +) + + +class ThreadedOscProtocol(OscProtocol): + + class Server(socketserver.UDPServer): + osc_protocol: "ThreadedOscProtocol" + + def verify_request(self, request, client_address) -> bool: + self.osc_protocol._process_command_queue() + return True + + def service_actions(self) -> None: + if cast(HealthCheck, self.osc_protocol.healthcheck).active: + self.osc_protocol._run_healthcheck() + + class Handler(socketserver.BaseRequestHandler): + def handle(self) -> None: + data = self.request[0] + for callback, message in cast( + ThreadedOscProtocol.Server, self.server + ).osc_protocol._validate_receive(data): + callback.procedure( + message, *(callback.args or ()), **(callback.kwargs or {}) + ) + + ### INITIALIZER ### + + def __init__( + self, + *, + name: Optional[str] = None, + on_connect_callback: Optional[Callable] = None, + on_disconnect_callback: Optional[Callable] = None, + on_panic_callback: Optional[Callable] = None, + ): + OscProtocol.__init__( + self, + boot_future=concurrent.futures.Future(), + exit_future=concurrent.futures.Future(), + name=name, + on_connect_callback=on_connect_callback, + on_disconnect_callback=on_disconnect_callback, + on_panic_callback=on_panic_callback, + ) + self.command_queue: Queue[Tuple[Literal["add", "remove"], OscCallback]] = ( + Queue() + ) + self.lock = threading.RLock() + self.osc_server = self._server_factory(self.ip_address, self.port) + self.osc_server_thread = threading.Thread(target=self.osc_server.serve_forever) + + ### PRIVATE METHODS ### + + def _disconnect(self, panicked: bool = False) -> None: + super()._disconnect(panicked=panicked) + if not self.osc_server._BaseServer__shutdown_request: + # We set the shutdown request flag rather than call .shutdown() + # because this is often being called from _inside_ the server + # thread. + self.osc_server._BaseServer__shutdown_request = True + self._on_disconnect(panicked=panicked) + + def _on_connect(self) -> None: + super()._on_connect() + if self.on_connect_callback: + self.on_connect_callback() + + def _on_disconnect(self, panicked: bool = False) -> None: + super()._on_disconnect(panicked=panicked) + if panicked and self.on_panic_callback: + self.on_panic_callback() + elif not panicked and self.on_disconnect_callback: + self.on_disconnect_callback() + + def _on_healthcheck_passed(self, message: OscMessage) -> None: + super()._on_healthcheck_passed(message) + if self.status == BootStatus.BOOTING: + self._on_connect() + + def _process_command_queue(self): + while self.command_queue.qsize(): + try: + action, callback = self.command_queue.get() + except Empty: + continue + if action == "add": + self._add_callback(callback) + elif action == "remove": + self._remove_callback(callback) + + def _run_healthcheck(self): + if self.healthcheck is None: + return + now = time.time() + if now < self.healthcheck_deadline: + return + if self.attempts > 0: + remaining = self.healthcheck.max_attempts - self.attempts + osc_protocol_logger.info( + f"[{self.ip_address}:{self.port}/{self.name or hex(id(self))}] " + f"healthcheck failed, {remaining} attempts remaining" + ) + new_timeout = self.healthcheck.timeout * pow( + self.healthcheck.backoff_factor, self.attempts + ) + self.healthcheck_deadline = now + new_timeout + self.attempts += 1 + if self.attempts <= self.healthcheck.max_attempts: + osc_protocol_logger.info( + f"[{self.ip_address}:{self.port}/{self.name or hex(id(self))}] " + "healthcheck: checking ..." + ) + self.send(OscMessage(*self.healthcheck.request_pattern)) + return + self._disconnect(panicked=True) + + def _server_factory(self, ip_address, port): + server = self.Server( + (self.ip_address, self.port), self.Handler, bind_and_activate=False + ) + server.osc_protocol = self + return server + + ### PUBLIC METHODS ### + + def activate_healthcheck(self) -> None: + if self._activate_healthcheck(): + cast(HealthCheck, self.healthcheck).active = True + + def connect( + self, ip_address: str, port: int, *, healthcheck: Optional[HealthCheck] = None + ): + if self.status != BootStatus.OFFLINE: + osc_protocol_logger.info( + f"[{self.ip_address}:{self.port}/{self.name or hex(id(self))}] " + "already connected!" + ) + raise OscProtocolAlreadyConnected + osc_protocol_logger.info( + f"[{self.ip_address}:{self.port}/{self.name or hex(id(self))}] " + "connecting ..." + ) + self._setup(ip_address, port, healthcheck) + self.healthcheck_deadline = time.time() + self.boot_future = concurrent.futures.Future() + self.exit_future = concurrent.futures.Future() + self.osc_server = self._server_factory(ip_address, port) + self.osc_server_thread = threading.Thread(target=self.osc_server.serve_forever) + self.osc_server_thread.daemon = True + self.osc_server_thread.start() + if not self.healthcheck: + self._on_connect() + + def disconnect(self) -> None: + if self.status != BootStatus.ONLINE: + osc_protocol_logger.info( + f"[{self.ip_address}:{self.port}/{self.name or hex(id(self))}] " + "already disconnected!" + ) + return + self._disconnect() + + def register( + self, + pattern: Sequence[Union[str, float]], + procedure: Callable[[OscMessage], Optional[Awaitable[None]]], + *, + failure_pattern: Optional[Sequence[Union[str, float]]] = None, + once: bool = False, + args: Optional[Tuple] = None, + kwargs: Optional[Dict] = None, + ) -> OscCallback: + """ + Register a callback. + """ + # Command queue prevents lock contention. + self.command_queue.put( + ( + "add", + callback := self._register( + pattern, + procedure, + failure_pattern=failure_pattern, + once=once, + args=args, + kwargs=kwargs, + ), + ) + ) + return callback + + def send(self, message: Union[OscBundle, OscMessage, SequenceABC, str]) -> None: + try: + self.osc_server.socket.sendto( + self._send(message), + (self.ip_address, self.port), + ) + except OSError: + # print(message) + raise + + def unregister(self, callback: OscCallback) -> None: + """ + Unregister a callback. + """ + # Command queue prevents lock contention. + self.command_queue.put(("remove", callback)) diff --git a/supriya/scsynth.py b/supriya/scsynth.py index e93b1f873..de454dec6 100644 --- a/supriya/scsynth.py +++ b/supriya/scsynth.py @@ -1,20 +1,23 @@ import asyncio import atexit +import concurrent.futures import enum import logging import os import platform -import signal +import shlex import subprocess -import time +import threading from dataclasses import dataclass from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import IO, Callable, Dict, List, Optional, Tuple, cast +import psutil import uqbar.io import uqbar.objects from .exceptions import ServerCannotBoot +from .typing import FutureLike logger = logging.getLogger(__name__) @@ -238,16 +241,10 @@ def find(scsynth_path=None): def kill(): - with subprocess.Popen( - ["ps", "-Af"], stdout=subprocess.PIPE, stderr=subprocess.STDOUT - ) as process: - output = process.stdout.read() - for line in output.decode().splitlines(): - parts = line.split() - if not any(part in ["supernova", "scsynth"] for part in parts): - continue - pid = int(parts[1]) - os.kill(pid, signal.SIGKILL) + for process in psutil.process_iter(): + if process.name() in ("scsynth", "supernova", "scsynth.exe", "supernova.exe"): + logger.info(f"killing {process!r}") + process.kill() class LineStatus(enum.IntEnum): @@ -256,159 +253,302 @@ class LineStatus(enum.IntEnum): ERROR = 2 -class ProcessProtocol: - def __init__(self): - self.is_running = False - - def boot(self, options: Options): - raise NotImplementedError +class BootStatus(enum.IntEnum): + OFFLINE = 0 + BOOTING = 1 + ONLINE = 2 + QUITTING = 3 - def quit(self): - raise NotImplementedError - def _handle_line(self, line): - if line.startswith("late:"): - logger.warning(f"Received: {line}") - elif "error" in line.lower() or "exception" in line.lower(): - logger.error(f"Received: {line}") +class ProcessProtocol: + def __init__( + self, + *, + name: Optional[str] = None, + on_boot_callback: Optional[Callable] = None, + on_panic_callback: Optional[Callable] = None, + on_quit_callback: Optional[Callable] = None, + ) -> None: + self.buffer_ = "" + self.error_text = "" + self.name = name + self.on_boot_callback = on_boot_callback + self.on_panic_callback = on_panic_callback + self.on_quit_callback = on_quit_callback + self.status = BootStatus.OFFLINE + self.options = Options() + + def _boot(self, options: Options) -> bool: + self.options = options + logger.info( + f"[{self.options.ip_address}:{self.options.port}/{self.name or hex(id(self))}] " + "booting ..." + ) + if self.status != BootStatus.OFFLINE: + logger.info( + f"[{self.options.ip_address}:{self.options.port}/{self.name or hex(id(self))}] " + "... already booted!" + ) + return False + self.status = BootStatus.BOOTING + self.error_text = "" + self.buffer_ = "" + logger.info( + f"[{self.options.ip_address}:{self.options.port}/{self.name or hex(id(self))}] " + "command: {}".format(shlex.join(options)) + ) + return True + + def _handle_data_received( + self, + *, + boot_future: FutureLike[bool], + text: str, + ) -> Tuple[bool, bool]: + resolved = False + errored = False + if "\n" in text: + text, _, self.buffer_ = text.rpartition("\n") + for line in text.splitlines(): + line_status = self._parse_line(line) + if line_status == LineStatus.READY: + boot_future.set_result(True) + self.status = BootStatus.ONLINE + resolved = True + logger.info( + f"[{self.options.ip_address}:{self.options.port}/{self.name or hex(id(self))}] " + "... booted!" + ) + elif line_status == LineStatus.ERROR: + if not boot_future.done(): + boot_future.set_result(False) + self.status = BootStatus.OFFLINE + self.error_text = line + resolved = True + errored = True + logger.info("... failed to boot!") else: - logger.info(f"Received: {line}") + self.buffer_ = text + return resolved, errored + + def _parse_line(self, line: str) -> LineStatus: + logger.info( + f"[{self.options.ip_address}:{self.options.port}/{self.name or hex(id(self))}] " + f"received: {line}" + ) if line.startswith(("SuperCollider 3 server ready", "Supernova ready")): return LineStatus.READY elif line.startswith(("Exception", "ERROR", "*** ERROR")): return LineStatus.ERROR return LineStatus.CONTINUE + def _quit(self) -> bool: + logger.info( + f"[{self.options.ip_address}:{self.options.port}/{self.name or hex(id(self))}] " + "quitting ..." + ) + if self.status != BootStatus.ONLINE: + logger.info( + f"[{self.options.ip_address}:{self.options.port}/{self.name or hex(id(self))}] " + "... already quit!" + ) + return False + self.status = BootStatus.QUITTING + return True + class SyncProcessProtocol(ProcessProtocol): - def __init__(self): - super().__init__() + def __init__( + self, + *, + name: Optional[str] = None, + on_boot_callback: Optional[Callable] = None, + on_panic_callback: Optional[Callable] = None, + on_quit_callback: Optional[Callable] = None, + ) -> None: + super().__init__( + name=name, + on_boot_callback=on_boot_callback, + on_panic_callback=on_panic_callback, + on_quit_callback=on_quit_callback, + ) atexit.register(self.quit) + self.boot_future: concurrent.futures.Future[bool] = concurrent.futures.Future() + self.exit_future: concurrent.futures.Future[int] = concurrent.futures.Future() + + def _run_process_thread(self, options: Options) -> None: + self.process = subprocess.Popen( + list(options), + stderr=subprocess.STDOUT, + stdout=subprocess.PIPE, + start_new_session=True, + ) + read_thread = threading.Thread( + args=(), + daemon=True, + target=self._run_read_thread, + ) + read_thread.start() + self.process.wait() + was_quitting = self.status == BootStatus.QUITTING + self.status = BootStatus.OFFLINE + self.exit_future.set_result(self.process.returncode) + if not self.boot_future.done(): + self.boot_future.set_result(False) + if was_quitting and self.on_quit_callback: + self.on_quit_callback() + elif not was_quitting and self.on_panic_callback: + self.on_panic_callback() + + def _run_read_thread(self) -> None: + while self.status == BootStatus.BOOTING: + if not (text := cast(IO[bytes], self.process.stdout).readline().decode()): + continue + _, _ = self._handle_data_received(boot_future=self.boot_future, text=text) + while self.status == BootStatus.ONLINE: + if not (text := cast(IO[bytes], self.process.stdout).readline().decode()): + continue + # we can capture /g_dumpTree output here + # do something + + def _shutdown(self) -> None: + self.process.terminate() + self.thread.join() + self.status = BootStatus.OFFLINE - def boot(self, options: Options): - if self.is_running: + def boot(self, options: Options) -> None: + if not self._boot(options): return - try: - logger.info("Boot: {}".format(*options)) - self.process = subprocess.Popen( - list(options), - stderr=subprocess.STDOUT, - stdout=subprocess.PIPE, - start_new_session=True, - ) - start_time = time.time() - timeout = 10 - while True: - line = self.process.stdout.readline().decode().rstrip() # type: ignore - if not line: - continue - line_status = self._handle_line(line) - if line_status == LineStatus.READY: - break - elif line_status == LineStatus.ERROR: - raise ServerCannotBoot(line) - elif (time.time() - start_time) > timeout: - raise ServerCannotBoot(line) - self.is_running = True - except ServerCannotBoot: - self.process.terminate() - self.process.wait() - raise + self.boot_future = concurrent.futures.Future() + self.exit_future = concurrent.futures.Future() + self.thread = threading.Thread( + args=(options,), + daemon=True, + target=self._run_process_thread, + ) + self.thread.start() + if not (self.boot_future.result()): + self._shutdown() + raise ServerCannotBoot(self.error_text) + if self.on_boot_callback: + self.on_boot_callback() def quit(self) -> None: - if not self.is_running: + if not self._quit(): return - self.process.terminate() - self.process.wait() - self.is_running = False + self._shutdown() + self.exit_future.result() + logger.info( + f"[{self.options.ip_address}:{self.options.port}/{self.name or hex(id(self))}] " + "... quit!" + ) class AsyncProcessProtocol(asyncio.SubprocessProtocol, ProcessProtocol): ### INITIALIZER ### - def __init__(self): - ProcessProtocol.__init__(self) + def __init__( + self, + *, + name: Optional[str] = None, + on_boot_callback: Optional[Callable] = None, + on_panic_callback: Optional[Callable] = None, + on_quit_callback: Optional[Callable] = None, + ) -> None: + ProcessProtocol.__init__( + self, + name=name, + on_boot_callback=on_boot_callback, + on_panic_callback=on_panic_callback, + on_quit_callback=on_quit_callback, + ) asyncio.SubprocessProtocol.__init__(self) - self.boot_future = asyncio.Future() - self.exit_future = asyncio.Future() - self.error_text = "" + self.boot_future: asyncio.Future[bool] = asyncio.Future() + self.exit_future: asyncio.Future[bool] = asyncio.Future() ### PUBLIC METHODS ### - async def boot(self, options: Options): - logger.info("Booting ...") - if self.is_running: - logger.info("... already booted!") + async def boot(self, options: Options) -> None: + if not self._boot(options): return - self.is_running = False loop = asyncio.get_running_loop() self.boot_future = loop.create_future() self.exit_future = loop.create_future() - self.error_text = "" - self.buffer_ = "" _, _ = await loop.subprocess_exec( lambda: self, *options, stdin=None, stderr=None ) if not (await self.boot_future): + await self.exit_future raise ServerCannotBoot(self.error_text) - - def connection_made(self, transport): - logger.info("Connection made!") - self.is_running = True + if self.on_boot_callback: + if asyncio.iscoroutine(result := self.on_boot_callback()): + loop.create_task(result) + + def connection_made(self, transport) -> None: + logger.info( + f"[{self.options.ip_address}:{self.options.port}/{self.name or hex(id(self))}] " + "connection made!" + ) self.transport = transport - def pipe_connection_lost(self, fd, exc): - logger.info("Pipe connection lost!") + def pipe_connection_lost(self, fd, exc) -> None: + logger.info( + f"[{self.options.ip_address}:{self.options.port}/{self.name or hex(id(self))}] " + "pipe connection lost!" + ) - def pipe_data_received(self, fd, data): + def pipe_data_received(self, fd, data) -> None: # *nix and OSX return full lines, # but Windows will return partial lines # which obligates us to reconstruct them. text = self.buffer_ + data.decode().replace("\r\n", "\n") - if "\n" in text: - text, _, self.buffer_ = text.rpartition("\n") - for line in text.splitlines(): - line_status = self._handle_line(line) - if line_status == LineStatus.READY: - self.boot_future.set_result(True) - logger.info("... booted!") - elif line_status == LineStatus.ERROR: - if not self.boot_future.done(): - self.boot_future.set_result(False) - self.error_text = line - logger.info("... failed to boot!") - else: - self.buffer_ = text + _, _ = self._handle_data_received(boot_future=self.boot_future, text=text) - def process_exited(self): - logger.info(f"Process exited with {self.transport.get_returncode()}.") - self.is_running = False + def process_exited(self) -> None: + return_code = self.transport.get_returncode() + logger.info( + f"[{self.options.ip_address}:{self.options.port}/{self.name or hex(id(self))}] " + f"process exited with {return_code}." + ) + was_quitting = self.status == BootStatus.QUITTING try: - self.exit_future.set_result(None) + self.exit_future.set_result(return_code) + self.status = BootStatus.OFFLINE if not self.boot_future.done(): self.boot_future.set_result(False) except asyncio.exceptions.InvalidStateError: pass - - async def quit(self): - logger.info("Quitting ...") - if not self.is_running: - logger.info("... already quit!") + if was_quitting and self.on_quit_callback: + if asyncio.iscoroutine(result := self.on_quit_callback()): + asyncio.get_running_loop().create_task(result) + elif not was_quitting and self.on_panic_callback: + if asyncio.iscoroutine(result := self.on_panic_callback()): + asyncio.get_running_loop().create_task(result) + + async def quit(self) -> None: + if not self._quit(): return - self.is_running = False self.transport.close() await self.exit_future - logger.info("... quit!") + logger.info( + f"[{self.options.ip_address}:{self.options.port}/{self.name or hex(id(self))}] " + "... quit!" + ) -class AsyncNonrealtimeProcessProtocol(asyncio.SubprocessProtocol): - def __init__(self, exit_future: asyncio.Future) -> None: - self.buffer_ = "" - self.exit_future = exit_future +class AsyncNonrealtimeProcessProtocol(asyncio.SubprocessProtocol, ProcessProtocol): + def __init__(self) -> None: + ProcessProtocol.__init__(self) + asyncio.SubprocessProtocol.__init__(self) + self.boot_future: asyncio.Future[bool] = asyncio.Future() + self.exit_future: asyncio.Future[bool] = asyncio.Future() async def run(self, command: List[str], render_directory_path: Path) -> None: - logger.info(f"Running: {' '.join(command)}") - _, _ = await asyncio.get_running_loop().subprocess_exec( + logger.info(f"running: {shlex.join(command)}") + loop = asyncio.get_running_loop() + self.boot_future = loop.create_future() + self.exit_future = loop.create_future() + _, _ = await loop.subprocess_exec( lambda: self, *command, stdin=None, @@ -416,27 +556,28 @@ async def run(self, command: List[str], render_directory_path: Path) -> None: start_new_session=True, cwd=render_directory_path, ) + await self.exit_future - def handle_line(self, line: str) -> None: - logger.debug(f"Received: {line}") - - def connection_made(self, transport): - logger.debug("Connecting") + def connection_made(self, transport) -> None: + logger.info("connection made!") self.transport = transport - def pipe_data_received(self, fd, data): - logger.debug(f"Data: {data}") + def pipe_connection_lost(self, fd, exc) -> None: + logger.info("pipe connection lost!") + + def pipe_data_received(self, fd, data) -> None: # *nix and OSX return full lines, # but Windows will return partial lines # which obligates us to reconstruct them. text = self.buffer_ + data.decode().replace("\r\n", "\n") - if "\n" in text: - text, _, self.buffer_ = text.rpartition("\n") - for line in text.splitlines(): - self.handle_line(line) - else: - self.buffer_ = text + _, _ = self._handle_data_received(boot_future=self.boot_future, text=text) - def process_exited(self): - logger.debug(f"Exiting with {self.transport.get_returncode()}") - self.exit_future.set_result(self.transport.get_returncode()) + def process_exited(self) -> None: + return_code = self.transport.get_returncode() + logger.info(f"process exited with {return_code}.") + self.exit_future.set_result(return_code) + try: + if not self.boot_future.done(): + self.boot_future.set_result(False) + except asyncio.exceptions.InvalidStateError: + pass diff --git a/supriya/typing.py b/supriya/typing.py index 5d3917587..c36c22045 100644 --- a/supriya/typing.py +++ b/supriya/typing.py @@ -1,3 +1,5 @@ +import asyncio +import concurrent.futures from os import PathLike from pathlib import Path from typing import ( @@ -73,10 +75,13 @@ def __render_memo__(self) -> SupportsRender: E = TypeVar("E") + _EnumLike = Optional[Union[E, SupportsInt, str, None]] + AddActionLike: TypeAlias = _EnumLike[AddAction] DoneActionLike: TypeAlias = _EnumLike[DoneAction] CalculationRateLike: TypeAlias = _EnumLike[CalculationRate] +FutureLike: TypeAlias = Union[concurrent.futures.Future[E], asyncio.Future[E]] ParameterRateLike: TypeAlias = _EnumLike[ParameterRate] RateLike: TypeAlias = _EnumLike[CalculationRate] EnvelopeShapeLike: TypeAlias = _EnumLike[EnvelopeShape] diff --git a/supriya/ugens/triggers.py b/supriya/ugens/triggers.py index 962726e25..c6d41035a 100644 --- a/supriya/ugens/triggers.py +++ b/supriya/ugens/triggers.py @@ -258,7 +258,7 @@ class Poll(UGen): :: - >>> server = supriya.Server().boot() + >>> server = supriya.Server().boot(port=supriya.osc.find_free_port()) >>> _ = server.add_synthdefs( ... synthdef, ... on_completion=lambda context: context.add_synth(synthdef), diff --git a/tests/book/conftest.py b/tests/book/conftest.py index 4a02fefe8..4d7d525c6 100644 --- a/tests/book/conftest.py +++ b/tests/book/conftest.py @@ -16,10 +16,9 @@ def remove_sphinx_projects(sphinx_test_tempdir) -> None: if Path(d, "_build").exists(): # This directory is a Sphinx project, remove it shutil.rmtree(str(d)) - yield @pytest.fixture() def rootdir(remove_sphinx_projects: None) -> Path: roots = Path(__file__).parent.absolute() / "roots" - yield roots + return roots diff --git a/tests/contexts/test_Server_lifecycle.py b/tests/contexts/test_Server_lifecycle.py index 3b31ca1b6..0d792e3d9 100644 --- a/tests/contexts/test_Server_lifecycle.py +++ b/tests/contexts/test_Server_lifecycle.py @@ -1,7 +1,7 @@ import asyncio import logging +import platform import random -import sys import pytest @@ -10,6 +10,7 @@ AsyncServer, BootStatus, Server, + ServerLifecycleEvent, ) from supriya.exceptions import ( OwnedServerShutdown, @@ -19,21 +20,36 @@ UnownedServerShutdown, ) from supriya.osc import find_free_port +from supriya.scsynth import kill supernova = pytest.param( "supernova", marks=pytest.mark.skipif( - sys.platform.startswith("win"), reason="Supernova won't boot on Windows" + platform.system() == "Windows", reason="Supernova won't boot on Windows" ), ) +def setup_context(context_class, name=None): + def on_event(event: ServerLifecycleEvent) -> None: + events.append(event) + + events = [] + context = context_class(name=name) + for event in ServerLifecycleEvent: + context.on(event, on_event) + return context, events + + async def get(x): if asyncio.iscoroutine(x): return await x return x +logger = logging.getLogger(__name__) + + @pytest.fixture(autouse=True) def use_caplog(caplog): caplog.set_level(logging.INFO) @@ -48,117 +64,229 @@ def healthcheck_attempts(monkeypatch): @pytest.mark.parametrize("executable", ["scsynth", supernova]) @pytest.mark.parametrize("context_class", [AsyncServer, Server]) async def test_boot_only(executable, context_class): - context = context_class() + context, events = setup_context(context_class) assert context.boot_status == BootStatus.OFFLINE assert not context.is_owner - result = context.boot(executable=executable) - if asyncio.iscoroutine(result): - await result + # + await get(context.boot(executable=executable)) assert context.boot_status == BootStatus.ONLINE assert context.is_owner + assert events == [ + ServerLifecycleEvent.BOOTING, + ServerLifecycleEvent.PROCESS_BOOTED, + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ServerLifecycleEvent.BOOTED, + ] + assert context.boot_future.done() + assert not context.exit_future.done() @pytest.mark.asyncio @pytest.mark.parametrize("executable", ["scsynth", supernova]) @pytest.mark.parametrize("context_class", [AsyncServer, Server]) async def test_boot_and_quit(executable, context_class): - context = context_class() + context, events = setup_context(context_class) assert context.boot_status == BootStatus.OFFLINE assert not context.is_owner - result = context.boot(executable=executable) - if asyncio.iscoroutine(result): - await result + # + await get(context.boot(executable=executable)) assert context.boot_status == BootStatus.ONLINE assert context.is_owner - result = context.quit() - if asyncio.iscoroutine(result): - await result + assert events == [ + ServerLifecycleEvent.BOOTING, + ServerLifecycleEvent.PROCESS_BOOTED, + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ServerLifecycleEvent.BOOTED, + ] + # + await get(context.quit()) assert context.boot_status == BootStatus.OFFLINE assert not context.is_owner + assert events == [ + ServerLifecycleEvent.BOOTING, + ServerLifecycleEvent.PROCESS_BOOTED, + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ServerLifecycleEvent.BOOTED, + ServerLifecycleEvent.QUITTING, + ServerLifecycleEvent.DISCONNECTING, + ServerLifecycleEvent.OSC_DISCONNECTED, + ServerLifecycleEvent.DISCONNECTED, + ServerLifecycleEvent.PROCESS_QUIT, + ServerLifecycleEvent.QUIT, + ] + assert context.boot_future.done() + assert context.exit_future.done() @pytest.mark.asyncio @pytest.mark.parametrize("executable", ["scsynth", supernova]) @pytest.mark.parametrize("context_class", [AsyncServer, Server]) async def test_boot_and_boot(executable, context_class): - context = context_class() + context, events = setup_context(context_class) assert context.boot_status == BootStatus.OFFLINE assert not context.is_owner - result = context.boot(executable=executable) - if asyncio.iscoroutine(result): - await result + # + await get(context.boot(executable=executable)) assert context.boot_status == BootStatus.ONLINE assert context.is_owner + assert events == [ + ServerLifecycleEvent.BOOTING, + ServerLifecycleEvent.PROCESS_BOOTED, + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ServerLifecycleEvent.BOOTED, + ] + # with pytest.raises(ServerOnline): - result = context.boot(executable=executable) - if asyncio.iscoroutine(result): - await result + await get(context.boot(executable=executable)) assert context.boot_status == BootStatus.ONLINE assert context.is_owner + assert events == [ + ServerLifecycleEvent.BOOTING, + ServerLifecycleEvent.PROCESS_BOOTED, + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ServerLifecycleEvent.BOOTED, + ] @pytest.mark.asyncio @pytest.mark.parametrize("executable", ["scsynth", supernova]) @pytest.mark.parametrize("context_class", [AsyncServer, Server]) async def test_boot_and_quit_and_quit(executable, context_class): - context = context_class() + context, events = setup_context(context_class) assert context.boot_status == BootStatus.OFFLINE assert not context.is_owner - result = context.boot(executable=executable) - if asyncio.iscoroutine(result): - await result + # + await get(context.boot(executable=executable)) assert context.boot_status == BootStatus.ONLINE assert context.is_owner - result = context.quit() - if asyncio.iscoroutine(result): - await result + assert events == [ + ServerLifecycleEvent.BOOTING, + ServerLifecycleEvent.PROCESS_BOOTED, + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ServerLifecycleEvent.BOOTED, + ] + # + await get(context.quit()) assert context.boot_status == BootStatus.OFFLINE assert not context.is_owner - result = context.quit() - if asyncio.iscoroutine(result): - await result + assert events == [ + ServerLifecycleEvent.BOOTING, + ServerLifecycleEvent.PROCESS_BOOTED, + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ServerLifecycleEvent.BOOTED, + ServerLifecycleEvent.QUITTING, + ServerLifecycleEvent.DISCONNECTING, + ServerLifecycleEvent.OSC_DISCONNECTED, + ServerLifecycleEvent.DISCONNECTED, + ServerLifecycleEvent.PROCESS_QUIT, + ServerLifecycleEvent.QUIT, + ] + # + await get(context.quit()) assert context.boot_status == BootStatus.OFFLINE assert not context.is_owner + assert events == [ + ServerLifecycleEvent.BOOTING, + ServerLifecycleEvent.PROCESS_BOOTED, + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ServerLifecycleEvent.BOOTED, + ServerLifecycleEvent.QUITTING, + ServerLifecycleEvent.DISCONNECTING, + ServerLifecycleEvent.OSC_DISCONNECTED, + ServerLifecycleEvent.DISCONNECTED, + ServerLifecycleEvent.PROCESS_QUIT, + ServerLifecycleEvent.QUIT, + ] @pytest.mark.asyncio @pytest.mark.parametrize("executable", ["scsynth", supernova]) @pytest.mark.parametrize("context_class", [AsyncServer, Server]) async def test_boot_and_connect(executable, context_class): - context = context_class() + context, events = setup_context(context_class) assert context.boot_status == BootStatus.OFFLINE assert not context.is_owner - result = context.boot(executable=executable) - if asyncio.iscoroutine(result): - await result + # + await get(context.boot(executable=executable)) assert context.boot_status == BootStatus.ONLINE assert context.is_owner + assert events == [ + ServerLifecycleEvent.BOOTING, + ServerLifecycleEvent.PROCESS_BOOTED, + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ServerLifecycleEvent.BOOTED, + ] + # with pytest.raises(ServerOnline): - result = context.connect() - if asyncio.iscoroutine(result): - await result + await get(context.connect()) assert context.boot_status == BootStatus.ONLINE assert context.is_owner + assert events == [ + ServerLifecycleEvent.BOOTING, + ServerLifecycleEvent.PROCESS_BOOTED, + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ServerLifecycleEvent.BOOTED, + ] @pytest.mark.asyncio @pytest.mark.parametrize("executable", ["scsynth", supernova]) @pytest.mark.parametrize("context_class", [AsyncServer, Server]) async def test_boot_a_and_boot_b_cannot_boot(executable, context_class): - context_a, context_b = context_class(), context_class() + context_a, events_a = setup_context(context_class, name="one") + context_b, events_b = setup_context(context_class, name="two") assert context_a.boot_status == BootStatus.OFFLINE and not context_a.is_owner assert context_b.boot_status == BootStatus.OFFLINE and not context_b.is_owner - result = context_a.boot(maximum_logins=4, executable=executable) - if asyncio.iscoroutine(result): - await result + # + await get(context_a.boot(maximum_logins=4, executable=executable)) assert context_a.boot_status == BootStatus.ONLINE and context_a.is_owner assert context_b.boot_status == BootStatus.OFFLINE and not context_b.is_owner + assert events_a == [ + ServerLifecycleEvent.BOOTING, + ServerLifecycleEvent.PROCESS_BOOTED, + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ServerLifecycleEvent.BOOTED, + ] + assert events_b == [] + # with pytest.raises(ServerCannotBoot): - result = context_b.boot(maximum_logins=4, executable=executable) - if asyncio.iscoroutine(result): - await result + await get(context_b.boot(maximum_logins=4, executable=executable)) assert context_a.boot_status == BootStatus.ONLINE and context_a.is_owner assert context_b.boot_status == BootStatus.OFFLINE and not context_b.is_owner + assert events_a == [ + ServerLifecycleEvent.BOOTING, + ServerLifecycleEvent.PROCESS_BOOTED, + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ServerLifecycleEvent.BOOTED, + ] + assert events_b == [ + ServerLifecycleEvent.BOOTING, + ServerLifecycleEvent.PROCESS_PANICKED, + ] # scsynth only @@ -166,140 +294,352 @@ async def test_boot_a_and_boot_b_cannot_boot(executable, context_class): @pytest.mark.parametrize("executable", ["scsynth", supernova]) @pytest.mark.parametrize("context_class", [AsyncServer, Server]) async def test_boot_a_and_connect_b_too_many_clients(executable, context_class): - context_a, context_b = context_class(), context_class() + context_a, events_a = setup_context(context_class, name="one") + context_b, events_b = setup_context(context_class, name="two") assert context_a.boot_status == BootStatus.OFFLINE and not context_a.is_owner assert context_b.boot_status == BootStatus.OFFLINE and not context_b.is_owner - result = context_a.boot(maximum_logins=1) - if asyncio.iscoroutine(result): - await result + # + await get(context_a.boot(maximum_logins=1)) assert context_a.boot_status == BootStatus.ONLINE and context_a.is_owner assert context_b.boot_status == BootStatus.OFFLINE and not context_b.is_owner + assert events_a == [ + ServerLifecycleEvent.BOOTING, + ServerLifecycleEvent.PROCESS_BOOTED, + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ServerLifecycleEvent.BOOTED, + ] + assert events_b == [] + # with pytest.raises(TooManyClients): - result = context_b.connect() - if asyncio.iscoroutine(result): - await result + await get(context_b.connect()) assert context_a.boot_status == BootStatus.ONLINE and context_a.is_owner assert context_b.boot_status == BootStatus.OFFLINE and not context_b.is_owner + assert events_a == [ + ServerLifecycleEvent.BOOTING, + ServerLifecycleEvent.PROCESS_BOOTED, + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ServerLifecycleEvent.BOOTED, + ] + assert events_b == [ + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.DISCONNECTING, + ServerLifecycleEvent.DISCONNECTED, + ] @pytest.mark.asyncio @pytest.mark.parametrize("executable", ["scsynth", supernova]) @pytest.mark.parametrize("context_class", [AsyncServer, Server]) async def test_boot_a_and_connect_b_and_quit_a(executable, context_class): - context_a, context_b = context_class(), context_class() + logger.warning("START") + context_a, events_a = setup_context(context_class, name="one") + context_b, events_b = setup_context(context_class, name="two") assert context_a.boot_status == BootStatus.OFFLINE and not context_a.is_owner assert context_b.boot_status == BootStatus.OFFLINE and not context_b.is_owner - result = context_a.boot(maximum_logins=2, executable=executable) - if asyncio.iscoroutine(result): - await result - result = context_b.connect() - if asyncio.iscoroutine(result): - await result + # + logger.warning("BOOT A") + await get(context_a.boot(maximum_logins=2, executable=executable)) + assert events_a == [ + ServerLifecycleEvent.BOOTING, + ServerLifecycleEvent.PROCESS_BOOTED, + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ServerLifecycleEvent.BOOTED, + ] + assert events_b == [] + # + logger.warning("CONNECT B") + await get(context_b.connect()) assert context_a.boot_status == BootStatus.ONLINE and context_a.is_owner assert context_b.boot_status == BootStatus.ONLINE and not context_b.is_owner - result = context_a.quit() - if asyncio.iscoroutine(result): - await result + assert events_a == [ + ServerLifecycleEvent.BOOTING, + ServerLifecycleEvent.PROCESS_BOOTED, + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ServerLifecycleEvent.BOOTED, + ] + assert events_b == [ + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ] + # + logger.warning("PROCESS_QUIT A") + await get(context_a.quit()) assert context_a.boot_status == BootStatus.OFFLINE and not context_a.is_owner + assert events_a == [ + ServerLifecycleEvent.BOOTING, + ServerLifecycleEvent.PROCESS_BOOTED, + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ServerLifecycleEvent.BOOTED, + ServerLifecycleEvent.QUITTING, + ServerLifecycleEvent.DISCONNECTING, + ServerLifecycleEvent.OSC_DISCONNECTED, + ServerLifecycleEvent.DISCONNECTED, + ServerLifecycleEvent.PROCESS_QUIT, + ServerLifecycleEvent.QUIT, + ] + assert events_b == [ + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ] + # + logger.warning("AWAIT B") for _ in range(100): await asyncio.sleep(0.1) if context_b.boot_status == BootStatus.OFFLINE: break + logger.warning("DONE AWAITING B") assert context_b.boot_status == BootStatus.OFFLINE and not context_b.is_owner + assert events_a == [ + ServerLifecycleEvent.BOOTING, + ServerLifecycleEvent.PROCESS_BOOTED, + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ServerLifecycleEvent.BOOTED, + ServerLifecycleEvent.QUITTING, + ServerLifecycleEvent.DISCONNECTING, + ServerLifecycleEvent.OSC_DISCONNECTED, + ServerLifecycleEvent.DISCONNECTED, + ServerLifecycleEvent.PROCESS_QUIT, + ServerLifecycleEvent.QUIT, + ] + assert events_b == [ + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ServerLifecycleEvent.OSC_PANICKED, + ServerLifecycleEvent.DISCONNECTING, + ServerLifecycleEvent.DISCONNECTED, + ] + logger.warning("END") @pytest.mark.asyncio @pytest.mark.parametrize("executable", ["scsynth", supernova]) @pytest.mark.parametrize("context_class", [AsyncServer, Server]) async def test_boot_a_and_connect_b_and_disconnect_b(executable, context_class): - context_a, context_b = context_class(), context_class() + context_a, events_a = setup_context(context_class, name="one") + context_b, events_b = setup_context(context_class, name="two") assert context_a.boot_status == BootStatus.OFFLINE and not context_a.is_owner assert context_b.boot_status == BootStatus.OFFLINE and not context_b.is_owner - result = context_a.boot(maximum_logins=2, executable=executable) - if asyncio.iscoroutine(result): - await result - result = context_b.connect() - if asyncio.iscoroutine(result): - await result + # + await get(context_a.boot(maximum_logins=2, executable=executable)) + await get(context_b.connect()) assert context_a.boot_status == BootStatus.ONLINE and context_a.is_owner assert context_b.boot_status == BootStatus.ONLINE and not context_b.is_owner - result = context_b.disconnect() - if asyncio.iscoroutine(result): - await result + assert events_a == [ + ServerLifecycleEvent.BOOTING, + ServerLifecycleEvent.PROCESS_BOOTED, + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ServerLifecycleEvent.BOOTED, + ] + assert events_b == [ + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ] + # + await get(context_b.disconnect()) assert context_a.boot_status == BootStatus.ONLINE and context_a.is_owner assert context_b.boot_status == BootStatus.OFFLINE and not context_b.is_owner + assert events_a == [ + ServerLifecycleEvent.BOOTING, + ServerLifecycleEvent.PROCESS_BOOTED, + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ServerLifecycleEvent.BOOTED, + ] + assert events_b == [ + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ServerLifecycleEvent.DISCONNECTING, + ServerLifecycleEvent.OSC_DISCONNECTED, + ServerLifecycleEvent.DISCONNECTED, + ] @pytest.mark.asyncio @pytest.mark.parametrize("executable", ["scsynth", supernova]) @pytest.mark.parametrize("context_class", [AsyncServer, Server]) async def test_boot_a_and_connect_b_and_disconnect_a(executable, context_class): - context_a, context_b = context_class(), context_class() + context_a, events_a = setup_context(context_class, name="one") + context_b, events_b = setup_context(context_class, name="two") assert context_a.boot_status == BootStatus.OFFLINE and not context_a.is_owner assert context_b.boot_status == BootStatus.OFFLINE and not context_b.is_owner - result = context_a.boot(maximum_logins=2, executable=executable) - if asyncio.iscoroutine(result): - await result - result = context_b.connect() - if asyncio.iscoroutine(result): - await result + # + await get(context_a.boot(maximum_logins=2, executable=executable)) + await get(context_b.connect()) assert context_a.boot_status == BootStatus.ONLINE and context_a.is_owner assert context_b.boot_status == BootStatus.ONLINE and not context_b.is_owner + assert events_a == [ + ServerLifecycleEvent.BOOTING, + ServerLifecycleEvent.PROCESS_BOOTED, + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ServerLifecycleEvent.BOOTED, + ] + assert events_b == [ + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ] + # with pytest.raises(OwnedServerShutdown): - result = context_a.disconnect() - if asyncio.iscoroutine(result): - await result + await get(context_a.disconnect()) assert context_a.boot_status == BootStatus.ONLINE and context_a.is_owner assert context_b.boot_status == BootStatus.ONLINE and not context_b.is_owner + assert events_a == [ + ServerLifecycleEvent.BOOTING, + ServerLifecycleEvent.PROCESS_BOOTED, + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ServerLifecycleEvent.BOOTED, + ] + assert events_b == [ + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ] @pytest.mark.asyncio @pytest.mark.parametrize("executable", ["scsynth", supernova]) @pytest.mark.parametrize("context_class", [AsyncServer, Server]) async def test_boot_a_and_connect_b_and_quit_b(executable, context_class): - context_a, context_b = context_class(), context_class() + context_a, events_a = setup_context(context_class, name="one") + context_b, events_b = setup_context(context_class, name="two") assert context_a.boot_status == BootStatus.OFFLINE and not context_a.is_owner assert context_b.boot_status == BootStatus.OFFLINE and not context_b.is_owner - result = context_a.boot(maximum_logins=2, executable=executable) - if asyncio.iscoroutine(result): - await result - result = context_b.connect() - if asyncio.iscoroutine(result): - await result + # + await get(context_a.boot(maximum_logins=2, executable=executable)) + await get(context_b.connect()) assert context_a.boot_status == BootStatus.ONLINE and context_a.is_owner assert context_b.boot_status == BootStatus.ONLINE and not context_b.is_owner + assert events_a == [ + ServerLifecycleEvent.BOOTING, + ServerLifecycleEvent.PROCESS_BOOTED, + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ServerLifecycleEvent.BOOTED, + ] + assert events_b == [ + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ] + # with pytest.raises(UnownedServerShutdown): - result = context_b.quit() - if asyncio.iscoroutine(result): - await result + await get(context_b.quit()) assert context_a.boot_status == BootStatus.ONLINE and context_a.is_owner assert context_b.boot_status == BootStatus.ONLINE and not context_b.is_owner + assert events_a == [ + ServerLifecycleEvent.BOOTING, + ServerLifecycleEvent.PROCESS_BOOTED, + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ServerLifecycleEvent.BOOTED, + ] + assert events_b == [ + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ] @pytest.mark.asyncio @pytest.mark.parametrize("executable", ["scsynth", supernova]) @pytest.mark.parametrize("context_class", [AsyncServer, Server]) async def test_boot_a_and_connect_b_and_force_quit_b(executable, context_class): - context_a, context_b = context_class(), context_class() + logger.warning("START") + context_a, events_a = setup_context(context_class, name="one") + context_b, events_b = setup_context(context_class, name="two") assert context_a.boot_status == BootStatus.OFFLINE and not context_a.is_owner assert context_b.boot_status == BootStatus.OFFLINE and not context_b.is_owner - result = context_a.boot(maximum_logins=2, executable=executable) - if asyncio.iscoroutine(result): - await result - result = context_b.connect() - if asyncio.iscoroutine(result): - await result + # + logger.warning("BOOT A") + await get(context_a.boot(maximum_logins=2, executable=executable)) + assert events_a == [ + ServerLifecycleEvent.BOOTING, + ServerLifecycleEvent.PROCESS_BOOTED, + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ServerLifecycleEvent.BOOTED, + ] + assert events_b == [] + # + logger.warning("CONNECT B") + await get(context_b.connect()) assert context_a.boot_status == BootStatus.ONLINE and context_a.is_owner assert context_b.boot_status == BootStatus.ONLINE and not context_b.is_owner - result = context_b.quit(force=True) - if asyncio.iscoroutine(result): - await result + assert events_a == [ + ServerLifecycleEvent.BOOTING, + ServerLifecycleEvent.PROCESS_BOOTED, + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ServerLifecycleEvent.BOOTED, + ] + assert events_b == [ + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ] + # + logger.warning("FORCE PROCESS_QUIT B") + await get(context_b.quit(force=True)) assert context_b.boot_status == BootStatus.OFFLINE and not context_b.is_owner + logger.warning("AWAIT A") for _ in range(100): await asyncio.sleep(0.1) if context_a.boot_status == BootStatus.OFFLINE: break assert context_a.boot_status == BootStatus.OFFLINE and not context_a.is_owner + assert events_a == [ + ServerLifecycleEvent.BOOTING, + ServerLifecycleEvent.PROCESS_BOOTED, + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ServerLifecycleEvent.BOOTED, + ServerLifecycleEvent.PROCESS_PANICKED, + ServerLifecycleEvent.OSC_PANICKED, + ServerLifecycleEvent.QUITTING, + ServerLifecycleEvent.DISCONNECTING, + ServerLifecycleEvent.DISCONNECTED, + ServerLifecycleEvent.QUIT, + ] + assert events_b == [ + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ServerLifecycleEvent.QUITTING, + ServerLifecycleEvent.DISCONNECTING, + ServerLifecycleEvent.OSC_DISCONNECTED, + ServerLifecycleEvent.DISCONNECTED, + ServerLifecycleEvent.QUIT, + ] + logger.warning("END") @pytest.mark.asyncio @@ -309,7 +649,7 @@ async def test_boot_reboot_sticky_options(executable, context_class): """ Options persist across booting and quitting. """ - context = context_class() + context, _ = setup_context(context_class) maximum_node_count = random.randint(1024, 2048) await get( context.boot(maximum_node_count=maximum_node_count, port=find_free_port()) @@ -321,3 +661,75 @@ async def test_boot_reboot_sticky_options(executable, context_class): assert context.options.maximum_node_count == maximum_node_count await get(context.quit()) assert context.options.maximum_node_count == maximum_node_count + + +@pytest.mark.asyncio +@pytest.mark.parametrize("executable", ["scsynth", supernova]) +@pytest.mark.parametrize("context_class", [AsyncServer, Server]) +async def test_boot_a_and_connect_b_and_kill(executable, context_class) -> None: + logger.warning("START") + context_a, events_a = setup_context(context_class, name="one") + context_b, events_b = setup_context(context_class, name="two") + logger.warning("BOOT A") + await get(context_a.boot(executable=executable, maximum_logins=2)) + assert events_a == [ + ServerLifecycleEvent.BOOTING, + ServerLifecycleEvent.PROCESS_BOOTED, + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ServerLifecycleEvent.BOOTED, + ] + assert events_b == [] + logger.warning("CONNECT B") + await get(context_b.connect()) + assert events_a == [ + ServerLifecycleEvent.BOOTING, + ServerLifecycleEvent.PROCESS_BOOTED, + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ServerLifecycleEvent.BOOTED, + ] + assert events_b == [ + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ] + logger.warning("KILL") + kill() + logger.warning("AWAIT A AND B") + for _ in range(100): + await asyncio.sleep(0.1) + if ( + context_a.boot_status == BootStatus.OFFLINE + and context_b.boot_status == BootStatus.OFFLINE + ): + break + assert events_a == [ + ServerLifecycleEvent.BOOTING, + ServerLifecycleEvent.PROCESS_BOOTED, + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ServerLifecycleEvent.BOOTED, + ServerLifecycleEvent.PROCESS_PANICKED, + ServerLifecycleEvent.OSC_PANICKED, + ServerLifecycleEvent.QUITTING, + ServerLifecycleEvent.DISCONNECTING, + ServerLifecycleEvent.DISCONNECTED, + ServerLifecycleEvent.QUIT, + ] + assert events_b == [ + ServerLifecycleEvent.CONNECTING, + ServerLifecycleEvent.OSC_CONNECTED, + ServerLifecycleEvent.CONNECTED, + ServerLifecycleEvent.OSC_PANICKED, + ServerLifecycleEvent.DISCONNECTING, + ServerLifecycleEvent.DISCONNECTED, + ] + assert context_a.boot_future.done() + assert context_a.exit_future.done() + assert context_b.boot_future.done() + assert context_b.exit_future.done() + logger.warning("END") diff --git a/tests/patterns/test_Pattern.py b/tests/patterns/test_Pattern.py index 9a7a645be..6202ed1da 100644 --- a/tests/patterns/test_Pattern.py +++ b/tests/patterns/test_Pattern.py @@ -216,22 +216,22 @@ def test_binary_ops(op, expr_one, expr_two, expected): ( operator.abs, SequencePattern([1, 2, 3]), - UnaryOpPattern(operator.abs, SequencePattern([1, 2, 3])), + UnaryOpPattern[int](operator.abs, SequencePattern([1, 2, 3])), ), ( operator.inv, SequencePattern([1, 2, 3]), - UnaryOpPattern(operator.invert, SequencePattern([1, 2, 3])), + UnaryOpPattern[int](operator.invert, SequencePattern([1, 2, 3])), ), ( operator.neg, SequencePattern([1, 2, 3]), - UnaryOpPattern(operator.neg, SequencePattern([1, 2, 3])), + UnaryOpPattern[int](operator.neg, SequencePattern([1, 2, 3])), ), ( operator.pos, SequencePattern([1, 2, 3]), - UnaryOpPattern(operator.pos, SequencePattern([1, 2, 3])), + UnaryOpPattern[int](operator.pos, SequencePattern([1, 2, 3])), ), ], ) diff --git a/tests/test_osc.py b/tests/test_osc.py index 15d02662a..4de88c09c 100644 --- a/tests/test_osc.py +++ b/tests/test_osc.py @@ -1,12 +1,12 @@ import asyncio +import concurrent.futures import logging -import time +from typing import List import pytest from uqbar.strings import normalize from supriya.osc import ( - NTP_DELTA, AsyncOscProtocol, HealthCheck, OscBundle, @@ -14,10 +14,24 @@ ThreadedOscProtocol, find_free_port, ) +from supriya.osc.messages import NTP_DELTA +from supriya.osc.protocols import BootStatus from supriya.scsynth import AsyncProcessProtocol, Options, SyncProcessProtocol +logger = logging.getLogger(__name__) -def test_OscMessage(): + +async def get(x): + if asyncio.iscoroutine(x): + return await x + elif asyncio.isfuture(x): + return await x + elif isinstance(x, concurrent.futures.Future): + return x.result() + return x + + +def test_OscMessage() -> None: osc_message = OscMessage( "/foo", 1, @@ -67,7 +81,7 @@ def test_OscMessage(): ) -def test_new_ntp_era(): +def test_new_ntp_era() -> None: """ Check for NTP timestamp overflow. """ @@ -76,74 +90,59 @@ def test_new_ntp_era(): assert datagram.hex() == "0000000100000000" -@pytest.fixture(autouse=True) -def log_everything(caplog): - caplog.set_level(logging.DEBUG, logger="supriya.osc") - caplog.set_level(logging.DEBUG, logger="supriya.server") - - +@pytest.mark.parametrize( + "osc_protocol_class, process_protocol_class", + [ + (AsyncOscProtocol, AsyncProcessProtocol), + (ThreadedOscProtocol, SyncProcessProtocol), + ], +) @pytest.mark.asyncio -async def test_AsyncOscProtocol(): - def on_healthcheck_failed(): +async def test_OscProtocol(osc_protocol_class, process_protocol_class) -> None: + def on_healthcheck_failed() -> None: healthcheck_failed.append(True) + logger.info("START") + healthcheck_failed: List[bool] = [] + port = find_free_port() + + logger.info("INIT PROCESS") + process_protocol = process_protocol_class() + + logger.info("INIT PROTOCOL") + osc_protocol = osc_protocol_class(on_panic_callback=on_healthcheck_failed) + assert osc_protocol.status == BootStatus.OFFLINE + try: - healthcheck_failed = [] - port = find_free_port() - options = Options(port=port) - healthcheck = HealthCheck( - request_pattern=["/status"], - response_pattern=["/status.reply"], - callback=on_healthcheck_failed, - max_attempts=3, + logger.info("BOOT PROCESS") + await get(process_protocol.boot(Options(port=port))) + await get(process_protocol.boot_future) + assert process_protocol.status == BootStatus.ONLINE + + logger.info("CONNECT PROTOCOL") + await get( + osc_protocol.connect( + "127.0.0.1", + port, + healthcheck=HealthCheck( + request_pattern=["/status"], + response_pattern=["/status.reply"], + max_attempts=3, + ), + ) ) - process_protocol = AsyncProcessProtocol() - await process_protocol.boot(options) - assert await process_protocol.boot_future - osc_protocol = AsyncOscProtocol() - await osc_protocol.connect("127.0.0.1", port, healthcheck=healthcheck) - assert osc_protocol.is_running - assert not healthcheck_failed - await asyncio.sleep(1) - await process_protocol.quit() - for _ in range(20): - await asyncio.sleep(1) - if not osc_protocol.is_running: - break - assert healthcheck_failed - assert not osc_protocol.is_running - finally: - await process_protocol.quit() + assert osc_protocol.status == BootStatus.BOOTING + logger.info("AWAIT CONNECTION") + await get(osc_protocol.boot_future) + assert osc_protocol.status == BootStatus.ONLINE -def test_ThreadedOscProtocol(): - def on_healthcheck_failed(): - healthcheck_failed.append(True) + logger.info("QUIT PROCESS") + await get(process_protocol.quit()) - healthcheck_failed = [] - options = Options() - port = find_free_port() - healthcheck = HealthCheck( - request_pattern=["/status"], - response_pattern=["/status.reply"], - callback=on_healthcheck_failed, - max_attempts=3, - ) - process_protocol = SyncProcessProtocol() - process_protocol.boot(options) - assert process_protocol.is_running - osc_protocol = ThreadedOscProtocol() - osc_protocol.connect("127.0.0.1", port, healthcheck=healthcheck) - assert osc_protocol.is_running - assert not healthcheck_failed - time.sleep(1) - process_protocol.quit() - assert not process_protocol.is_running - assert osc_protocol.is_running - assert not healthcheck_failed - for _ in range(20): - time.sleep(1) - if not osc_protocol.is_running: - break - assert healthcheck_failed - assert not osc_protocol.is_running + logger.info("AWAIT DISCONNECTION") + await get(osc_protocol.exit_future) + assert osc_protocol.status == BootStatus.OFFLINE + + finally: + await get(process_protocol.quit())