Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I-JEPA #1273

Merged
merged 40 commits into from
Jul 14, 2023
Merged

I-JEPA #1273

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
5c6d01a
very dirty draft
Jun 1, 2023
a1c0c15
little refactoring
Jun 1, 2023
c85ee97
refactoring
Jun 1, 2023
0ece1d3
refactoring
Jun 1, 2023
6a3f7fe
+ encoder. TODO: decoder based on causal attention
Jun 2, 2023
7102910
fix imports
Jun 2, 2023
0c5bac6
add Decoder class to consistency between code
Jun 2, 2023
7c41f59
change naming; change class structure
Jun 3, 2023
d99994d
change module; add example
Jun 6, 2023
963c7f1
Merge branch 'lightly-ai:master' into master
Natyren Jun 7, 2023
4e585e7
few refactoring
Jun 7, 2023
57f9020
Merge branch 'master' of https://github.com/Natyren/lightly
Jun 7, 2023
9034061
del line
Jun 7, 2023
035ed5a
del comments
Jun 7, 2023
3b41342
del comment
Jun 7, 2023
4646786
del comment
Jun 7, 2023
28fd7db
del line
Natyren Jun 7, 2023
de1106c
pass
Natyren Jun 7, 2023
9995df6
pass
Natyren Jun 7, 2023
c23de39
add model itself, todo: train loop and debug
Natyren Jul 9, 2023
ca171bf
add collator;
Natyren Jul 9, 2023
e1b97ec
added template to train code and transforms
Natyren Jul 9, 2023
612f7cc
add train in pure pytorch
Natyren Jul 12, 2023
d9a9fd8
little fix
Natyren Jul 12, 2023
a77045d
Merge branch 'lightly-ai:master' into master
Natyren Jul 12, 2023
adac38b
little fix
Natyren Jul 12, 2023
800aed6
little fix
Natyren Jul 12, 2023
3a92cda
fix classmethod
Natyren Jul 13, 2023
485f9fc
fix collator
Natyren Jul 13, 2023
80cc077
fixes
Natyren Jul 13, 2023
45510e2
fix in collator
Natyren Jul 13, 2023
263c221
fix collators, added imports, fix models
Natyren Jul 13, 2023
6939ff9
add ijepa backbone
Natyren Jul 13, 2023
0fd4639
finish pure torch impelementation
Natyren Jul 13, 2023
72ca5cb
docstring fix
Natyren Jul 13, 2023
605cdc2
fixes of name and references to original paper
Natyren Jul 14, 2023
1be6453
Format
guarin Jul 14, 2023
d35b7a9
Add note about experimental support
guarin Jul 14, 2023
d262f12
Add datasets to gitignore
guarin Jul 14, 2023
3dafc05
Cleanup imports
guarin Jul 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
lightning_logs/
**lightning_logs/
**/__MACOSX
datasets/
docs/source/tutorials/package/*
docs/source/tutorials/platform/*
docs/source/tutorials_source/platform/data
Expand Down
117 changes: 117 additions & 0 deletions examples/pytorch/ijepa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import copy

import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm

from lightly.data.collate import IJEPAMaskCollator
from lightly.models import utils
from lightly.models.modules.ijepa import IJEPABackbone, IJEPAPredictor
from lightly.transforms.ijepa_transform import IJEPATransform


class IJEPA(nn.Module):
def __init__(self, vit_encoder, vit_predictor, momentum_scheduler):
super().__init__()
self.encoder = IJEPABackbone.from_vit(vit_encoder)
self.predictor = IJEPAPredictor.from_vit_encoder(
vit_predictor.encoder,
(vit_predictor.image_size // vit_predictor.patch_size) ** 2,
)
self.target_encoder = copy.deepcopy(self.encoder)
self.momentum_scheduler = momentum_scheduler

def forward_target(self, imgs, masks_enc, masks_pred):
with torch.no_grad():
h = self.target_encoder(imgs)
h = F.layer_norm(h, (h.size(-1),)) # normalize over feature-dim
B = len(h)
# -- create targets (masked regions of h)
h = utils.apply_masks(h, masks_pred)
h = utils.repeat_interleave_batch(h, B, repeat=len(masks_enc))
return h

def forward_context(self, imgs, masks_enc, masks_pred):
z = self.encoder(imgs, masks_enc)
z = self.predictor(z, masks_enc, masks_pred)
return z

def forward(self, imgs, masks_enc, masks_pred):
z = self.forward_context(imgs, masks_enc, masks_pred)
h = self.forward_target(imgs, masks_enc, masks_pred)
return z, h

def update_target_encoder(
self,
):
with torch.no_grad():
m = next(self.momentum_scheduler)
for param_q, param_k in zip(
self.encoder.parameters(), self.target_encoder.parameters()
):
param_k.data.mul_(m).add_((1.0 - m) * param_q.detach().data)


collator = IJEPAMaskCollator(
input_size=(224, 224),
patch_size=32,
)

transform = IJEPATransform()

# we ignore object detection annotations by setting target_transform to return 0
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")
dataset = torchvision.datasets.VOCDetection(
"datasets/pascal_voc",
download=True,
transform=transform,
target_transform=lambda t: 0,
)
data_loader = torch.utils.data.DataLoader(
dataset, collate_fn=collator, batch_size=10, persistent_workers=False
)

ema = (0.996, 1.0)
ipe_scale = 1.0
ipe = len(data_loader)
num_epochs = 10
momentum_scheduler = (
ema[0] + i * (ema[1] - ema[0]) / (ipe * num_epochs * ipe_scale)
for i in range(int(ipe * num_epochs * ipe_scale) + 1)
)

vit_for_predictor = torchvision.models.vit_b_32(pretrained=False)
vit_for_embedder = torchvision.models.vit_b_32(pretrained=False)
model = IJEPA(vit_for_embedder, vit_for_predictor, momentum_scheduler)

criterion = nn.SmoothL1Loss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1.5e-4)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

print("Starting Training")
for epoch in range(num_epochs):
total_loss = 0
for udata, masks_enc, masks_pred in tqdm(data_loader):

def load_imgs():
# -- unsupervised imgs
imgs = udata[0].to(device, non_blocking=True)
masks_1 = [u.to(device, non_blocking=True) for u in masks_enc]
masks_2 = [u.to(device, non_blocking=True) for u in masks_pred]
return (imgs, masks_1, masks_2)

imgs, masks_enc, masks_pred = load_imgs()
z, h = model(imgs, masks_enc, masks_pred)
loss = criterion(z, h)
total_loss += loss.detach()
loss.backward()
optimizer.step()
optimizer.zero_grad()
model.update_target_encoder()

avg_loss = total_loss / len(data_loader)
print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")
1 change: 1 addition & 0 deletions examples/pytorch_lightning/ijepa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# TODO
1 change: 1 addition & 0 deletions examples/pytorch_lightning_distributed/ijepa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# TODO
172 changes: 172 additions & 0 deletions lightly/data/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved

import math
from multiprocessing import Value
from typing import List, Optional, Tuple, Union
from warnings import warn

Expand Down Expand Up @@ -1345,6 +1347,176 @@
return (views_global, views_local, grids_global, grids_local), labels, fnames


class IJEPAMaskCollator:
"""Collator for IJEPA model [0].

Experimental: Support for I-JEPA is experimental, there might be breaking changes
in the future.

Code inspired by [1].

- [0]: Joint-Embedding Predictive Architecture, 2023, https://arxiv.org/abs/2301.08243
- [1]: https://github.com/facebookresearch/ijepa
"""

def __init__(
self,
input_size=(224, 224),
patch_size=16,
enc_mask_scale=(0.2, 0.8),
pred_mask_scale=(0.2, 0.8),
aspect_ratio=(0.3, 3.0),
nenc=1,
npred=2,
min_keep=4,
allow_overlap=False,
):
if not isinstance(input_size, tuple):
input_size = (input_size,) * 2
self.patch_size = patch_size
self.height, self.width = (

Check warning on line 1377 in lightly/data/collate.py

View check run for this annotation

Codecov / codecov/patch

lightly/data/collate.py#L1374-L1377

Added lines #L1374 - L1377 were not covered by tests
input_size[0] // patch_size,
input_size[1] // patch_size,
)
self.enc_mask_scale = enc_mask_scale
self.pred_mask_scale = pred_mask_scale
self.aspect_ratio = aspect_ratio
self.nenc = nenc
self.npred = npred
self.min_keep = min_keep # minimum number of patches to keep
self.allow_overlap = (

Check warning on line 1387 in lightly/data/collate.py

View check run for this annotation

Codecov / codecov/patch

lightly/data/collate.py#L1381-L1387

Added lines #L1381 - L1387 were not covered by tests
allow_overlap # whether to allow overlap b/w enc and pred masks
)
self._itr_counter = Value("i", -1) # collator is shared across worker processes

Check warning on line 1390 in lightly/data/collate.py

View check run for this annotation

Codecov / codecov/patch

lightly/data/collate.py#L1390

Added line #L1390 was not covered by tests

def step(self):
i = self._itr_counter
with i.get_lock():
i.value += 1
v = i.value
return v

Check warning on line 1397 in lightly/data/collate.py

View check run for this annotation

Codecov / codecov/patch

lightly/data/collate.py#L1393-L1397

Added lines #L1393 - L1397 were not covered by tests

def _sample_block_size(self, generator, scale, aspect_ratio_scale):
_rand = torch.rand(1, generator=generator).item()

Check warning on line 1400 in lightly/data/collate.py

View check run for this annotation

Codecov / codecov/patch

lightly/data/collate.py#L1400

Added line #L1400 was not covered by tests
# -- Sample block scale
min_s, max_s = scale
mask_scale = min_s + _rand * (max_s - min_s)
max_keep = int(self.height * self.width * mask_scale)

Check warning on line 1404 in lightly/data/collate.py

View check run for this annotation

Codecov / codecov/patch

lightly/data/collate.py#L1402-L1404

Added lines #L1402 - L1404 were not covered by tests
# -- Sample block aspect-ratio
min_ar, max_ar = aspect_ratio_scale
aspect_ratio = min_ar + _rand * (max_ar - min_ar)

Check warning on line 1407 in lightly/data/collate.py

View check run for this annotation

Codecov / codecov/patch

lightly/data/collate.py#L1406-L1407

Added lines #L1406 - L1407 were not covered by tests
# -- Compute block height and width (given scale and aspect-ratio)
h = int(round(math.sqrt(max_keep * aspect_ratio)))
w = int(round(math.sqrt(max_keep / aspect_ratio)))
while h >= self.height:
h -= 1
while w >= self.width:
w -= 1

Check warning on line 1414 in lightly/data/collate.py

View check run for this annotation

Codecov / codecov/patch

lightly/data/collate.py#L1409-L1414

Added lines #L1409 - L1414 were not covered by tests

return (h, w)

Check warning on line 1416 in lightly/data/collate.py

View check run for this annotation

Codecov / codecov/patch

lightly/data/collate.py#L1416

Added line #L1416 was not covered by tests

def _sample_block_mask(self, b_size, acceptable_regions=None):
h, w = b_size

Check warning on line 1419 in lightly/data/collate.py

View check run for this annotation

Codecov / codecov/patch

lightly/data/collate.py#L1419

Added line #L1419 was not covered by tests

def constrain_mask(mask, tries=0):

Check warning on line 1421 in lightly/data/collate.py

View check run for this annotation

Codecov / codecov/patch

lightly/data/collate.py#L1421

Added line #L1421 was not covered by tests
"""Helper to restrict given mask to a set of acceptable regions"""
N = max(int(len(acceptable_regions) - tries), 0)
for k in range(N):
mask *= acceptable_regions[k]

Check warning on line 1425 in lightly/data/collate.py

View check run for this annotation

Codecov / codecov/patch

lightly/data/collate.py#L1423-L1425

Added lines #L1423 - L1425 were not covered by tests

# --
# -- Loop to sample masks until we find a valid one
tries = 0
timeout = og_timeout = 20
valid_mask = False
while not valid_mask:

Check warning on line 1432 in lightly/data/collate.py

View check run for this annotation

Codecov / codecov/patch

lightly/data/collate.py#L1429-L1432

Added lines #L1429 - L1432 were not covered by tests
# -- Sample block top-left corner
top = torch.randint(0, self.height - h, (1,))
left = torch.randint(0, self.width - w, (1,))
mask = torch.zeros((self.height, self.width), dtype=torch.int32)
mask[top : top + h, left : left + w] = 1

Check warning on line 1437 in lightly/data/collate.py

View check run for this annotation

Codecov / codecov/patch

lightly/data/collate.py#L1434-L1437

Added lines #L1434 - L1437 were not covered by tests
# -- Constrain mask to a set of acceptable regions
if acceptable_regions is not None:
constrain_mask(mask, tries)
mask = torch.nonzero(mask.flatten())

Check warning on line 1441 in lightly/data/collate.py

View check run for this annotation

Codecov / codecov/patch

lightly/data/collate.py#L1439-L1441

Added lines #L1439 - L1441 were not covered by tests
# -- If mask too small try again
valid_mask = len(mask) > self.min_keep
if not valid_mask:
timeout -= 1
if timeout == 0:
tries += 1
timeout = og_timeout
mask = mask.squeeze()

Check warning on line 1449 in lightly/data/collate.py

View check run for this annotation

Codecov / codecov/patch

lightly/data/collate.py#L1443-L1449

Added lines #L1443 - L1449 were not covered by tests
# --
mask_complement = torch.ones((self.height, self.width), dtype=torch.int32)
mask_complement[top : top + h, left : left + w] = 0

Check warning on line 1452 in lightly/data/collate.py

View check run for this annotation

Codecov / codecov/patch

lightly/data/collate.py#L1451-L1452

Added lines #L1451 - L1452 were not covered by tests
# --
return mask, mask_complement

Check warning on line 1454 in lightly/data/collate.py

View check run for this annotation

Codecov / codecov/patch

lightly/data/collate.py#L1454

Added line #L1454 was not covered by tests

def __call__(self, batch):
"""
Create encoder and predictor masks when collating imgs into a batch
# 1. sample enc block (size + location) using seed
# 2. sample pred block (size) using seed
# 3. sample several enc block locations for each image (w/o seed)
# 4. sample several pred block locations for each image (w/o seed)
# 5. return enc mask and pred mask
"""
B = len(batch)

Check warning on line 1465 in lightly/data/collate.py

View check run for this annotation

Codecov / codecov/patch

lightly/data/collate.py#L1465

Added line #L1465 was not covered by tests

collated_batch = torch.utils.data.default_collate(batch)

Check warning on line 1467 in lightly/data/collate.py

View check run for this annotation

Codecov / codecov/patch

lightly/data/collate.py#L1467

Added line #L1467 was not covered by tests

seed = self.step()
g = torch.Generator()
g.manual_seed(seed)
p_size = self._sample_block_size(

Check warning on line 1472 in lightly/data/collate.py

View check run for this annotation

Codecov / codecov/patch

lightly/data/collate.py#L1469-L1472

Added lines #L1469 - L1472 were not covered by tests
generator=g,
scale=self.pred_mask_scale,
aspect_ratio_scale=self.aspect_ratio,
)
e_size = self._sample_block_size(

Check warning on line 1477 in lightly/data/collate.py

View check run for this annotation

Codecov / codecov/patch

lightly/data/collate.py#L1477

Added line #L1477 was not covered by tests
generator=g, scale=self.enc_mask_scale, aspect_ratio_scale=(1.0, 1.0)
)

collated_masks_pred, collated_masks_enc = [], []
min_keep_pred = self.height * self.width
min_keep_enc = self.height * self.width
for _ in range(B):
masks_p, masks_C = [], []
for _ in range(self.npred):
mask, mask_C = self._sample_block_mask(p_size)
masks_p.append(mask)
masks_C.append(mask_C)
min_keep_pred = min(min_keep_pred, len(mask))
collated_masks_pred.append(masks_p)

Check warning on line 1491 in lightly/data/collate.py

View check run for this annotation

Codecov / codecov/patch

lightly/data/collate.py#L1481-L1491

Added lines #L1481 - L1491 were not covered by tests

acceptable_regions = masks_C

Check warning on line 1493 in lightly/data/collate.py

View check run for this annotation

Codecov / codecov/patch

lightly/data/collate.py#L1493

Added line #L1493 was not covered by tests

if self.allow_overlap:
acceptable_regions = None

Check warning on line 1496 in lightly/data/collate.py

View check run for this annotation

Codecov / codecov/patch

lightly/data/collate.py#L1495-L1496

Added lines #L1495 - L1496 were not covered by tests

masks_e = []
for _ in range(self.nenc):
mask, _ = self._sample_block_mask(

Check warning on line 1500 in lightly/data/collate.py

View check run for this annotation

Codecov / codecov/patch

lightly/data/collate.py#L1498-L1500

Added lines #L1498 - L1500 were not covered by tests
e_size, acceptable_regions=acceptable_regions
)
masks_e.append(mask)
min_keep_enc = min(min_keep_enc, len(mask))
collated_masks_enc.append(masks_e)

Check warning on line 1505 in lightly/data/collate.py

View check run for this annotation

Codecov / codecov/patch

lightly/data/collate.py#L1503-L1505

Added lines #L1503 - L1505 were not covered by tests

collated_masks_pred = [

Check warning on line 1507 in lightly/data/collate.py

View check run for this annotation

Codecov / codecov/patch

lightly/data/collate.py#L1507

Added line #L1507 was not covered by tests
[cm[:min_keep_pred] for cm in cm_list] for cm_list in collated_masks_pred
]
collated_masks_pred = torch.utils.data.default_collate(collated_masks_pred)

Check warning on line 1510 in lightly/data/collate.py

View check run for this annotation

Codecov / codecov/patch

lightly/data/collate.py#L1510

Added line #L1510 was not covered by tests
# --
collated_masks_enc = [

Check warning on line 1512 in lightly/data/collate.py

View check run for this annotation

Codecov / codecov/patch

lightly/data/collate.py#L1512

Added line #L1512 was not covered by tests
[cm[:min_keep_enc] for cm in cm_list] for cm_list in collated_masks_enc
]
collated_masks_enc = torch.utils.data.default_collate(collated_masks_enc)

Check warning on line 1515 in lightly/data/collate.py

View check run for this annotation

Codecov / codecov/patch

lightly/data/collate.py#L1515

Added line #L1515 was not covered by tests

return collated_batch, collated_masks_enc, collated_masks_pred

Check warning on line 1517 in lightly/data/collate.py

View check run for this annotation

Codecov / codecov/patch

lightly/data/collate.py#L1517

Added line #L1517 was not covered by tests


def _deprecation_warning_collate_functions() -> None:
warn(
"Collate functions are deprecated and will be removed in favor of transforms in v1.4.0.\n"
Expand Down
Loading
Loading