Skip to content

Commit

Permalink
documentation for CodeDatasetBuilder
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Sep 15, 2023
1 parent 7b536bf commit 6c0b2b2
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 33 deletions.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "pytorch-dataset"
version = "0.0.3"
version = "0.0.4"
description = "Pytorch Dataset - Pytorch"
license = "MIT"
authors = ["Kye Gomez <[email protected]>"]
Expand Down
130 changes: 106 additions & 24 deletions pytorch/dataset_builder.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,128 @@
import os
import re
from datasets import Dataset, load_from_disk

class CodeDatasetBuilder:
"""
A utility class to build and manage code datasets.
Args:
root_dir (str): The root directory to search for code files.
Attributes:
root_dir (str): The root directory to search for code files.
Example:
code_builder = CodeDatasetBuilder("lucidrains_repositories")
code_builder.save_dataset("lucidrains_python_code_dataset", exclude_files=["setup.py"], exclude_dirs=["tests"])
code_builder.push_to_hub("lucidrains_python_code_dataset", organization="kye")
"""

def __init__(self, root_dir):
self.root_dir = root_dir

def collect_code_files(self, file_extension=".py", exclude_files=None):
def collect_code_files(self, file_extension=".py", exclude_files=None, exclude_dirs=None):
"""
Collects code snippets from files in the specified directory and its subdirectories.
Args:
file_extension (str, optional): The file extension of code files to include (default is ".py").
exclude_files (list of str, optional): List of file names to exclude from collection (default is None).
exclude_dirs (list of str, optional): List of directory names to exclude from collection (default is None).
Returns:
list of str: List of code snippets.
"""
code_snippets = []
exclude_files = set(exclude_files) if exclude_files else set()

for root, _, files in os.walk(self.root_dir):
for file in files:
if file.endswith(file_extension) and file not in exclude_files:
with open(os.path.join(root, file), "r", encoding="utf-8") as f:
code = f.read()
# Remove import statements to exclude dependencies
code = re.sub(r'^import .*\n', '', code, flags=re.MULTILINE)
code = re.sub(r'^from .*\n', '', code, flags=re.MULTILINE)
code_snippets.append(code)

exclude_dirs = set(exclude_dirs) if exclude_dirs else set()

try:
for root, dirs, files in os.walk(self.root_dir):
# Exclude specified directories
dirs[:] = [d for d in dirs if d not in exclude_dirs]

for file in files:
if file.endswith(file_extension) and file not in exclude_files:
with open(os.path.join(root, file), "r", encoding="utf-8") as f:
code = f.read()
code_snippets.append(code)
except Exception as e:
raise RuntimeError(f"Error while collecting code files: {e}")

return code_snippets

def create_dataset(self, file_extension=".py", exclude_files=None):
code_snippets = self.collect_code_files(file_extension, exclude_files)
def create_dataset(self, file_extension=".py", exclude_files=None, exclude_dirs=None):
"""
Creates a dataset from collected code snippets.
Args:
file_extension (str, optional): The file extension of code files to include (default is ".py").
exclude_files (list of str, optional): List of file names to exclude from collection (default is None).
exclude_dirs (list of str, optional): List of directory names to exclude from collection (default is None).
Returns:
datasets.Dataset: The code dataset.
"""
code_snippets = self.collect_code_files(file_extension, exclude_files, exclude_dirs)
dataset = Dataset.from_dict({"code": code_snippets})
return dataset

def save_dataset(self, dataset_name):
dataset = self.create_dataset()
dataset.save_to_disk(dataset_name)
def save_dataset(self, dataset_name, file_extension=".py", exclude_files=None, exclude_dirs=None):
"""
Saves the code dataset to disk.
Args:
dataset_name (str): The name for the saved dataset.
file_extension (str, optional): The file extension of code files to include (default is ".py").
exclude_files (list of str, optional): List of file names to exclude from collection (default is None).
exclude_dirs (list of str, optional): List of directory names to exclude from collection (default is None).
Raises:
RuntimeError: If there is an error while saving the dataset.
"""
dataset = self.create_dataset(file_extension, exclude_files, exclude_dirs)
try:
dataset.save_to_disk(dataset_name)
except Exception as e:
raise RuntimeError(f"Error while saving the dataset: {e}")

def load_dataset(self, dataset_name):
loaded_dataset = load_from_disk(dataset_name)
return loaded_dataset
"""
Loads a code dataset from disk.
Args:
dataset_name (str): The name of the saved dataset.
Returns:
datasets.Dataset: The loaded code dataset.
"""
try:
loaded_dataset = load_from_disk(dataset_name)
return loaded_dataset
except Exception as e:
raise RuntimeError(f"Error while loading the dataset: {e}")

def push_to_hub(self, dataset_name, organization=None):
"""
Pushes the code dataset to the Hugging Face Model Hub.
Args:
dataset_name (str): The name of the saved dataset.
organization (str, optional): The organization on the Model Hub to push to (default is "username").
"""
dataset = self.load_dataset(dataset_name)
organization = organization or "username"
dataset.push_to_hub(f"{organization}/{dataset_name}")
try:
dataset.push_to_hub(f"{organization}/{dataset_name}")
except Exception as e:
raise RuntimeError(f"Error while pushing the dataset to the Hugging Face Model Hub: {e}")

# Example usage:
# code_builder = CodeDatasetBuilder("lucidrains_repositories")
# code_builder.save_dataset("lucidrains_python_code_dataset")
# code_builder.push_to_hub("lucidrains_python_code_dataset", organization="kye")
code_builder = CodeDatasetBuilder("lucidrains_repositories")

code_builder.save_dataset(
"lucidrains_python_code_dataset",
exclude_files=["setup.py"], exclude_dirs=["tests"]
)

code_builder.push_to_hub("lucidrains_python_code_dataset", organization="kye")
23 changes: 15 additions & 8 deletions pytorch/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,6 @@
import zipfile

class GitHubRepoDownloader:
"""
# Example usage:
# downloader = GitHubRepoDownloader(username="lucidrains", download_dir="lucidrains_repositories")
# downloader.download_repositories()
"""
def __init__(self, username, download_dir):
self.username = username
self.api_url = f"https://api.github.com/users/{username}/repos"
Expand All @@ -28,15 +22,28 @@ def download_repositories(self):
if zip_response.status_code == 200:
with open(zip_file_path, "wb") as zip_file:
zip_file.write(zip_response.content)
self._unzip_repository(zip_file_path)
print(f"Downloaded and unzipped {repo_name}")
if self._is_valid_zip(zip_file_path):
self._unzip_repository(zip_file_path)
print(f"Downloaded and unzipped {repo_name}")
else:
print(f"Invalid ZIP file for {repo_name}")
else:
print(f"Failed to download {repo_name}")
else:
print(f"Failed to fetch repositories for user {self.username}")
print("All repositories downloaded and unzipped.")

def _is_valid_zip(self, zip_file_path):
try:
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
return True
except zipfile.BadZipFile:
return False

def _unzip_repository(self, zip_file_path):
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
zip_ref.extractall(self.download_dir)

# Example usage:
downloader = GitHubRepoDownloader(username="lucidrains", download_dir="lucidrains_repositories")
downloader.download_repositories()

0 comments on commit 6c0b2b2

Please sign in to comment.