Skip to content

Commit

Permalink
dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
rreece committed Mar 9, 2023
1 parent 3a02d73 commit 219aec7
Showing 1 changed file with 22 additions and 13 deletions.
35 changes: 22 additions & 13 deletions scripts/load_csv_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,30 @@
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
from tqdm import tqdm



# DEBUG
#torch.set_printoptions(profile="full")



def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('infiles', nargs='+', default=None,
help='Input csv files.')
parser.add_argument('-m', '--max_batches', type=int, default=10,
help='Max batches to process.')
return parser.parse_args()


class SentimentDataset(Dataset):
def __init__(self, csv_fn,
tokenizer=None,
max_length=512,
return_attention_mask=True):
return_attention_mask=True,
return_token_type_ids=False):
# TODO: handle multiple files
if isinstance(csv_fn, list):
csv_fn = csv_fn[0]
Expand All @@ -34,6 +47,7 @@ def __init__(self, csv_fn,
self.max_length = max_length
self.tokenizer = tokenizer
self.return_attention_mask = return_attention_mask
self.return_token_type_ids = return_token_type_ids

def __len__(self):
return len(self.df)
Expand All @@ -45,12 +59,13 @@ def __getitem__(self, idx):
datum = dict()
if self.tokenizer:
tokenizer_outputs = self.tokenizer(sample,
add_special_tokens=True,
max_length=self.max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
return_token_type_ids=False,
return_attention_mask=self.return_attention_mask,
return_token_type_ids=self.return_token_type_ids,
)
datum["input_ids"] = tokenizer_outputs["input_ids"].squeeze()
if self.return_attention_mask:
Expand All @@ -61,14 +76,6 @@ def __getitem__(self, idx):
return datum



def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('infiles', default=None, nargs='+',
help='Input csv files.')
return parser.parse_args()


def get_dataloader(fn):
"""
csv to DataLoader
Expand All @@ -87,11 +94,13 @@ def get_dataloader(fn):
def main():
args = parse_args()
infiles = args.infiles
max_batches = args.max_batches
dataloader = get_dataloader(infiles)
print(dataloader)
for x in dataloader:
print(x)
assert False
for i_batch, batch in tqdm(enumerate(dataloader), total=max_batches):
os.system("sleep 0.2s")
if i_batch + 1 >= max_batches:
break
print("Done.")


Expand Down

0 comments on commit 219aec7

Please sign in to comment.