diff --git a/python/cog/server/worker.py b/python/cog/server/worker.py index 403ee4e59..d3c54113c 100644 --- a/python/cog/server/worker.py +++ b/python/cog/server/worker.py @@ -286,55 +286,54 @@ def run(self) -> None: callback=self._stream_write_hook, ) - self._setup(redirector) - self._loop(redirector) + with redirector: + self._setup(redirector) + self._loop(redirector) def send_cancel(self) -> None: if self.is_alive() and self.pid: os.kill(self.pid, signal.SIGUSR1) def _setup(self, redirector: StreamRedirector) -> None: - with redirector: - done = Done() + done = Done() + try: + self._predictor = load_predictor_from_ref(self._predictor_ref) + # Could be a function or a class + if hasattr(self._predictor, "setup"): + run_setup(self._predictor) + except Exception as e: # pylint: disable=broad-exception-caught + traceback.print_exc() + done.error = True + done.error_detail = str(e) + except BaseException as e: + # For SystemExit and friends we attempt to add some useful context + # to the logs, but reraise to ensure the process dies. + traceback.print_exc() + done.error = True + done.error_detail = str(e) + raise + finally: try: - self._predictor = load_predictor_from_ref(self._predictor_ref) - # Could be a function or a class - if hasattr(self._predictor, "setup"): - run_setup(self._predictor) - except Exception as e: # pylint: disable=broad-exception-caught - traceback.print_exc() - done.error = True - done.error_detail = str(e) - except BaseException as e: - # For SystemExit and friends we attempt to add some useful context - # to the logs, but reraise to ensure the process dies. - traceback.print_exc() - done.error = True - done.error_detail = str(e) - raise - finally: - try: - redirector.drain(timeout=10) - except TimeoutError: - self._events.send( - Log( - "WARNING: logs may be truncated due to excessive volume.", - source="stderr", - ) + redirector.drain(timeout=10) + except TimeoutError: + self._events.send( + Log( + "WARNING: logs may be truncated due to excessive volume.", + source="stderr", ) - raise - self._events.send(done) + ) + raise + self._events.send(done) def _loop(self, redirector: StreamRedirector) -> None: - with redirector: - while True: - ev = self._events.recv() - if isinstance(ev, Shutdown): - break - if isinstance(ev, PredictionInput): - self._predict(ev.payload, redirector) - else: - print(f"Got unexpected event: {ev}", file=sys.stderr) + while True: + ev = self._events.recv() + if isinstance(ev, Shutdown): + break + if isinstance(ev, PredictionInput): + self._predict(ev.payload, redirector) + else: + print(f"Got unexpected event: {ev}", file=sys.stderr) def _predict( self,