Skip to content

Commit

Permalink
(fix): restarting of session on OsError
Browse files Browse the repository at this point in the history
  • Loading branch information
gautamajay52 committed Sep 30, 2024
1 parent 163155e commit 20f4f3c
Showing 1 changed file with 64 additions and 163 deletions.
227 changes: 64 additions & 163 deletions pyrogram/session/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import bisect
import logging
import os
from time import time
from hashlib import sha1
from io import BytesIO
from typing import Optional
Expand Down Expand Up @@ -52,14 +51,11 @@ def __init__(self):
class Session:
START_TIMEOUT = 5
WAIT_TIMEOUT = 15
REART_TIMEOUT = 5
SLEEP_THRESHOLD = 10
MAX_RETRIES = 10
ACKS_THRESHOLD = 10
PING_INTERVAL = 5
STORED_MSG_IDS_MAX_SIZE = 1000 * 2
RECONNECT_THRESHOLD = 13
STOP_RANGE = range(2)

TRANSPORT_ERRORS = {
404: "auth key not found",
Expand Down Expand Up @@ -110,20 +106,12 @@ def __init__(
self.recv_task = None

self.is_started = asyncio.Event()
self.restart_event = asyncio.Event()

self.loop = asyncio.get_event_loop()

self.instant_stop = False # set internally
self.last_reconnect_attempt = None
self.currently_restarting = False
self.currently_stopping = False

async def start(self):
while True:
if self.instant_stop:
log.info("session init force stopped (loop)")
return # stop instantly

self.connection = self.client.connection_factory(
dc_id=self.dc_id,
test_mode=self.test_mode,
Expand Down Expand Up @@ -183,106 +171,51 @@ async def start(self):

log.info("Session started")

async def stop(self, restart: bool = False):
if self.currently_stopping:
return # don't stop twice
if self.instant_stop:
log.info("session stop process stopped")
return # stop doing anything instantly, client is manually handling
async def stop(self):
self.is_started.clear()

try:
self.currently_stopping = True
self.is_started.clear()
self.stored_msg_ids.clear()

if restart:
self.instant_stop = True # tell all funcs that we want to stop

self.ping_task_event.set()
for _ in self.STOP_RANGE:
try:
if self.ping_task is not None:
await asyncio.wait_for(
self.ping_task, timeout=self.REART_TIMEOUT
)
break
except TimeoutError:
self.ping_task.cancel()
continue # next stage
self.ping_task_event.clear()
self.stored_msg_ids.clear()

self.ping_task_event.set()

if self.ping_task is not None:
await self.ping_task

self.ping_task_event.clear()

await self.connection.close()

if self.recv_task:
await self.recv_task

if not self.is_media and callable(self.client.disconnect_handler):
try:
await asyncio.wait_for(
self.connection.close(), timeout=self.REART_TIMEOUT
)
await self.client.disconnect_handler(self.client)
except Exception as e:
log.exception(e)

for _ in self.STOP_RANGE:
try:
if self.recv_task:
await asyncio.wait_for(
self.recv_task, timeout=self.REART_TIMEOUT
)
break
except TimeoutError:
self.recv_task.cancel()
continue # next stage

if not self.is_media and callable(self.client.disconnect_handler):
try:
await self.client.disconnect_handler(self.client)
except Exception as e:
log.exception(e)

log.info("session stopped")
finally:
self.currently_stopping = False
if restart:
self.instant_stop = False # reset
log.info("Session stopped")

async def restart(self):
if self.currently_restarting:
return # don't restart twice
if self.instant_stop:
return # stop instantly

try:
self.currently_restarting = True
now = time()
if (
self.last_reconnect_attempt
and (now - self.last_reconnect_attempt) < self.RECONNECT_THRESHOLD
):
to_wait = self.RECONNECT_THRESHOLD + int(
self.RECONNECT_THRESHOLD - (now - self.last_reconnect_attempt)
)
log.warning(
"[pyrogram] Client [%s] is reconnecting too frequently, sleeping for %s seconds",
self.client.name,
to_wait
)
await asyncio.sleep(to_wait)

self.last_reconnect_attempt = now
await self.stop(restart=True)
await self.start()
finally:
self.currently_restarting = False
self.restart_event.set()
await self.stop()
await self.start()
self.restart_event.clear()

async def handle_packet(self, packet):
if self.instant_stop:
log.info("Stopped packet handler")
return # stop instantly

data = await self.loop.run_in_executor(
pyrogram.crypto_executor,
mtproto.unpack,
BytesIO(packet),
self.session_id,
self.auth_key,
self.auth_key_id
)
try:
data = await self.loop.run_in_executor(
pyrogram.crypto_executor,
mtproto.unpack,
BytesIO(packet),
self.session_id,
self.auth_key,
self.auth_key_id
)
except ValueError as e:
log.debug(e)
self.loop.create_task(self.restart())
return

messages = (
data.body.messages
Expand Down Expand Up @@ -360,17 +293,9 @@ async def handle_packet(self, packet):
self.pending_acks.clear()

async def ping_worker(self):
if self.instant_stop:
log.info("PingTask force stopped")
return # stop instantly

log.info("PingTask started")

while True:
if self.instant_stop:
log.info("PingTask force stopped (loop)")
return # stop instantly

try:
await asyncio.wait_for(self.ping_task_event.wait(), self.PING_INTERVAL)
except asyncio.TimeoutError:
Expand All @@ -396,10 +321,6 @@ async def recv_worker(self):
log.info("NetworkTask started")

while True:
if self.instant_stop:
log.info("NetworkTask force stopped (loop)")
return # stop instantly

packet = await self.connection.recv()

if packet is None or len(packet) == 4:
Expand All @@ -412,10 +333,8 @@ async def recv_worker(self):
# "and log in again with your phone number or bot token."
# )
log.warning(
"[%s] Server sent transport error: %s (%s)",
self.client.name,
error_code,
Session.TRANSPORT_ERRORS.get(error_code, "unknown error"),
"Server sent transport error: %s (%s)",
error_code, Session.TRANSPORT_ERRORS.get(error_code, "unknown error")
)

if self.is_started.is_set():
Expand All @@ -431,11 +350,8 @@ async def send(
self,
data: TLObject,
wait_response: bool = True,
timeout: float = WAIT_TIMEOUT,
timeout: float = WAIT_TIMEOUT
):
if self.instant_stop:
return # stop instantly

message = self.msg_factory(data)
msg_id = message.msg_id

Expand Down Expand Up @@ -493,27 +409,19 @@ async def invoke(
timeout: float = WAIT_TIMEOUT,
sleep_threshold: float = SLEEP_THRESHOLD
):
try:
await asyncio.wait_for(self.is_started.wait(), self.WAIT_TIMEOUT)
except asyncio.TimeoutError:
pass

if isinstance(query, Session.CUR_ALWD_INNR_QRYS):
inner_query = query.query
else:
inner_query = query

query_name = ".".join(inner_query.QUALNAME.split(".")[1:])

while retries > 0:
# sleep until the restart is performed
if self.currently_restarting:
while self.currently_restarting:
if self.instant_stop:
return # stop instantly
await asyncio.sleep(1)

if self.instant_stop:
return # stop instantly

if not self.is_started.is_set():
await self.is_started.wait()

while True:
try:
return await self.send(query, timeout=timeout)
except (FloodWait, FloodPremiumWait) as e:
Expand All @@ -522,12 +430,8 @@ async def invoke(
if amount > sleep_threshold >= 0:
raise

log.warning(
'[%s] Waiting for %s seconds before continuing (required by "%s")',
self.client.name,
amount,
query_name,
)
log.warning('[%s] Waiting for %s seconds before continuing (required by "%s")',
self.client.name, amount, query_name)

await asyncio.sleep(amount)
except (
Expand All @@ -550,26 +454,23 @@ async def invoke(
):
raise e from None

if (isinstance(e, (OSError, RuntimeError)) and "handler" in str(e)) or (
isinstance(e, TimeoutError)
):
(log.warning if retries < 2 else log.info)(
'[%s] [%s] reconnecting session requesting "%s", due to: %s',
self.client.name,
Session.MAX_RETRIES - retries,
query_name,
str(e) or repr(e),
)
(log.warning if retries < 2 else log.info)(
'[%s] Retrying "%s" due to: %s',
Session.MAX_RETRIES - retries + 1,
query_name, str(e) or repr(e)
)

# restart was never being called after Exception block
if not self.restart_event.is_set():
self.loop.create_task(self.restart())
else:
(log.warning if retries < 2 else log.info)(
'[%s] [%s] Retrying "%s" due to: %s',
self.client.name,
Session.MAX_RETRIES - retries,
query_name,
str(e) or repr(e),
)

await asyncio.sleep(1)

# multiple Exceptions can be raised in a row, so we need to wait for the restart to finish
try:
await asyncio.wait_for(self.restart_event.wait(), self.WAIT_TIMEOUT)
except asyncio.TimeoutError:
pass

await asyncio.sleep(0.5)

return await self.invoke(query, retries - 1, timeout)
raise TimeoutError("Exceeded maximum number of retries")

0 comments on commit 20f4f3c

Please sign in to comment.