diff --git a/osu/game.py b/osu/game.py index d2ff1c9..32a3993 100644 --- a/osu/game.py +++ b/osu/game.py @@ -72,26 +72,12 @@ def __init__( self.version = version self.tourney = tournament self.disable_chat = disable_chat_logging - self.version_number = version self.logger = logging.getLogger("osu!") self.logger.disabled = disable_logging - if not version: - # Fetch latest client version - self.version = self.fetch_version(stream) - - if not self.version: - # Failed to get version - exit(1) - else: - # Custom client version was set - self.version = f"b{self.version}" - - if self.tourney: - self.version = f"{self.version}tourney" - + self.resolve_version() self.session = requests.Session() self.session.headers = {"User-Agent": "osu!", "osu-version": self.version} @@ -114,6 +100,10 @@ def __init__( if tasks: self.tasks.tasks = tasks + if not self.version or not self.version_number: + # Failed to get version + exit(1) + if not (updates := self.api.check_updates()): # Updates are required because of the executable hash # TODO: Custom executable hash? @@ -179,6 +169,22 @@ async def run_async(self) -> None: loop = asyncio.get_event_loop() await loop.run_in_executor(None, self.run, False, False) + def resolve_version(self) -> None: + """Ensure the client version is set""" + if not self.version: + # Fetch latest client version + self.version = self.fetch_version(self.stream) + return + + if type(self.version) not in (float, int): + raise ValueError("Invalid version number") + + # Custom client version was set + self.version = f"b{self.version}" + + if self.tourney: + self.version = f"{self.version}tourney" + def fetch_version(self, stream: str = "stable40") -> Optional[str]: """ Fetch the latest version of the client from: diff --git a/osu/tcp/game.py b/osu/tcp/game.py index 3d0693a..5507541 100644 --- a/osu/tcp/game.py +++ b/osu/tcp/game.py @@ -77,17 +77,12 @@ def __init__( self.version = version self.tourney = tournament self.disable_chat = disable_chat_logging - self.version_number = version self.logger = logging.getLogger("osu!") self.logger.disabled = disable_logging - self.version = f"b{self.version}" - - if self.tourney: - self.version = f"{self.version}tourney" - + self.resolve_version() self.session = requests.Session() self.session.headers = {"User-Agent": "osu!", "osu-version": self.version} @@ -162,3 +157,13 @@ async def run_async(self) -> None: loop = asyncio.get_event_loop() await loop.run_in_executor(None, self.run, False, False) + + def resolve_version(self) -> None: + """Ensure a correct client version was set""" + if type(self.version) not in (float, int): + raise ValueError("Invalid version number") + + self.version = f"b{self.version}" + + if self.tourney: + self.version = f"{self.version}tourney"