Skip to content

Commit

Permalink
Update collate_fn.py
Browse files Browse the repository at this point in the history
  • Loading branch information
oriolcolomefont authored Aug 12, 2023
1 parent bdfd372 commit 84b2b54
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions collate_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,28 @@


def collate_fn(batch, loss_type):
"""
A collate function for creating batches of data for training siamese/triplet network models.
Args:
batch (list): A list of dictionaries where each dictionary represents a data sample with keys:
- "anchor" (torch.Tensor): The anchor waveform tensor.
- "positive" (torch.Tensor): The positive waveform tensor.
- "negative" (torch.Tensor): The negative waveform tensor.
- "label" (int): Label for the anchor-positive pair.
- "label_neg" (int): Label for the anchor-negative pair.
loss_type (str): The type of loss function to use. Can be "triplet" or "contrastive".
Returns:
torch.Tensor or tuple: Depending on the loss type, returns either a tuple containing tensors for:
- "triplet" loss: (anchors, positives, hardest_negatives)
- "contrastive" loss: (samples1, samples2, labels)
or a single tensor for "contrastive" loss: labels.
Raises:
ValueError: If an invalid loss type is provided.
"""
if loss_type == "triplet":
return collate_fn_triplet(batch)
elif loss_type == "contrastive":
Expand Down

0 comments on commit 84b2b54

Please sign in to comment.