Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: remove subscription and use internal interval #31

Merged
merged 7 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions config/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
# the configuration for the chosen engine as described below.
provider_engine = 'pyth_replicator'

product_update_interval_secs = 10
price_update_interval_secs = 1.0
product_update_interval_secs = 60
health_check_port = 8000

# The health check will return a failure status if no price data has been published within the specified time frame.
Expand All @@ -22,7 +23,7 @@ endpoint = 'ws://127.0.0.1:8910'
# coin_gecko_id = 'bitcoin'

[publisher.pyth_replicator]
http_endpoint = 'https://pythnet.rpcpool.com'
ws_endpoint = 'wss://pythnet.rpcpool.com'
http_endpoint = 'https://api2.pythnet.pyth.network'
ws_endpoint = 'wss://api2.pythnet.pyth.network'
first_mapping = 'AHtgzX45WTKfkPG53L6WYhGEXwQkN1BVknET3sVsLL8J'
program_key = 'FsJ3A3u2vn5cTVofAjvy6y5kwABJAqYWpe4975bi2epH'
2 changes: 1 addition & 1 deletion example_publisher/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
_DEFAULT_CONFIG_PATH = os.path.join("config", "config.toml")


log_level = logging._nameToLevel[os.environ.get("LOG_LEVEL", "DEBUG").upper()]
log_level = logging._nameToLevel[os.environ.get("LOG_LEVEL", "INFO").upper()]
structlog.configure(wrapper_class=structlog.make_filtering_bound_logger(log_level))

log = structlog.get_logger()
Expand Down
1 change: 1 addition & 0 deletions example_publisher/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class Config:
pythd: Pythd
health_check_port: int
health_check_threshold_secs: int
price_update_interval_secs: float = ts.option(default=1.0)
product_update_interval_secs: int = ts.option(default=60)
coin_gecko: Optional[CoinGeckoConfig] = ts.option(default=None)
pyth_replicator: Optional[PythReplicatorConfig] = ts.option(default=None)
4 changes: 2 additions & 2 deletions example_publisher/providers/pyth_replicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ async def _update_loop(self) -> None:
update.timestamp,
)

log.info(
log.debug(
"Received a price update", symbol=symbol, price=self._prices[symbol]
)

Expand All @@ -118,7 +118,7 @@ async def _update_accounts_loop(self) -> None:

await asyncio.sleep(self._config.account_update_interval_secs)

def upd_products(self, *args) -> None:
def upd_products(self, product_symbols: List[Symbol]) -> None:
# This provider stores all the possible feeds and
# does not care about the desired products as knowing
# them does not improve the performance of the replicator
Expand Down
103 changes: 47 additions & 56 deletions example_publisher/publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from example_publisher.providers.coin_gecko import CoinGecko
from example_publisher.config import Config
from example_publisher.providers.pyth_replicator import PythReplicator
from example_publisher.pythd import Pythd, SubscriptionId
from example_publisher.pythd import PriceUpdate, Pythd, SubscriptionId


log = get_logger()
Expand Down Expand Up @@ -50,7 +50,6 @@ def __init__(self, config: Config) -> None:

self.pythd: Pythd = Pythd(
address=config.pythd.endpoint,
on_notify_price_sched=self.on_notify_price_sched,
)
self.subscriptions: Dict[SubscriptionId, Product] = {}
self.products: List[Product] = []
Expand All @@ -66,18 +65,17 @@ def is_healthy(self) -> bool:
async def start(self):
await self.pythd.connect()

self._product_update_task = asyncio.create_task(
self._start_product_update_loop()
)

async def _start_product_update_loop(self):
await self._upd_products()

self._product_update_task = asyncio.create_task(self._product_update_loop())
self._price_update_task = asyncio.create_task(self._price_update_loop())

self.provider.start()

async def _product_update_loop(self):
while True:
await self._upd_products()
await self._subscribe_notify_price_sched()
await asyncio.sleep(self.config.product_update_interval_secs)
await self._upd_products()

async def _upd_products(self):
log.debug("fetching product accounts from Pythd")
Expand Down Expand Up @@ -114,58 +112,51 @@ async def _upd_products(self):

self.provider.upd_products([product.symbol for product in self.products])

async def _subscribe_notify_price_sched(self):
# Subscribe to Pythd's notify_price_sched for each product that
# is not subscribed yet. Unfortunately there is no way to unsubscribe
# to the prices that are no longer available.
log.debug("subscribing to notify_price_sched")

subscriptions = {}
for product in self.products:
if not product.subscription_id:
subscription_id = await self.pythd.subscribe_price_sched(
product.price_account
async def _price_update_loop(self):
while True:
price_updates = []
for product in self.products:
price = self.provider.latest_price(product.symbol)
if not price:
log.info("latest price not available", symbol=product.symbol)
continue

scaled_price = self.apply_exponent(price.price, product.exponent)
scaled_conf = self.apply_exponent(price.conf, product.exponent)

price_updates.append(
PriceUpdate(
account=product.price_account,
price=scaled_price,
conf=scaled_conf,
status=TRADING,
)
)
log.debug(
"sending price update",
symbol=product.symbol,
price_account=product.price_account,
price=price.price,
conf=price.conf,
scaled_price=scaled_price,
scaled_conf=scaled_conf,
)
product.subscription_id = subscription_id

subscriptions[product.subscription_id] = product

self.subscriptions = subscriptions

async def on_notify_price_sched(self, subscription: int) -> None:

log.debug("received notify_price_sched", subscription=subscription)
if subscription not in self.subscriptions:
return
self.last_successful_update = (
price.timestamp
if self.last_successful_update is None
else max(self.last_successful_update, price.timestamp)
)

# Look up the current price and confidence interval of the product
product = self.subscriptions[subscription]
price = self.provider.latest_price(product.symbol)
if not price:
log.info("latest price not available", symbol=product.symbol)
return
log.info(
"sending batch update_price",
num_price_updates=len(price_updates),
total_products=len(self.products),
)

# Scale the price and confidence interval using the Pyth exponent
scaled_price = self.apply_exponent(price.price, product.exponent)
scaled_conf = self.apply_exponent(price.conf, product.exponent)
await self.pythd.update_price_batch(price_updates)

# Send the price update
log.info(
"sending update_price",
product_account=product.product_account,
price_account=product.price_account,
price=scaled_price,
conf=scaled_conf,
symbol=product.symbol,
)
await self.pythd.update_price(
product.price_account, scaled_price, scaled_conf, TRADING
)
self.last_successful_update = (
price.timestamp
if self.last_successful_update is None
else max(self.last_successful_update, price.timestamp)
)
await asyncio.sleep(self.config.price_update_interval_secs)

def apply_exponent(self, x: float, exp: int) -> int:
return int(x * (10 ** (-exp)))
125 changes: 78 additions & 47 deletions example_publisher/pythd.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import asyncio
from dataclasses import dataclass, field
import sys
import traceback
from dataclasses_json import config, DataClassJsonMixin
from typing import Callable, Coroutine, List
import json
from dataclasses_json import config, DataClassJsonMixin, dataclass_json
from dataclasses_json.undefined import Undefined
from typing import List, Any, Optional
from structlog import get_logger
from jsonrpc_websocket import Server
from websockets.client import connect, WebSocketClientProtocol
from asyncio import Lock

log = get_logger()

Expand All @@ -15,12 +15,22 @@
TRADING = "trading"


@dataclass_json(undefined=Undefined.EXCLUDE)
@dataclass
class Price(DataClassJsonMixin):
account: str
exponent: int = field(metadata=config(field_name="price_exponent"))


@dataclass
class PriceUpdate(DataClassJsonMixin):
account: str
price: int
conf: int
status: str


@dataclass_json(undefined=Undefined.EXCLUDE)
@dataclass
class Metadata(DataClassJsonMixin):
symbol: str
Expand All @@ -34,56 +44,77 @@ class Product(DataClassJsonMixin):
prices: List[Price] = field(metadata=config(field_name="price"))


@dataclass
class JSONRPCRequest(DataClassJsonMixin):
id: int
method: str
params: List[Any] | Any
jsonrpc: str = "2.0"


@dataclass
class JSONRPCResponse(DataClassJsonMixin):
id: int
result: Optional[Any] = None
error: Optional[Any] = None
jsonrpc: str = "2.0"


class Pythd:
def __init__(
self,
address: str,
on_notify_price_sched: Callable[[SubscriptionId], Coroutine[None, None, None]],
) -> None:
self.address = address
self.server: Server
self.on_notify_price_sched = on_notify_price_sched
self._tasks = set()
self.client: WebSocketClientProtocol
self.id_counter = 0
self.lock = Lock()

async def connect(self):
self.server = Server(self.address)
self.server.notify_price_sched = self._notify_price_sched
task = await self.server.ws_connect()
task.add_done_callback(Pythd._on_connection_done)
self._tasks.add(task)

@staticmethod
def _on_connection_done(task):
log.error("pythd connection closed")
if not task.cancelled() and task.exception() is not None:
e = task.exception()
traceback.print_exception(None, e, e.__traceback__)
sys.exit(1)

async def subscribe_price_sched(self, account: str) -> int:
subscription = (await self.server.subscribe_price_sched(account=account))[
"subscription"
]
log.debug(
"subscribed to price_sched", account=account, subscription=subscription
self.client = await connect(self.address)

def _create_request(self, method: str, params: List[Any] | Any) -> JSONRPCRequest:
self.id_counter += 1
return JSONRPCRequest(
id=self.id_counter,
method=method,
params=params,
)
return subscription

def _notify_price_sched(self, subscription: int) -> None:
log.debug("notify_price_sched RPC call received", subscription=subscription)
task = asyncio.get_event_loop().create_task(
self.on_notify_price_sched(subscription)
)
self._tasks.add(task)
task.add_done_callback(self._tasks.discard)
async def send_request(self, request: JSONRPCRequest) -> JSONRPCResponse:
# Using a lock will result in a synchronous execution of the send_request method
# and response retrieval which makes the code easier but is not good for performance.
# It is not recommended to use this behaviour where there are concurrent requests
# being made to the server.
async with self.lock:
await self.client.send(request.to_json())
response = await self.client.recv()
return JSONRPCResponse.from_json(response)

async def send_batch_request(
self, requests: List[JSONRPCRequest]
) -> List[JSONRPCResponse]:
async with self.lock:
await self.client.send(
json.dumps([request.to_dict() for request in requests])
)
response = await self.client.recv()
return JSONRPCResponse.schema().loads(response, many=True)

async def all_products(self) -> List[Product]:
result = await self.server.get_product_list()
return [Product.from_dict(d) for d in result]

async def update_price(
self, account: str, price: int, conf: int, status: str
) -> None:
await self.server.update_price(
account=account, price=price, conf=conf, status=status
)
request = self._create_request("get_product_list", [])
result = await self.send_request(request)
if result.result:
return Product.schema().load(result.result, many=True)
else:
raise ValueError(f"Error fetching products: {result.to_json()}")

async def update_price_batch(self, price_updates: List[PriceUpdate]) -> None:
requests = [
self._create_request("update_price", price_update.to_dict())
for price_update in price_updates
]
results = await self.send_batch_request(requests)
if any(result.error for result in results):
results_json_str = JSONRPCResponse.schema().dumps(results, many=True)
raise ValueError(f"Error updating prices: {results_json_str}")
Loading
Loading