diff --git a/cellfinder/core/download/cli.py b/cellfinder/core/download/cli.py index e3c6aa46..c97bbd18 100644 --- a/cellfinder/core/download/cli.py +++ b/cellfinder/core/download/cli.py @@ -3,7 +3,7 @@ from pathlib import Path from cellfinder.core.download import models -from cellfinder.core.download.download import amend_cfg +from cellfinder.core.download.download import amend_user_configuration home = Path.home() DEFAULT_DOWNLOAD_DIRECTORY = home / ".cellfinder" @@ -65,7 +65,7 @@ def main(): model_path = models.main(args.model, args.install_path) if not args.no_amend_config: - amend_cfg(new_model_path=model_path) + amend_user_configuration(new_model_path=model_path) if __name__ == "__main__": diff --git a/cellfinder/core/download/download.py b/cellfinder/core/download/download.py index ee242daa..00d6dee3 100644 --- a/cellfinder/core/download/download.py +++ b/cellfinder/core/download/download.py @@ -7,8 +7,8 @@ from brainglobe_utils.general.system import disk_free_gb from cellfinder.core.tools.source_files import ( - source_config_cellfinder, - source_custom_config_cellfinder, + default_configuration_path, + user_specific_configuration_path, ) @@ -75,16 +75,45 @@ def download( os.remove(download_path) -def amend_cfg(new_model_path=None): - print("Ensuring custom config file is correct") - - original_config = source_config_cellfinder() - new_config = source_custom_config_cellfinder() - if new_model_path is not None: - write_model_to_cfg(new_model_path, original_config, new_config) +def amend_user_configuration(new_model_path=None) -> None: + """ + Amends the user configuration to contain the configuration + in new_model_path, if specified. + Parameters + ---------- + new_model_path : str, optional + The path to the new model configuration. + """ + print("(Over-)writing custom user configuration") -def write_model_to_cfg(new_model_path, orig_config, custom_config): + original_config = default_configuration_path() + new_config = user_specific_configuration_path() + if new_model_path is not None: + write_model_to_config(new_model_path, original_config, new_config) + + +def write_model_to_config(new_model_path, orig_config, custom_config): + """ + Update the model path in the custom configuration file, by + reading the lines in the original configuration file, replacing + the line starting with "model_path =" and writing these + lines to the custom file. + + Parameters + ---------- + new_model_path : str + The new path to the model. + orig_config : str + The path to the original configuration file. + custom_config : str + The path to the custom configuration file to be created. + + Returns + ------- + None + + """ config_obj = get_config_obj(orig_config) model_conf = config_obj["model"] orig_path = model_conf["model_path"] diff --git a/cellfinder/core/tools/prep.py b/cellfinder/core/tools/prep.py index e513917d..1e0bccf3 100644 --- a/cellfinder/core/tools/prep.py +++ b/cellfinder/core/tools/prep.py @@ -14,8 +14,8 @@ import cellfinder.core.tools.tf as tf_tools from cellfinder.core import logger from cellfinder.core.download import models as model_download -from cellfinder.core.download.download import amend_cfg -from cellfinder.core.tools.source_files import source_custom_config_cellfinder +from cellfinder.core.download.download import amend_user_configuration +from cellfinder.core.tools.source_files import user_specific_configuration_path home = Path.home() DEFAULT_INSTALL_PATH = home / ".cellfinder" @@ -49,18 +49,18 @@ def prep_models( if model_weights_path is None: logger.debug("No model supplied, so using the default") - config_file = source_custom_config_cellfinder() + config_file = user_specific_configuration_path() if not Path(config_file).exists(): logger.debug("Custom config does not exist, downloading models") model_path = model_download.main(model_name, install_path) - amend_cfg(new_model_path=model_path) + amend_user_configuration(new_model_path=model_path) model_weights = get_model_weights(config_file) if not model_weights.exists(): logger.debug("Model weights do not exist, downloading") model_path = model_download.main(model_name, install_path) - amend_cfg(new_model_path=model_path) + amend_user_configuration(new_model_path=model_path) model_weights = get_model_weights(config_file) else: model_weights = Path(model_weights_path) diff --git a/cellfinder/core/tools/source_files.py b/cellfinder/core/tools/source_files.py index 51b16c07..474cc51e 100644 --- a/cellfinder/core/tools/source_files.py +++ b/cellfinder/core/tools/source_files.py @@ -1,9 +1,27 @@ from pathlib import Path -def source_config_cellfinder(): +def default_configuration_path(): + """ + Returns the default configuration path for cellfinder. + + Returns: + Path: The default configuration path. + """ return Path(__file__).parent.parent / "config" / "cellfinder.conf" -def source_custom_config_cellfinder(): - return Path(__file__).parent.parent / "config" / "cellfinder.conf.custom" +def user_specific_configuration_path(): + """ + Returns the path to the user-specific configuration file for cellfinder. + + This function returns the path to the user-specific configuration file + for cellfinder. The user-specific configuration file is located in the + user's home directory under the ".cellfinder" folder and is named + "cellfinder.conf.custom". + + Returns: + Path: The path to the custom configuration file. + + """ + return Path.home() / ".cellfinder" / "cellfinder.conf.custom"