Skip to content

Commit

Permalink
feat: Adding performance improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
TomMcL committed Sep 22, 2023
1 parent 54f667d commit a87ecb9
Show file tree
Hide file tree
Showing 14 changed files with 278 additions and 102 deletions.
6 changes: 6 additions & 0 deletions examples/nullchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@
vega.forward("10s")
vega.wait_for_total_catchup()

import pdb

pdb.set_trace()
vega.create_asset(
MM_WALLET.name,
name="tDAI",
Expand Down Expand Up @@ -213,6 +216,9 @@
wait=True,
)

import pdb

pdb.set_trace()
vega.cancel_order(MM_WALLET.name, market_id, to_cancel)

vega.submit_order(
Expand Down
110 changes: 82 additions & 28 deletions vega_sim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,24 @@ class MissingMarketError(Exception):
S = TypeVar("S")

PartyMarketAccount = namedtuple("PartyMarketAccount", ["general", "margin", "bond"])
AccountData = namedtuple(
"AccountData", ["owner", "balance", "asset", "market_id", "type"]
)
RiskFactor = namedtuple("RiskFactors", ["market_id", "short", "long"])
OrderBook = namedtuple("OrderBook", ["bids", "asks"])
PriceLevel = namedtuple("PriceLevel", ["price", "number_of_orders", "volume"])


@dataclass
class AccountData:
owner: str
balance: float
asset: str
market_id: str
type: vega_protos.vega.AccountType

@property
def account_id(self):
return f"{self.owner}-{self.type}-{self.market_id}-{self.asset}"


@dataclass
class IcebergOrder:
peak_size: float
Expand Down Expand Up @@ -389,9 +399,11 @@ def _order_from_proto(
updated_at=order.updated_at,
version=order.version,
market_id=order.market_id,
iceberg_order=_iceberg_order_from_proto(order.iceberg_order, decimal_spec)
if order.HasField("iceberg_order")
else None,
iceberg_order=(
_iceberg_order_from_proto(order.iceberg_order, decimal_spec)
if order.HasField("iceberg_order")
else None
),
)


Expand Down Expand Up @@ -519,9 +531,9 @@ def positions_by_market(
market_info = data_raw.market_info(
market_id=pos.market_id, data_client=data_client
)
market_to_asset_map[
pos.market_id
] = market_info.tradable_instrument.instrument.future.settlement_asset
market_to_asset_map[pos.market_id] = (
market_info.tradable_instrument.instrument.future.settlement_asset
)

# Update maps if value does not exist for current asset id
if market_to_asset_map[pos.market_id] not in asset_decimals_map:
Expand Down Expand Up @@ -669,6 +681,16 @@ def _liquidity_provider_fee_share_from_proto(
]


def _account_from_proto(account, decimal_spec: DecimalSpec) -> AccountData:
return AccountData(
owner=account.owner,
balance=num_from_padded_int(int(account.balance), decimal_spec.asset_decimals),
asset=account.asset,
type=account.type,
market_id=account.market_id,
)


def list_accounts(
data_client: vac.VegaTradingDataClientV2,
pub_key: Optional[str] = None,
Expand All @@ -695,14 +717,8 @@ def list_accounts(
)

output_accounts.append(
AccountData(
owner=account.owner,
balance=num_from_padded_int(
int(account.balance), asset_decimals_map.setdefault(account.asset)
),
asset=account.asset,
type=account.type,
market_id=account.market_id,
_account_from_proto(
account, DecimalSpec(asset_decimals=asset_decimals_map[account.asset])
)
)
return output_accounts
Expand All @@ -727,22 +743,42 @@ def party_account(
asset_dp if asset_dp is not None else get_asset_decimals(asset_id, data_client)
)

return account_list_to_party_account(
[
account
for account in accounts
if account.market_id is None or account.market_id == market_id
],
asset_dp_conversion=asset_dp,
)


def account_list_to_party_account(
accounts: Union[
List[data_node_protos_v2.trading_data.AccountBalance], List[AccountData]
],
asset_dp_conversion: Optional[int] = None,
):
general, margin, bond = 0, 0, 0 # np.nan, np.nan, np.nan
for account in accounts:
if (
account.market_id
and account.market_id != "!"
and account.market_id != market_id
):
# The 'general' account type has no market ID, so we have to pull
# all markets then filter down here
continue
if account.type == vega_protos.vega.ACCOUNT_TYPE_GENERAL:
general = num_from_padded_int(float(account.balance), asset_dp)
general = (
num_from_padded_int(float(account.balance), asset_dp_conversion)
if asset_dp_conversion is not None
else account.balance
)
if account.type == vega_protos.vega.ACCOUNT_TYPE_MARGIN:
margin = num_from_padded_int(float(account.balance), asset_dp)
margin = (
num_from_padded_int(float(account.balance), asset_dp_conversion)
if asset_dp_conversion is not None
else account.balance
)
if account.type == vega_protos.vega.ACCOUNT_TYPE_BOND:
bond = num_from_padded_int(float(account.balance), asset_dp)
bond = (
num_from_padded_int(float(account.balance), asset_dp_conversion)
if asset_dp_conversion is not None
else account.balance
)

return PartyMarketAccount(general, margin, bond)

Expand Down Expand Up @@ -1506,6 +1542,24 @@ def transfer_subscription_handler(
)


def accounts_subscription_handler(
stream: Iterable[vega_protos.api.v1.core.ObserveEventBusResponse],
mkt_pos_dp: Optional[Dict[str, int]] = None,
mkt_price_dp: Optional[Dict[str, int]] = None,
mkt_to_asset: Optional[Dict[str, str]] = None,
asset_dp: Optional[Dict[str, int]] = None,
) -> Transfer:
return _stream_handler(
stream_item=stream,
extraction_fn=lambda evt: evt.account,
conversion_fn=_account_from_proto,
mkt_pos_dp=mkt_pos_dp,
mkt_price_dp=mkt_price_dp,
mkt_to_asset=mkt_to_asset,
asset_dp=asset_dp,
)


def ledger_entries_subscription_handler(
stream_item: vega_protos.api.v1.core.ObserveEventBusResponse,
asset_dp: Optional[Dict[str, int]] = None,
Expand Down
1 change: 1 addition & 0 deletions vega_sim/environment/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
class VegaState:
network_state: Tuple
market_state: Dict[str, Any]
time: int


class Agent(ABC):
Expand Down
10 changes: 7 additions & 3 deletions vega_sim/environment/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,10 +409,14 @@ def _default_state_extraction(self, vega: VegaService) -> VegaState:
orders=order_status.get(market_id, {}),
)

return VegaState(network_state=(), market_state=market_state)
return VegaState(
network_state=(),
market_state=market_state,
time=vega.get_blockchain_time_from_feed(),
)

def step(self, vega: VegaService) -> None:
vega.wait_for_thread_catchup()
# vega.wait_for_thread_catchup()
state = self.state_func(vega)
for agent in (
sorted(self.agents, key=lambda _: self.random_state.random())
Expand All @@ -428,7 +432,7 @@ def step(self, vega: VegaService) -> None:
" received."
)
# Mint forwards blocks, wait for catchup
vega.wait_for_total_catchup()
# vega.wait_for_total_catchup()


class NetworkEnvironment(MarketEnvironmentWithState):
Expand Down
104 changes: 95 additions & 9 deletions vega_sim/local_data_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def __init__(
self.orders_lock = threading.RLock()
self.transfers_lock = threading.RLock()
self.asset_lock = threading.RLock()
self.account_lock = threading.RLock()
self.market_lock = threading.RLock()
self.market_data_lock = threading.RLock()
self.trades_lock = threading.RLock()
Expand All @@ -114,6 +115,9 @@ def __init__(
self._market_from_feed = {}
self.market_data_from_feed_store = {}
self._transfer_state_from_feed = {}
self._accounts_from_feed = {}
self._account_keys_for_party = {}
self._account_keys_for_market = {}
self._trades_from_feed: List[data.Trade] = []
self._ledger_entries_from_feed: List[data.LedgerEntry] = []

Expand Down Expand Up @@ -164,6 +168,16 @@ def __init__(
self._asset_decimals,
),
),
(
(events_protos.BUS_EVENT_TYPE_ACCOUNT,),
lambda evt: data.accounts_subscription_handler(
evt,
self._market_pos_decimals,
self._market_price_decimals,
self._market_to_asset,
self._asset_decimals,
),
),
]
self._high_load_stream_registry = [
(
Expand Down Expand Up @@ -282,6 +296,7 @@ def start_live_feeds(
)
self.initialise_assets()
self.initialise_markets()
self.initialise_accounts()
self.initialise_time_update_monitoring()
self.initialise_order_monitoring(
market_ids=market_ids,
Expand Down Expand Up @@ -327,6 +342,19 @@ def initialise_assets(self):
for asset in base_assets:
self._asset_from_feed[asset.id] = asset

def initialise_accounts(self):
base_accounts = data.list_accounts(data_client=self._trading_data_client)

with self.account_lock:
for account in base_accounts:
self._accounts_from_feed[account.account_id] = account
self._account_keys_for_party.setdefault(account.owner, set()).add(
account.account_id
)
self._account_keys_for_market.setdefault(account.market_id, set()).add(
account.account_id
)

def initialise_markets(self):
base_markets = data_raw.all_markets(self._trading_data_client)

Expand Down Expand Up @@ -388,15 +416,15 @@ def initialise_market_data(
]
with self.market_data_lock:
for market_id in market_ids:
self.market_data_from_feed_store[
market_id
] = data.get_latest_market_data(
market_id,
data_client=self._trading_data_client,
market_price_decimals_map=self._market_price_decimals,
market_position_decimals_map=self._market_pos_decimals,
asset_decimals_map=self._asset_decimals,
market_to_asset_map=self._market_to_asset,
self.market_data_from_feed_store[market_id] = (
data.get_latest_market_data(
market_id,
data_client=self._trading_data_client,
market_price_decimals_map=self._market_price_decimals,
market_position_decimals_map=self._market_pos_decimals,
asset_decimals_map=self._asset_decimals,
market_to_asset_map=self._market_to_asset,
)
)

def initialise_transfer_monitoring(
Expand Down Expand Up @@ -485,6 +513,16 @@ def _monitor_stream(self) -> None:
# get the decimal precision for the asset
self._asset_decimals[update.id]

elif isinstance(update, data.AccountData):
with self.account_lock:
self._accounts_from_feed[update.account_id] = update
self._account_keys_for_party.setdefault(
update.owner, set()
).add(update.account_id)
self._account_keys_for_market.setdefault(
update.market_id, set()
).add(update.account_id)

elif update is None:
logger.debug("Failed to process event into update.")

Expand Down Expand Up @@ -557,3 +595,51 @@ def get_trades_from_stream(
continue
results.append(trade)
return results

def get_accounts_from_stream(
self,
market_id: Optional[str] = None,
asset_id: Optional[str] = None,
party_id: Optional[str] = None,
) -> List[data.AccountData]:
"""Loads accounts for either a given party, market or both from stream.
Must pass one or the other
Args:
market_id:
optional str, Restrict to trades on a specific market
party_id:
optional str, Select only trades with a given id
Returns:
List[AccountData], list of formatted trade objects which match the required
restrictions.
"""
if market_id is None and party_id is None:
raise Exception("At least one of market_id and party_id must be specified")
with self.account_lock:
to_check_keys = set()
if market_id is not None:
to_check_keys.update(
self._account_keys_for_market.get(market_id, set())
)
if party_id is not None:
to_check_keys.update(self._account_keys_for_party.get(party_id, set()))
results = []
for key in to_check_keys:
acct = self._accounts_from_feed[key]
if party_id is not None and acct.owner != party_id:
continue
if (
market_id is not None
and acct.market_id != market_id
and acct.type != vega_protos.vega.AccountType.ACCOUNT_TYPE_GENERAL
):
continue
if market_id is None and acct.market_id != "":
continue
if asset_id is not None and acct.asset != asset_id:
continue
results.append(acct)

return results
Loading

0 comments on commit a87ecb9

Please sign in to comment.