diff --git a/src/brad/daemon/transition_orchestrator.py b/src/brad/daemon/transition_orchestrator.py index a08f5e26..75b6c208 100644 --- a/src/brad/daemon/transition_orchestrator.py +++ b/src/brad/daemon/transition_orchestrator.py @@ -480,8 +480,11 @@ async def _run_aurora_post_transition( table_diffs: Optional[list[TableDiff]], ) -> None: # Drop removed tables. + assert self._curr_blueprint is not None + aurora_on = self._curr_blueprint.aurora_provisioning().num_nodes() > 0 if ( - table_diffs is not None + aurora_on + and table_diffs is not None and len(table_diffs) > 0 and self._config.disable_table_movement is False and self._config.skip_aurora_table_deletion is False @@ -626,7 +629,13 @@ async def _run_redshift_post_transition( self, diff: Optional[ProvisioningDiff], table_diffs: Optional[list[TableDiff]] ) -> None: # Drop removed tables - if table_diffs is not None and self._config.disable_table_movement is False: + assert self._curr_blueprint is not None + redshift_on = self._curr_blueprint.redshift_provisioning().num_nodes() > 0 + if ( + redshift_on + and table_diffs is not None + and self._config.disable_table_movement is False + ): if self._system_event_logger is not None: self._system_event_logger.log( SystemEvent.PostTableMovementStarted, "redshift" @@ -825,9 +834,9 @@ def _new_execution_context(self) -> ExecutionContext: nonsilent_assert(self._cxns is not None) assert self._cxns is not None return ExecutionContext( - aurora=self._cxns.get_connection(Engine.Aurora), - athena=self._cxns.get_connection(Engine.Athena), - redshift=self._cxns.get_connection(Engine.Redshift), + aurora=self._cxns.get_connection_if_exists(Engine.Aurora), + athena=self._cxns.get_connection_if_exists(Engine.Athena), + redshift=self._cxns.get_connection_if_exists(Engine.Redshift), blueprint=self._blueprint_mgr.get_blueprint(), config=self._config, ) diff --git a/src/brad/front_end/engine_connections.py b/src/brad/front_end/engine_connections.py index 5c63dd3c..b1aa041c 100644 --- a/src/brad/front_end/engine_connections.py +++ b/src/brad/front_end/engine_connections.py @@ -274,6 +274,12 @@ def get_connection(self, engine: Engine) -> Connection: except KeyError as ex: raise RuntimeError("Not connected to {}".format(engine)) from ex + def get_connection_if_exists(self, engine: Engine) -> Optional[Connection]: + try: + return self._connection_map[engine] + except KeyError: + return None + def get_reader_connection( self, engine: Engine, specific_index: Optional[int] = None ) -> Connection: