From c0165ba12eb9dc61aa43ed2d5f5225f299ef073b Mon Sep 17 00:00:00 2001 From: David I Date: Wed, 6 Mar 2024 11:50:37 -0500 Subject: [PATCH 01/10] fixing pypi --- .gitignore | 3 ++- README.md | 2 +- src/eclipse/__init__.py | 1 + 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 2d799fd..d3c4be1 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ ner_model_bert *.pyc .DS_Store eclipse_ai.egg-info -dist \ No newline at end of file +dist +local_build.md \ No newline at end of file diff --git a/README.md b/README.md index bcf50c8..577b775 100644 --- a/README.md +++ b/README.md @@ -151,7 +151,7 @@ Additional Options ## Usage as a module ```python -from eclipse import process_text # Replace 'your_script_name' with the actual name of the script without '.py' +from eclipse import process_text # Set the path to the pretrained BERT model. This should be the same as DEFAULT_MODEL_PATH in the script model_path = "./ner_model_bert" diff --git a/src/eclipse/__init__.py b/src/eclipse/__init__.py index e69de29..b0de1a4 100644 --- a/src/eclipse/__init__.py +++ b/src/eclipse/__init__.py @@ -0,0 +1 @@ +from .eclipse import process_text \ No newline at end of file From 6e3db40bc3538c7125784cccea806e69e056abbb Mon Sep 17 00:00:00 2001 From: David I Date: Wed, 6 Mar 2024 12:19:12 -0500 Subject: [PATCH 02/10] adding option to set model download directory --- src/eclipse/eclipse.py | 35 +++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/src/eclipse/eclipse.py b/src/eclipse/eclipse.py index 178dd9a..7681d15 100644 --- a/src/eclipse/eclipse.py +++ b/src/eclipse/eclipse.py @@ -174,42 +174,42 @@ def save_local_metadata(file_name, etag): json.dump({"etag": etag}, f) -def ensure_model_folder_exists(): - metadata_file = os.path.join(DEFAULT_MODEL_PATH, "metadata.json") +def ensure_model_folder_exists(model_directory): + metadata_file = os.path.join(model_directory, "metadata.json") local_etag = get_local_metadata(metadata_file) s3_etag = get_s3_file_etag(s3_url) if s3_etag is None: return # Exit if there's no internet connection or other issues with S3 # Check if the model directory exists and has the same etag (metadata) - if folder_exists_and_not_empty(DEFAULT_MODEL_PATH) and local_etag == s3_etag: - cprint(f"Model directory {DEFAULT_MODEL_PATH} is up-to-date.", "green") + if folder_exists_and_not_empty(model_directory) and local_etag == s3_etag: + cprint(f"Model directory {model_directory} is up-to-date.", "green") return # No need to update anything as local version matches S3 version # If folder doesn't exist, is empty, or etag doesn't match, prompt for download. - user_input = "y" - if local_etag and local_etag != s3_etag: - user_input = get_input_with_default( - "New versions of the models are available, would you like to download them? (y/n) " - ) + user_input = get_input_with_default( + "New versions of the models are available, would you like to download them? (y/n) ", + default_text="y" # Automatically opt for download if not specified otherwise + ) - if user_input.lower() != "y": + if user_input.lower() != 'y': return # Exit if user chooses not to update # Logic to remove the model directory if it exists - if os.path.exists(DEFAULT_MODEL_PATH): + if os.path.exists(model_directory): cprint("Removing existing model folder...", "yellow") - shutil.rmtree(DEFAULT_MODEL_PATH) + shutil.rmtree(model_directory) cprint( - f"{DEFAULT_MODEL_PATH} not found or is different. Downloading and unzipping...", + f"{model_directory} not found or is outdated. Downloading and unzipping...", "yellow", ) - download_and_unzip(s3_url, f"{DEFAULT_MODEL_PATH}.zip") + download_and_unzip(s3_url, f"{model_directory}.zip") # Save new metadata save_local_metadata(metadata_file, s3_etag) + def is_internet_available(host="8.8.8.8", port=53, timeout=3): """Check if there is an internet connection.""" try: @@ -356,6 +356,13 @@ def main(): action="store_true", help="Enable GPU usage for model inference.", ) + parser.add_argument( + "-dir", "--model_directory", + type=str, + default=DEFAULT_MODEL_PATH, + help="Directory where the BERT model should be downloaded and unzipped." +) + args = parser.parse_args() # Determine whether to use the GPU or not based on the user's command line input From 1da48dd211555ee84d12e5c258be3e4a1d8d45fb Mon Sep 17 00:00:00 2001 From: David I Date: Wed, 6 Mar 2024 15:05:11 -0500 Subject: [PATCH 03/10] adding support for loading the model only once --- README.md | 60 +++++-- src/eclipse/__init__.py | 1 - src/eclipse/eclipse.py | 381 +++++++++++++++++++++++++++------------- 3 files changed, 304 insertions(+), 138 deletions(-) diff --git a/README.md b/README.md index 577b775..d677210 100644 --- a/README.md +++ b/README.md @@ -115,6 +115,7 @@ pip install eclipse-ai --upgrade ``` bash usage: eclipse [-h] [-p PROMPT] [-f FILE] [-m MODEL_PATH] [-o OUTPUT] [--debug] [-d DELIMITER] [-g] + [-dir MODEL_DIRECTORY] Entity recognition using BERT. @@ -131,7 +132,8 @@ options: -d DELIMITER, --delimiter DELIMITER Delimiter to separate text inputs, defaults to newline. -g, --use_gpu Enable GPU usage for model inference. - + -dir MODEL_DIRECTORY, --model_directory MODEL_DIRECTORY + Directory where the BERT model should be downloaded and unzipped. ``` Here are some examples: @@ -151,22 +153,54 @@ Additional Options ## Usage as a module ```python -from eclipse import process_text - -# Set the path to the pretrained BERT model. This should be the same as DEFAULT_MODEL_PATH in the script -model_path = "./ner_model_bert" +from eclipse import process_text # Make sure this is the correct import -# Example text to process +model_path = "./ner_model_bert" input_text = "Your example text here." -# Process the text -# The 'device' argument is either 'cpu' or 'cuda' depending on whether you are using CPU or GPU -processed_text, highest_avg_label, highest_avg_confidence, is_high_confidence = process_text(input_text, model_path, 'cpu') +line_by_line = True # Change this to False if you want to process the whole text at once + +try: + # Handle both line-by-line processing and whole text processing + if line_by_line: + # Process the text line by line + for result in process_text(input_text, model_path, 'cpu', line_by_line=True): + if result: + processed_text, highest_avg_label, highest_avg_confidence, is_high_confidence = result + print(f"Processed Text: {processed_text}") + print(f"Highest Average Label: {highest_avg_label}") + print(f"Highest Average Confidence: {highest_avg_confidence}") + print(f"Is High Confidence: {is_high_confidence}") + else: + # If result is empty (which should not happen in line-by-line mode), assign default values + print("Processed Text: Error in processing") + print("Highest Average Label: Error") + print("Highest Average Confidence: Error") + print("Is High Confidence: Error") + else: + # If line_by_line is set to False, expecting a single result + result = process_text(input_text, model_path, 'cpu', line_by_line=False) + if result: # Checking if result is not empty + processed_text, highest_avg_label, highest_avg_confidence, is_high_confidence = result + print(f"Processed Text: {processed_text}") + print(f"Highest Average Label: {highest_avg_label}") + print(f"Highest Average Confidence: {highest_avg_confidence}") + print(f"Is High Confidence: {is_high_confidence}") + else: + # If result is empty, assign default values + print("Processed Text: Error in processing") + print("Highest Average Label: Error") + print("Highest Average Confidence: Error") + print("Is High Confidence: Error") + +except Exception as e: # Catching general exceptions + print(f"Error processing text: {e}") + # Default error handling values + print("Processed Text: Error in processing") + print("Highest Average Label: Error") + print("Highest Average Confidence: Error") + print("Is High Confidence: Error") -print(f"Processed Text: {processed_text}") -print(f"Highest Average Label: {highest_avg_label}") -print(f"Highest Average Confidence: {highest_avg_confidence}") -print(f"Is High Confidence: {is_high_confidence}") ``` ## Understanding the Output diff --git a/src/eclipse/__init__.py b/src/eclipse/__init__.py index b0de1a4..e69de29 100644 --- a/src/eclipse/__init__.py +++ b/src/eclipse/__init__.py @@ -1 +0,0 @@ -from .eclipse import process_text \ No newline at end of file diff --git a/src/eclipse/eclipse.py b/src/eclipse/eclipse.py index 7681d15..05e5285 100644 --- a/src/eclipse/eclipse.py +++ b/src/eclipse/eclipse.py @@ -6,7 +6,7 @@ import shutil import socket import subprocess -from collections import defaultdict +import warnings from importlib.metadata import version from typing import List, Set, Tuple from zipfile import ZipFile @@ -16,10 +16,42 @@ from prompt_toolkit import prompt from prompt_toolkit.history import InMemoryHistory from prompt_toolkit.styles import Style -from termcolor import cprint from transformers import BertForTokenClassification, BertTokenizerFast -logging.getLogger("transformers").setLevel(logging.ERROR) +# Suppress specific warning from transformers +warnings.filterwarnings( + "ignore", + message="Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.", +) + + +# Configure basic logging +# This will set the log level to ERROR, meaning only error and critical messages will be logged +# You can specify a filename to write the logs to a file; otherwise, it will log to stderr +log_file_path = os.path.join(os.path.expanduser("~"), "eclipse.log") + +# Configure basic logging +# Get the user's home directory from the HOME environment variable +home_directory = os.getenv("HOME") # This returns None if 'HOME' is not set + +if home_directory: + log_file_path = os.path.join(home_directory, "eclipse.log") +else: + # Fallback mechanism or throw an error + log_file_path = ( + "eclipse.log" # Default to current directory, or handle error as needed + ) + +# Configure basic logging +logging.basicConfig( + filename=log_file_path, + level=logging.ERROR, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) + + +# Test logging at different levels + s3_url = "https://nebula-models.s3.amazonaws.com/ner_model_bert.zip" # Update this with your actual S3 URL # Define the label mappings @@ -35,9 +67,33 @@ DEFAULT_MODEL_PATH = "./ner_model_bert" +class ModelManager: + instance = None + + class __ModelManager: + def __init__(self, model_path, device): + self.device = torch.device( + "cuda" if torch.cuda.is_available() and device == "cuda" else "cpu" + ) + self.tokenizer = BertTokenizerFast.from_pretrained(model_path) + self.model = BertForTokenClassification.from_pretrained(model_path) + self.model.config.id2label = id_to_label + self.model.config.label2id = label_to_id + self.model.to(self.device) + self.model.eval() + + @staticmethod + def get_instance(model_path=DEFAULT_MODEL_PATH, device="cpu"): + if ModelManager.instance is None: + ModelManager.instance = ModelManager.__ModelManager(model_path, device) + return ModelManager.instance + + def get_s3_file_etag(s3_url): if not is_internet_available(): - cprint("No internet connection available. Skipping version check.", "red") + logging.error( + "No internet connection available. Skipping version check.", "red" + ) return None response = requests.head(s3_url) return response.headers.get("ETag") @@ -63,39 +119,45 @@ def get_latest_pypi_version(package_name): if response.status_code == 200: return response.json()["info"]["version"] except requests.exceptions.RequestException as e: - cprint(f"Failed to get latest version information: {e}", "red") + logging.error(f"Failed to get latest version information: {e}", "red") return None def check_new_pypi_version(package_name="eclipse-ai"): """Check if a newer version of the package is available on PyPI.""" if not is_internet_available(): - cprint("No internet connection available. Skipping version check.", "red") + logging.error( + "No internet connection available. Skipping version check.", "red" + ) return try: installed_version = version(package_name) except Exception as e: - cprint(f"Error retrieving installed version of {package_name}: {e}", "red") + logging.error( + f"Error retrieving installed version of {package_name}: {e}", "red" + ) return - cprint(f"Installed version: {installed_version}", "green") + logging.info(f"Installed version: {installed_version}", "green") try: latest_version = get_latest_pypi_version(package_name) if latest_version is None: - cprint( + logging.error( f"Error retrieving latest version of {package_name} from PyPI.", "red" ) return if latest_version > installed_version: - cprint( + logging.info( f"A newer version ({latest_version}) of {package_name} is available on PyPI. Please consider updating to access the latest features!", "yellow", ) except Exception as e: - cprint(f"An error occurred while checking for the latest version: {e}", "red") + logging.error( + f"An error occurred while checking for the latest version: {e}", "red" + ) def get_input_with_default(message, default_text=None): @@ -121,7 +183,7 @@ def folder_exists_and_not_empty(folder_path): def download_and_unzip(url, output_name): try: # Download the file from the S3 bucket using wget with progress bar - print("Downloading...") + logging.info("Downloading...") subprocess.run( ["wget", "--progress=bar:force:noscroll", url, "-O", output_name], check=True, @@ -131,7 +193,7 @@ def download_and_unzip(url, output_name): target_dir = os.path.splitext(output_name)[0] # Removes '.zip' from output_name # Extract the ZIP file - print("\nUnzipping...") + logging.info("\nUnzipping...") with ZipFile(output_name, "r") as zip_ref: # Here we will extract in a temp directory to inspect the structure temp_dir = "temp_extract_dir" @@ -162,10 +224,10 @@ def download_and_unzip(url, output_name): # Remove the ZIP file to clean up os.remove(output_name) except subprocess.CalledProcessError as e: - cprint(f"Error occurred during download: {e}", "red") + logging.error(f"Error occurred during download: {e}", "red") logging.error(f"Error occurred during download: {e}") except Exception as e: - cprint(f"Unexpected error: {e}", "red") + logging.error(f"Unexpected error: {e}", "red") logging.error(f"Unexpected error: {e}") @@ -174,7 +236,7 @@ def save_local_metadata(file_name, etag): json.dump({"etag": etag}, f) -def ensure_model_folder_exists(model_directory): +def ensure_model_folder_exists(model_directory, auto_update=True): metadata_file = os.path.join(model_directory, "metadata.json") local_etag = get_local_metadata(metadata_file) s3_etag = get_s3_file_etag(s3_url) @@ -183,24 +245,29 @@ def ensure_model_folder_exists(model_directory): # Check if the model directory exists and has the same etag (metadata) if folder_exists_and_not_empty(model_directory) and local_etag == s3_etag: - cprint(f"Model directory {model_directory} is up-to-date.", "green") + logging.info(f"Model directory {model_directory} is up-to-date.", "green") return # No need to update anything as local version matches S3 version - # If folder doesn't exist, is empty, or etag doesn't match, prompt for download. - user_input = get_input_with_default( - "New versions of the models are available, would you like to download them? (y/n) ", - default_text="y" # Automatically opt for download if not specified otherwise - ) + if not auto_update: + # If folder doesn't exist, is empty, or etag doesn't match, prompt for download. + user_input = get_input_with_default( + "New versions of the models are available, would you like to download them? (y/n) ", + default_text="y", # Automatically opt for download if not specified otherwise + ) - if user_input.lower() != 'y': - return # Exit if user chooses not to update + if user_input.lower() != "y": + return # Exit if user chooses not to update + else: + logging.info( + "Auto-update is enabled. Downloading new version if necessary...", "yellow" + ) - # Logic to remove the model directory if it exists + # Proceed with the removal of the existing model directory and the download of the new version if os.path.exists(model_directory): - cprint("Removing existing model folder...", "yellow") + logging.info("Removing existing model folder...", "yellow") shutil.rmtree(model_directory) - cprint( + logging.info( f"{model_directory} not found or is outdated. Downloading and unzipping...", "yellow", ) @@ -209,7 +276,6 @@ def ensure_model_folder_exists(model_directory): save_local_metadata(metadata_file, s3_etag) - def is_internet_available(host="8.8.8.8", port=53, timeout=3): """Check if there is an internet connection.""" try: @@ -259,64 +325,95 @@ def recognize_entities_bert( detected_labels = {label for label in predictions_labels if label != "O"} return detected_labels, predictions_labels, confidence_list, average_confidence except Exception as e: - print(f"An error occurred in recognize_entities_bert: {e}") + logging.error(f"An error occurred in recognize_entities_bert: {e}") # Return empty sets and lists if an error occurs return set(), [], [], 0.0 def process_text( - input_text: str, model_path: str, device: str -) -> Tuple[str, defaultdict, float, bool]: - try: - device = torch.device( - "cuda" if torch.cuda.is_available() and device == "cuda" else "cpu" - ) - tokenizer = BertTokenizerFast.from_pretrained(model_path) - model = BertForTokenClassification.from_pretrained(model_path) - - # Ensure the model uses the correct label mappings - model.config.id2label = id_to_label - model.config.label2id = label_to_id - - model.to(device) - model.eval() - - ( - unique_labels_detected, - labels_detected, - confidences, - avg_confidence, - ) = recognize_entities_bert(input_text, model, tokenizer, device) - - # Collate labels and their corresponding confidences - label_confidences = defaultdict(list) - for label, conf in zip(labels_detected, confidences): - label_confidences[label].append(conf) - - # Determine the label with the highest average confidence - highest_avg_label, highest_avg_conf = max( - label_confidences.items(), - key=lambda lc: sum(lc[1]) / len(lc[1]), - default=("BENIGN", [avg_confidence]), - ) + input_text: str, model_path: str, device: str, line_by_line: bool = False +): + # Ensure the model folder exists and is up to date + ensure_model_folder_exists(model_path) + # Get the singleton instance of the model manager for the specified model path and device + model_manager = ModelManager.get_instance(model_path, device) + + # Define an inner function for processing a single line + def process_single(input_line): + try: + # Process the single line using the shared model and tokenizer + return process_single_line(input_line, model_manager, device) + except Exception as e: + logging.error(f"An error occurred while processing single line: {e}") + return input_line, "Error", [0], False + + # Define an inner generator function for processing line by line + def process_multiple(input_lines): + for line in input_lines.split("\n"): + try: + # Yield results from processing each line individually + yield process_single_line(line, model_manager, device) + except Exception as e: + logging.error( + f"An error occurred while processing line: {line}, Error: {e}" + ) + yield line, "Error", [0], False - # Determine if non-'BENIGN' labels exist and average confidence is above threshold - non_benign_high_conf = ( - "BENIGN" not in highest_avg_label and avg_confidence > 0.80 + # Process the input text based on the line_by_line flag + if line_by_line: + return process_multiple(input_text) # This returns a generator + else: + return process_single(input_text) # This returns a single tuple + + +def process_single_line(line: str, model_manager, device): + # Access the tokenizer and model from the model manager + tokenizer = model_manager.tokenizer + model = model_manager.model + device = model_manager.device # Use device from the model manager + + # Tokenize the input line + tokenized_inputs = tokenizer( + line, truncation=True, padding=True, max_length=512, return_tensors="pt" + ) + tokenized_inputs = tokenized_inputs.to(device) + + # Perform model inference + with torch.no_grad(): + outputs = model(**tokenized_inputs) + + # Process the model outputs + logits = outputs.logits + softmax = torch.nn.functional.softmax(logits, dim=-1) + confidence_scores, predictions = torch.max(softmax, dim=2) + average_confidence = confidence_scores.mean().item() + + # Convert model predictions to labels + predictions_labels = [ + model_manager.model.config.id2label.get(pred.item(), "O") + for pred in predictions[0] + ] + detected_labels = [label for label in predictions_labels if label != "O"] + + # Find the most frequent label among detected labels, if any + if detected_labels: + highest_avg_label = max(set(detected_labels), key=detected_labels.count) + highest_avg_conf = ( + confidence_scores[0][ + predictions[0] == model_manager.model.config.label2id[highest_avg_label] + ] + .mean() + .item() ) - return input_text, highest_avg_label, highest_avg_conf, non_benign_high_conf - except Exception as e: - print(f"An error occurred in process_text: {e}") - # Return empty results and false for high confidence if an error occurs - return input_text, defaultdict(list), 0.0, False + else: + highest_avg_label = "None" # Use 'None' if no entity detected + highest_avg_conf = 0.0 + # Return the processed line information + return line, highest_avg_label, highest_avg_conf, average_confidence > 0.80 -def main(): - # Check for new PyPI package version - check_new_pypi_version() - # Ensure the model folder exists and is updated - ensure_model_folder_exists() +def main(): parser = argparse.ArgumentParser(description="Entity recognition using BERT.") parser.add_argument( "-p", "--prompt", type=str, help="Direct text prompt for recognizing entities." @@ -357,62 +454,98 @@ def main(): help="Enable GPU usage for model inference.", ) parser.add_argument( - "-dir", "--model_directory", - type=str, - default=DEFAULT_MODEL_PATH, - help="Directory where the BERT model should be downloaded and unzipped." -) - + "-dir", + "--model_directory", + type=str, + default=DEFAULT_MODEL_PATH, + help="Directory where the BERT model should be downloaded and unzipped.", + ) + parser.add_argument( + "--line_by_line", + action="store_true", + help="Process text line by line and yield results incrementally.", + ) args = parser.parse_args() + # Early exit if only displaying help + if not any([args.prompt, args.file]): + parser.print_help() + return + + # Now we ensure the model folder exists if needed + if args.prompt or args.file: + ensure_model_folder_exists(args.model_directory, auto_update=True) + # Determine whether to use the GPU or not based on the user's command line input device = "cuda" if args.use_gpu and torch.cuda.is_available() else "cpu" - if args.prompt: - line, highest_avg_label, highest_avg_conf, high_conf = process_text( - args.prompt, args.model_path, device - ) - # Print results with the highest average label and its confidence - print( - f"{line}: {highest_avg_label} (Highest Avg. Conf.: {sum(highest_avg_conf)/len(highest_avg_conf):.2f})" - ) + if args.line_by_line: + # If line-by-line mode is enabled, iterate over generator + for ( + processed_text, + highest_avg_label, + highest_avg_confidence, + is_high_confidence, + ) in process_text(args.prompt, args.model_path, device, True): + print(f"Processed Text: {processed_text}") + print(f"Highest Average Label: {highest_avg_label}") + print(f"Highest Average Confidence: {highest_avg_confidence}") + print(f"Is High Confidence: {is_high_confidence}") + else: + # Process the entire text as a single block + ( + processed_text, + highest_avg_label, + highest_avg_confidence, + is_high_confidence, + ) = process_text(args.prompt, args.model_path, device, False) + print(f"Processed Text: {processed_text}") + print(f"Highest Average Label: {highest_avg_label}") + print(f"Highest Average Confidence: {highest_avg_confidence}") + print(f"Is High Confidence: {is_high_confidence}") elif args.file: - try: - results = [] - with open(args.file, "r") as file: - file_content = file.read() - lines = ( - file_content.split(args.delimiter) - if args.delimiter != "\n" - else file_content.splitlines() - ) - for line in lines: - line = line.strip() - if line: # Avoid processing empty lines - results.append(process_text(line, args.model_path, device)) - - with open(args.output, "w") as html_file: - html_file.write("\n") - for line, highest_avg_label, highest_avg_conf, high_conf in results: - debug_info = "" - if args.debug: - debug_info = f" (Highest Avg. Label: {highest_avg_label}, Highest Avg. Conf.: {sum(highest_avg_conf)/len(highest_avg_conf):.2f})" - - colored_line = html.escape(line) - if high_conf: - colored_line = ( - f"{colored_line}" - ) - html_file.write(f"{colored_line}{debug_info}
\n") - html_file.write("") - print(f"Output written to {args.output}") - except FileNotFoundError: - print(f"The file {args.file} was not found.") - else: - print( - "No input provided. Please use --prompt or --file to provide input text for entity recognition." + # Adapt file processing as needed, similar to the prompt handling + process_file( + args.file, + args.model_path, + device, + args.output, + args.debug, + args.delimiter, + args.line_by_line, ) +def process_file(file_path, model_path, device, output_path, debug, delimiter): + try: + results = [] + with open(file_path, "r") as file: + file_content = file.read() + lines = ( + file_content.split(delimiter) + if delimiter != "\n" + else file_content.splitlines() + ) + for line in lines: + line = line.strip() + if line: # Avoid processing empty lines + results.append(process_text(line, model_path, device)) + + with open(output_path, "w") as html_file: + html_file.write("\n") + for line, highest_avg_label, highest_avg_conf, high_conf in results: + debug_info = "" + if debug: + debug_info = f" (Highest Avg. Label: {highest_avg_label}, Highest Avg. Conf.: {sum(highest_avg_conf)/len(highest_avg_conf):.2f})" + colored_line = html.escape(line) + if high_conf: + colored_line = f"{colored_line}" + html_file.write(f"{colored_line}{debug_info}
\n") + html_file.write("") + logging.info(f"Output written to {output_path}") + except FileNotFoundError: + logging.info(f"The file {file_path} was not found.") + + if __name__ == "__main__": main() From 2f1ac3a62a4109b80155b1477a12ed8c5e098a3d Mon Sep 17 00:00:00 2001 From: David I Date: Wed, 6 Mar 2024 16:19:27 -0500 Subject: [PATCH 04/10] ignoring warnings --- src/eclipse/eclipse.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/eclipse/eclipse.py b/src/eclipse/eclipse.py index 05e5285..c8411d4 100644 --- a/src/eclipse/eclipse.py +++ b/src/eclipse/eclipse.py @@ -20,11 +20,10 @@ # Suppress specific warning from transformers warnings.filterwarnings( - "ignore", + action="ignore", message="Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.", + category=UserWarning ) - - # Configure basic logging # This will set the log level to ERROR, meaning only error and critical messages will be logged # You can specify a filename to write the logs to a file; otherwise, it will log to stderr From ec1117be0181e7dc073377c19c96c677a7d486ea Mon Sep 17 00:00:00 2001 From: dige-mothership Date: Wed, 6 Mar 2024 16:27:49 -0500 Subject: [PATCH 05/10] formatting and removing uneeded error suppression --- .DS_Store | Bin 6148 -> 6148 bytes src/eclipse/eclipse.py | 9 +-------- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/.DS_Store b/.DS_Store index a56fb214ee9bd0c6482cf1ccc77ae52f70b1fab8..7752e80edca84d7509e3ae887e3829a9ac336c5f 100644 GIT binary patch delta 129 zcmZoMXffEJ#u9hQmVtqRg+Y%YogtH(Y+N*I0v?rPv%cYqQ-H1OWW5BbopJ diff --git a/src/eclipse/eclipse.py b/src/eclipse/eclipse.py index c8411d4..180f346 100644 --- a/src/eclipse/eclipse.py +++ b/src/eclipse/eclipse.py @@ -6,7 +6,6 @@ import shutil import socket import subprocess -import warnings from importlib.metadata import version from typing import List, Set, Tuple from zipfile import ZipFile @@ -18,12 +17,6 @@ from prompt_toolkit.styles import Style from transformers import BertForTokenClassification, BertTokenizerFast -# Suppress specific warning from transformers -warnings.filterwarnings( - action="ignore", - message="Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.", - category=UserWarning -) # Configure basic logging # This will set the log level to ERROR, meaning only error and critical messages will be logged # You can specify a filename to write the logs to a file; otherwise, it will log to stderr @@ -74,7 +67,7 @@ def __init__(self, model_path, device): self.device = torch.device( "cuda" if torch.cuda.is_available() and device == "cuda" else "cpu" ) - self.tokenizer = BertTokenizerFast.from_pretrained(model_path) + self.model = BertForTokenClassification.from_pretrained(model_path) self.model.config.id2label = id_to_label self.model.config.label2id = label_to_id From a2a1403f27a2adaf559978d58f4f53e79f668876 Mon Sep 17 00:00:00 2001 From: dige-mothership Date: Wed, 6 Mar 2024 16:46:53 -0500 Subject: [PATCH 06/10] fixing file output --- .DS_Store | Bin 6148 -> 8196 bytes src/eclipse/eclipse.py | 13 ++++++++----- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/.DS_Store b/.DS_Store index 7752e80edca84d7509e3ae887e3829a9ac336c5f..e13aa0ed12c261275650f92ddbabcd0687f9b123 100644 GIT binary patch literal 8196 zcmeHMOKTHR6h60&>BQ2Sg+i$a18!WDHff|NZqhUr3hJUG3O-0PNlaqr5%Oq|P{=|z z?aDtOZd{5k-AhH$wHrZyhb#T=BkfGmsXHm;o^bCsx#yg@=X_@}! z?6L~`5~Vva)qF6uGgBQ36T9R5rZ6X#m8fUUfM#Hn0p7dEXpyQ^r*^1+U%9vz)VY2> zZ#%w?HDmi`>h1owgVnIUe4@Tl3%nc1ZNZhM2K!QtI%J1-r`aE)7Qq-EEXQB$$ry&t zlYw?c|?o4smy2!B#&(@IbyOV4~ ze8k^u4xcN?!k6cVRSNAFtWlIrDC78h9`jkcMH{q6>vV@K%!|mN9FCpwD>|6xEaw8F z31*%m*Alhi*Ty%OoFU)+A>RaTVwY9Q)9r}eeh+yVs6hp1?_kGcsxVOBromQ^ce8|7u^?k=4 z92f0-^-8ypJiq39?PA@lG=)RGT!G1>`*pAEcXNKnE4PHNB^5@(NEDKDyv7R?~BQk5^z29oK5n42&lOlQQi&-v6&hzyBXk za&*?3fiYo##qw4@hh(E3i@eu1F}5(caJ#-l;etuD<3Q1l0}uZ&#IXre?vqv6mxvX} Re}566pZ|h)J&iX5e*u>KL=^x4 delta 305 zcmZp1XfcprU|?W$DortDU=RQ@Ie-{Mvv5sJ6q~50$SAlmU^g?P;A9>F$#}g>wje16 zJ%)6KOoo!&d>5Cboctt^Hdoz~$@5MhbwrgoVHlj8pIZRb!@yuZfnlNKkZ=Y0dSl{u=E?jro}e&b RV1k4PD2Ny~$Mei#1^{o?MOOd- diff --git a/src/eclipse/eclipse.py b/src/eclipse/eclipse.py index 180f346..2e7a665 100644 --- a/src/eclipse/eclipse.py +++ b/src/eclipse/eclipse.py @@ -497,6 +497,7 @@ def main(): print(f"Is High Confidence: {is_high_confidence}") elif args.file: # Adapt file processing as needed, similar to the prompt handling + process_file( args.file, args.model_path, @@ -504,7 +505,6 @@ def main(): args.output, args.debug, args.delimiter, - args.line_by_line, ) @@ -521,14 +521,19 @@ def process_file(file_path, model_path, device, output_path, debug, delimiter): for line in lines: line = line.strip() if line: # Avoid processing empty lines + # Assume process_text is a function that processes the text and returns a tuple + # containing the processed line, highest average label, highest average confidence, + # and a boolean indicating if the confidence is high. results.append(process_text(line, model_path, device)) with open(output_path, "w") as html_file: - html_file.write("\n") - for line, highest_avg_label, highest_avg_conf, high_conf in results: + html_file.write("Processed Output\n") + for result in results: + line, highest_avg_label, highest_avg_conf, high_conf = result debug_info = "" if debug: debug_info = f" (Highest Avg. Label: {highest_avg_label}, Highest Avg. Conf.: {sum(highest_avg_conf)/len(highest_avg_conf):.2f})" + # Escape the line to convert any HTML special characters to their equivalent entities colored_line = html.escape(line) if high_conf: colored_line = f"{colored_line}" @@ -537,7 +542,5 @@ def process_file(file_path, model_path, device, output_path, debug, delimiter): logging.info(f"Output written to {output_path}") except FileNotFoundError: logging.info(f"The file {file_path} was not found.") - - if __name__ == "__main__": main() From 0b77df98bf14e0ade7c9e3b6609182341f34d936 Mon Sep 17 00:00:00 2001 From: David I Date: Wed, 6 Mar 2024 17:23:48 -0500 Subject: [PATCH 07/10] updating readme and logging --- README.md | 51 +++++++++++++++++++++--------------------- src/eclipse/eclipse.py | 8 ++++++- 2 files changed, 32 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index d677210..14c7608 100644 --- a/README.md +++ b/README.md @@ -114,8 +114,7 @@ pip install eclipse-ai --upgrade ## Usage. ``` bash -usage: eclipse [-h] [-p PROMPT] [-f FILE] [-m MODEL_PATH] [-o OUTPUT] [--debug] [-d DELIMITER] [-g] - [-dir MODEL_DIRECTORY] +usage: eclipse [-h] [-p PROMPT] [-f FILE] [-m MODEL_PATH] [-o OUTPUT] [--debug] [-d DELIMITER] [-g] [-dir MODEL_DIRECTORY] [--line_by_line] Entity recognition using BERT. @@ -134,6 +133,7 @@ options: -g, --use_gpu Enable GPU usage for model inference. -dir MODEL_DIRECTORY, --model_directory MODEL_DIRECTORY Directory where the BERT model should be downloaded and unzipped. + --line_by_line Process text line by line and yield results incrementally. ``` Here are some examples: @@ -153,54 +153,53 @@ Additional Options ## Usage as a module ```python -from eclipse import process_text # Make sure this is the correct import +# Correct import based on your project structure +from eclipse import process_text model_path = "./ner_model_bert" input_text = "Your example text here." -line_by_line = True # Change this to False if you want to process the whole text at once +# Set this to True if you want to process the text line by line, or False to process all at once +line_by_line = False try: # Handle both line-by-line processing and whole text processing if line_by_line: # Process the text line by line - for result in process_text(input_text, model_path, 'cpu', line_by_line=True): + for result in process_text(input_text, model_path, "cpu", line_by_line=False): + # In line-by-line mode, result should not be None, but check to be safe if result: - processed_text, highest_avg_label, highest_avg_confidence, is_high_confidence = result + ( + processed_text, + highest_avg_label, + highest_avg_confidence, + is_high_confidence, + ) = result print(f"Processed Text: {processed_text}") print(f"Highest Average Label: {highest_avg_label}") print(f"Highest Average Confidence: {highest_avg_confidence}") print(f"Is High Confidence: {is_high_confidence}") else: - # If result is empty (which should not happen in line-by-line mode), assign default values - print("Processed Text: Error in processing") - print("Highest Average Label: Error") - print("Highest Average Confidence: Error") - print("Is High Confidence: Error") + print("Error: Empty result for a line.") else: - # If line_by_line is set to False, expecting a single result - result = process_text(input_text, model_path, 'cpu', line_by_line=False) - if result: # Checking if result is not empty - processed_text, highest_avg_label, highest_avg_confidence, is_high_confidence = result + # Process the entire text as a single block + result = process_text(input_text, model_path, "cpu", line_by_line=False) + if result: + ( + processed_text, + highest_avg_label, + highest_avg_confidence, + is_high_confidence, + ) = result print(f"Processed Text: {processed_text}") print(f"Highest Average Label: {highest_avg_label}") print(f"Highest Average Confidence: {highest_avg_confidence}") print(f"Is High Confidence: {is_high_confidence}") else: - # If result is empty, assign default values - print("Processed Text: Error in processing") - print("Highest Average Label: Error") - print("Highest Average Confidence: Error") - print("Is High Confidence: Error") + print("Error: Empty result for the text.") except Exception as e: # Catching general exceptions print(f"Error processing text: {e}") - # Default error handling values - print("Processed Text: Error in processing") - print("Highest Average Label: Error") - print("Highest Average Confidence: Error") - print("Is High Confidence: Error") - ``` ## Understanding the Output diff --git a/src/eclipse/eclipse.py b/src/eclipse/eclipse.py index 2e7a665..3400451 100644 --- a/src/eclipse/eclipse.py +++ b/src/eclipse/eclipse.py @@ -11,12 +11,14 @@ from zipfile import ZipFile import requests +import transformers import torch from prompt_toolkit import prompt from prompt_toolkit.history import InMemoryHistory from prompt_toolkit.styles import Style from transformers import BertForTokenClassification, BertTokenizerFast +transformers.logging.set_verbosity_error() # Configure basic logging # This will set the log level to ERROR, meaning only error and critical messages will be logged # You can specify a filename to write the logs to a file; otherwise, it will log to stderr @@ -73,6 +75,8 @@ def __init__(self, model_path, device): self.model.config.label2id = label_to_id self.model.to(self.device) self.model.eval() + self.tokenizer = BertTokenizerFast.from_pretrained(model_path) + @staticmethod def get_instance(model_path=DEFAULT_MODEL_PATH, device="cpu"): @@ -485,6 +489,7 @@ def main(): print(f"Is High Confidence: {is_high_confidence}") else: # Process the entire text as a single block + ( processed_text, highest_avg_label, @@ -532,7 +537,8 @@ def process_file(file_path, model_path, device, output_path, debug, delimiter): line, highest_avg_label, highest_avg_conf, high_conf = result debug_info = "" if debug: - debug_info = f" (Highest Avg. Label: {highest_avg_label}, Highest Avg. Conf.: {sum(highest_avg_conf)/len(highest_avg_conf):.2f})" + debug_info = f" (Highest Avg. Label: {highest_avg_label}, Highest Avg. Conf.: {highest_avg_conf:.2f})" + # Escape the line to convert any HTML special characters to their equivalent entities colored_line = html.escape(line) if high_conf: From 2ee816bff4441216a8fc25fc52d811b9f5fd8685 Mon Sep 17 00:00:00 2001 From: David I Date: Wed, 6 Mar 2024 17:33:58 -0500 Subject: [PATCH 08/10] updating init --- src/eclipse/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/eclipse/__init__.py b/src/eclipse/__init__.py index e69de29..b0de1a4 100644 --- a/src/eclipse/__init__.py +++ b/src/eclipse/__init__.py @@ -0,0 +1 @@ +from .eclipse import process_text \ No newline at end of file From d11cd8d3538667d10dc454cc26abb8b0bf72ba22 Mon Sep 17 00:00:00 2001 From: David I Date: Wed, 6 Mar 2024 18:54:06 -0500 Subject: [PATCH 09/10] formatting --- src/eclipse/__init__.py | 2 +- src/eclipse/eclipse.py | 42 +++++++++++++++++++++++++++++++++-------- 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/src/eclipse/__init__.py b/src/eclipse/__init__.py index b0de1a4..c5076b4 100644 --- a/src/eclipse/__init__.py +++ b/src/eclipse/__init__.py @@ -1 +1 @@ -from .eclipse import process_text \ No newline at end of file +from .eclipse import process_text diff --git a/src/eclipse/eclipse.py b/src/eclipse/eclipse.py index 3400451..e227b8f 100644 --- a/src/eclipse/eclipse.py +++ b/src/eclipse/eclipse.py @@ -11,8 +11,8 @@ from zipfile import ZipFile import requests -import transformers import torch +import transformers from prompt_toolkit import prompt from prompt_toolkit.history import InMemoryHistory from prompt_toolkit.styles import Style @@ -77,7 +77,6 @@ def __init__(self, model_path, device): self.model.eval() self.tokenizer = BertTokenizerFast.from_pretrained(model_path) - @staticmethod def get_instance(model_path=DEFAULT_MODEL_PATH, device="cpu"): if ModelManager.instance is None: @@ -327,7 +326,11 @@ def recognize_entities_bert( def process_text( - input_text: str, model_path: str, device: str, line_by_line: bool = False + input_text: str, + model_path: str, + device: str, + line_by_line: bool = False, + confidence_threshold: float = 0.80, ): # Ensure the model folder exists and is up to date ensure_model_folder_exists(model_path) @@ -362,7 +365,11 @@ def process_multiple(input_lines): return process_single(input_text) # This returns a single tuple -def process_single_line(line: str, model_manager, device): +def process_single_line( + line: str, model_manager, device, confidence_threshold: float = 0.80 +): + # Rest of the function remains unchanged + # Access the tokenizer and model from the model manager tokenizer = model_manager.tokenizer model = model_manager.model @@ -406,7 +413,12 @@ def process_single_line(line: str, model_manager, device): highest_avg_conf = 0.0 # Return the processed line information - return line, highest_avg_label, highest_avg_conf, average_confidence > 0.80 + return ( + line, + highest_avg_label, + highest_avg_conf, + average_confidence > confidence_threshold, + ) def main(): @@ -461,6 +473,14 @@ def main(): action="store_true", help="Process text line by line and yield results incrementally.", ) + parser.add_argument( + "-c", + "--confidence_threshold", + type=float, + default=0.90, + help="Confidence threshold for considering predictions as high confidence.", + ) + args = parser.parse_args() # Early exit if only displaying help @@ -482,14 +502,16 @@ def main(): highest_avg_label, highest_avg_confidence, is_high_confidence, - ) in process_text(args.prompt, args.model_path, device, True): + ) in process_text( + args.prompt, args.model_path, device, True, args.confidence_threshold + ): print(f"Processed Text: {processed_text}") print(f"Highest Average Label: {highest_avg_label}") print(f"Highest Average Confidence: {highest_avg_confidence}") print(f"Is High Confidence: {is_high_confidence}") else: # Process the entire text as a single block - + ( processed_text, highest_avg_label, @@ -532,7 +554,9 @@ def process_file(file_path, model_path, device, output_path, debug, delimiter): results.append(process_text(line, model_path, device)) with open(output_path, "w") as html_file: - html_file.write("Processed Output\n") + html_file.write( + "Processed Output\n" + ) for result in results: line, highest_avg_label, highest_avg_conf, high_conf = result debug_info = "" @@ -548,5 +572,7 @@ def process_file(file_path, model_path, device, output_path, debug, delimiter): logging.info(f"Output written to {output_path}") except FileNotFoundError: logging.info(f"The file {file_path} was not found.") + + if __name__ == "__main__": main() From f9ae04ea2123251ff92498fb766ce0641a897427 Mon Sep 17 00:00:00 2001 From: David I Date: Wed, 6 Mar 2024 18:56:20 -0500 Subject: [PATCH 10/10] updating ruff --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 5efdb99..8bc20f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,3 @@ [tool.ruff] -ignore = ["E501","E722"] +ignore = ["E501","E722","F401"] fixable = ["ALL"] \ No newline at end of file