From 6c0b2b27ade52c678e30ec5f8ae5355546d2e1b5 Mon Sep 17 00:00:00 2001 From: Kye Date: Fri, 15 Sep 2023 18:15:21 -0400 Subject: [PATCH] documentation for `CodeDatasetBuilder` --- {pytorch => old}/download.py | 0 {pytorch => old}/process.py | 0 pyproject.toml | 2 +- pytorch/dataset_builder.py | 130 ++++++++++++++++++++++++++++------- pytorch/downloader.py | 23 ++++--- 5 files changed, 122 insertions(+), 33 deletions(-) rename {pytorch => old}/download.py (100%) rename {pytorch => old}/process.py (100%) diff --git a/pytorch/download.py b/old/download.py similarity index 100% rename from pytorch/download.py rename to old/download.py diff --git a/pytorch/process.py b/old/process.py similarity index 100% rename from pytorch/process.py rename to old/process.py diff --git a/pyproject.toml b/pyproject.toml index f66f1b3..6773740 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "pytorch-dataset" -version = "0.0.3" +version = "0.0.4" description = "Pytorch Dataset - Pytorch" license = "MIT" authors = ["Kye Gomez "] diff --git a/pytorch/dataset_builder.py b/pytorch/dataset_builder.py index 2b482e2..6aa3d4c 100644 --- a/pytorch/dataset_builder.py +++ b/pytorch/dataset_builder.py @@ -1,46 +1,128 @@ import os -import re from datasets import Dataset, load_from_disk class CodeDatasetBuilder: + """ + A utility class to build and manage code datasets. + + Args: + root_dir (str): The root directory to search for code files. + + Attributes: + root_dir (str): The root directory to search for code files. + + Example: + code_builder = CodeDatasetBuilder("lucidrains_repositories") + code_builder.save_dataset("lucidrains_python_code_dataset", exclude_files=["setup.py"], exclude_dirs=["tests"]) + code_builder.push_to_hub("lucidrains_python_code_dataset", organization="kye") + """ + def __init__(self, root_dir): self.root_dir = root_dir - def collect_code_files(self, file_extension=".py", exclude_files=None): + def collect_code_files(self, file_extension=".py", exclude_files=None, exclude_dirs=None): + """ + Collects code snippets from files in the specified directory and its subdirectories. + + Args: + file_extension (str, optional): The file extension of code files to include (default is ".py"). + exclude_files (list of str, optional): List of file names to exclude from collection (default is None). + exclude_dirs (list of str, optional): List of directory names to exclude from collection (default is None). + + Returns: + list of str: List of code snippets. + """ code_snippets = [] exclude_files = set(exclude_files) if exclude_files else set() - - for root, _, files in os.walk(self.root_dir): - for file in files: - if file.endswith(file_extension) and file not in exclude_files: - with open(os.path.join(root, file), "r", encoding="utf-8") as f: - code = f.read() - # Remove import statements to exclude dependencies - code = re.sub(r'^import .*\n', '', code, flags=re.MULTILINE) - code = re.sub(r'^from .*\n', '', code, flags=re.MULTILINE) - code_snippets.append(code) - + exclude_dirs = set(exclude_dirs) if exclude_dirs else set() + + try: + for root, dirs, files in os.walk(self.root_dir): + # Exclude specified directories + dirs[:] = [d for d in dirs if d not in exclude_dirs] + + for file in files: + if file.endswith(file_extension) and file not in exclude_files: + with open(os.path.join(root, file), "r", encoding="utf-8") as f: + code = f.read() + code_snippets.append(code) + except Exception as e: + raise RuntimeError(f"Error while collecting code files: {e}") + return code_snippets - def create_dataset(self, file_extension=".py", exclude_files=None): - code_snippets = self.collect_code_files(file_extension, exclude_files) + def create_dataset(self, file_extension=".py", exclude_files=None, exclude_dirs=None): + """ + Creates a dataset from collected code snippets. + + Args: + file_extension (str, optional): The file extension of code files to include (default is ".py"). + exclude_files (list of str, optional): List of file names to exclude from collection (default is None). + exclude_dirs (list of str, optional): List of directory names to exclude from collection (default is None). + + Returns: + datasets.Dataset: The code dataset. + """ + code_snippets = self.collect_code_files(file_extension, exclude_files, exclude_dirs) dataset = Dataset.from_dict({"code": code_snippets}) return dataset - def save_dataset(self, dataset_name): - dataset = self.create_dataset() - dataset.save_to_disk(dataset_name) + def save_dataset(self, dataset_name, file_extension=".py", exclude_files=None, exclude_dirs=None): + """ + Saves the code dataset to disk. + + Args: + dataset_name (str): The name for the saved dataset. + file_extension (str, optional): The file extension of code files to include (default is ".py"). + exclude_files (list of str, optional): List of file names to exclude from collection (default is None). + exclude_dirs (list of str, optional): List of directory names to exclude from collection (default is None). + + Raises: + RuntimeError: If there is an error while saving the dataset. + """ + dataset = self.create_dataset(file_extension, exclude_files, exclude_dirs) + try: + dataset.save_to_disk(dataset_name) + except Exception as e: + raise RuntimeError(f"Error while saving the dataset: {e}") def load_dataset(self, dataset_name): - loaded_dataset = load_from_disk(dataset_name) - return loaded_dataset + """ + Loads a code dataset from disk. + + Args: + dataset_name (str): The name of the saved dataset. + + Returns: + datasets.Dataset: The loaded code dataset. + """ + try: + loaded_dataset = load_from_disk(dataset_name) + return loaded_dataset + except Exception as e: + raise RuntimeError(f"Error while loading the dataset: {e}") def push_to_hub(self, dataset_name, organization=None): + """ + Pushes the code dataset to the Hugging Face Model Hub. + + Args: + dataset_name (str): The name of the saved dataset. + organization (str, optional): The organization on the Model Hub to push to (default is "username"). + """ dataset = self.load_dataset(dataset_name) organization = organization or "username" - dataset.push_to_hub(f"{organization}/{dataset_name}") + try: + dataset.push_to_hub(f"{organization}/{dataset_name}") + except Exception as e: + raise RuntimeError(f"Error while pushing the dataset to the Hugging Face Model Hub: {e}") # Example usage: -# code_builder = CodeDatasetBuilder("lucidrains_repositories") -# code_builder.save_dataset("lucidrains_python_code_dataset") -# code_builder.push_to_hub("lucidrains_python_code_dataset", organization="kye") +code_builder = CodeDatasetBuilder("lucidrains_repositories") + +code_builder.save_dataset( + "lucidrains_python_code_dataset", + exclude_files=["setup.py"], exclude_dirs=["tests"] +) + +code_builder.push_to_hub("lucidrains_python_code_dataset", organization="kye") diff --git a/pytorch/downloader.py b/pytorch/downloader.py index 5793414..2e5b93d 100644 --- a/pytorch/downloader.py +++ b/pytorch/downloader.py @@ -3,12 +3,6 @@ import zipfile class GitHubRepoDownloader: - """ - # Example usage: - # downloader = GitHubRepoDownloader(username="lucidrains", download_dir="lucidrains_repositories") - # downloader.download_repositories() - - """ def __init__(self, username, download_dir): self.username = username self.api_url = f"https://api.github.com/users/{username}/repos" @@ -28,15 +22,28 @@ def download_repositories(self): if zip_response.status_code == 200: with open(zip_file_path, "wb") as zip_file: zip_file.write(zip_response.content) - self._unzip_repository(zip_file_path) - print(f"Downloaded and unzipped {repo_name}") + if self._is_valid_zip(zip_file_path): + self._unzip_repository(zip_file_path) + print(f"Downloaded and unzipped {repo_name}") + else: + print(f"Invalid ZIP file for {repo_name}") else: print(f"Failed to download {repo_name}") else: print(f"Failed to fetch repositories for user {self.username}") print("All repositories downloaded and unzipped.") + def _is_valid_zip(self, zip_file_path): + try: + with zipfile.ZipFile(zip_file_path, "r") as zip_ref: + return True + except zipfile.BadZipFile: + return False + def _unzip_repository(self, zip_file_path): with zipfile.ZipFile(zip_file_path, "r") as zip_ref: zip_ref.extractall(self.download_dir) +# Example usage: +downloader = GitHubRepoDownloader(username="lucidrains", download_dir="lucidrains_repositories") +downloader.download_repositories()