generated from kyegomez/Python-Package-Template
-
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
documentation for
CodeDatasetBuilder
- Loading branch information
Kye
committed
Sep 15, 2023
1 parent
7b536bf
commit 6c0b2b2
Showing
5 changed files
with
122 additions
and
33 deletions.
There are no files selected for viewing
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" | |
|
||
[tool.poetry] | ||
name = "pytorch-dataset" | ||
version = "0.0.3" | ||
version = "0.0.4" | ||
description = "Pytorch Dataset - Pytorch" | ||
license = "MIT" | ||
authors = ["Kye Gomez <[email protected]>"] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,46 +1,128 @@ | ||
import os | ||
import re | ||
from datasets import Dataset, load_from_disk | ||
|
||
class CodeDatasetBuilder: | ||
""" | ||
A utility class to build and manage code datasets. | ||
Args: | ||
root_dir (str): The root directory to search for code files. | ||
Attributes: | ||
root_dir (str): The root directory to search for code files. | ||
Example: | ||
code_builder = CodeDatasetBuilder("lucidrains_repositories") | ||
code_builder.save_dataset("lucidrains_python_code_dataset", exclude_files=["setup.py"], exclude_dirs=["tests"]) | ||
code_builder.push_to_hub("lucidrains_python_code_dataset", organization="kye") | ||
""" | ||
|
||
def __init__(self, root_dir): | ||
self.root_dir = root_dir | ||
|
||
def collect_code_files(self, file_extension=".py", exclude_files=None): | ||
def collect_code_files(self, file_extension=".py", exclude_files=None, exclude_dirs=None): | ||
""" | ||
Collects code snippets from files in the specified directory and its subdirectories. | ||
Args: | ||
file_extension (str, optional): The file extension of code files to include (default is ".py"). | ||
exclude_files (list of str, optional): List of file names to exclude from collection (default is None). | ||
exclude_dirs (list of str, optional): List of directory names to exclude from collection (default is None). | ||
Returns: | ||
list of str: List of code snippets. | ||
""" | ||
code_snippets = [] | ||
exclude_files = set(exclude_files) if exclude_files else set() | ||
|
||
for root, _, files in os.walk(self.root_dir): | ||
for file in files: | ||
if file.endswith(file_extension) and file not in exclude_files: | ||
with open(os.path.join(root, file), "r", encoding="utf-8") as f: | ||
code = f.read() | ||
# Remove import statements to exclude dependencies | ||
code = re.sub(r'^import .*\n', '', code, flags=re.MULTILINE) | ||
code = re.sub(r'^from .*\n', '', code, flags=re.MULTILINE) | ||
code_snippets.append(code) | ||
|
||
exclude_dirs = set(exclude_dirs) if exclude_dirs else set() | ||
|
||
try: | ||
for root, dirs, files in os.walk(self.root_dir): | ||
# Exclude specified directories | ||
dirs[:] = [d for d in dirs if d not in exclude_dirs] | ||
|
||
for file in files: | ||
if file.endswith(file_extension) and file not in exclude_files: | ||
with open(os.path.join(root, file), "r", encoding="utf-8") as f: | ||
code = f.read() | ||
code_snippets.append(code) | ||
except Exception as e: | ||
raise RuntimeError(f"Error while collecting code files: {e}") | ||
|
||
return code_snippets | ||
|
||
def create_dataset(self, file_extension=".py", exclude_files=None): | ||
code_snippets = self.collect_code_files(file_extension, exclude_files) | ||
def create_dataset(self, file_extension=".py", exclude_files=None, exclude_dirs=None): | ||
""" | ||
Creates a dataset from collected code snippets. | ||
Args: | ||
file_extension (str, optional): The file extension of code files to include (default is ".py"). | ||
exclude_files (list of str, optional): List of file names to exclude from collection (default is None). | ||
exclude_dirs (list of str, optional): List of directory names to exclude from collection (default is None). | ||
Returns: | ||
datasets.Dataset: The code dataset. | ||
""" | ||
code_snippets = self.collect_code_files(file_extension, exclude_files, exclude_dirs) | ||
dataset = Dataset.from_dict({"code": code_snippets}) | ||
return dataset | ||
|
||
def save_dataset(self, dataset_name): | ||
dataset = self.create_dataset() | ||
dataset.save_to_disk(dataset_name) | ||
def save_dataset(self, dataset_name, file_extension=".py", exclude_files=None, exclude_dirs=None): | ||
""" | ||
Saves the code dataset to disk. | ||
Args: | ||
dataset_name (str): The name for the saved dataset. | ||
file_extension (str, optional): The file extension of code files to include (default is ".py"). | ||
exclude_files (list of str, optional): List of file names to exclude from collection (default is None). | ||
exclude_dirs (list of str, optional): List of directory names to exclude from collection (default is None). | ||
Raises: | ||
RuntimeError: If there is an error while saving the dataset. | ||
""" | ||
dataset = self.create_dataset(file_extension, exclude_files, exclude_dirs) | ||
try: | ||
dataset.save_to_disk(dataset_name) | ||
except Exception as e: | ||
raise RuntimeError(f"Error while saving the dataset: {e}") | ||
|
||
def load_dataset(self, dataset_name): | ||
loaded_dataset = load_from_disk(dataset_name) | ||
return loaded_dataset | ||
""" | ||
Loads a code dataset from disk. | ||
Args: | ||
dataset_name (str): The name of the saved dataset. | ||
Returns: | ||
datasets.Dataset: The loaded code dataset. | ||
""" | ||
try: | ||
loaded_dataset = load_from_disk(dataset_name) | ||
return loaded_dataset | ||
except Exception as e: | ||
raise RuntimeError(f"Error while loading the dataset: {e}") | ||
|
||
def push_to_hub(self, dataset_name, organization=None): | ||
""" | ||
Pushes the code dataset to the Hugging Face Model Hub. | ||
Args: | ||
dataset_name (str): The name of the saved dataset. | ||
organization (str, optional): The organization on the Model Hub to push to (default is "username"). | ||
""" | ||
dataset = self.load_dataset(dataset_name) | ||
organization = organization or "username" | ||
dataset.push_to_hub(f"{organization}/{dataset_name}") | ||
try: | ||
dataset.push_to_hub(f"{organization}/{dataset_name}") | ||
except Exception as e: | ||
raise RuntimeError(f"Error while pushing the dataset to the Hugging Face Model Hub: {e}") | ||
|
||
# Example usage: | ||
# code_builder = CodeDatasetBuilder("lucidrains_repositories") | ||
# code_builder.save_dataset("lucidrains_python_code_dataset") | ||
# code_builder.push_to_hub("lucidrains_python_code_dataset", organization="kye") | ||
code_builder = CodeDatasetBuilder("lucidrains_repositories") | ||
|
||
code_builder.save_dataset( | ||
"lucidrains_python_code_dataset", | ||
exclude_files=["setup.py"], exclude_dirs=["tests"] | ||
) | ||
|
||
code_builder.push_to_hub("lucidrains_python_code_dataset", organization="kye") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters