Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: pass on device to DataParallel; line breaks. #65

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ setup.cfg
examples/*
notebooks/*

tests/*
scripts/tmp/*
scripts/dataset_stats/*
scripts/leaderboard/*
Expand Down
22 changes: 15 additions & 7 deletions cleanfid/downloads_helper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os
import urllib.request
import requests
import shutil
import urllib.request

import requests

inception_url = "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt"

Expand All @@ -12,11 +12,15 @@
ARGS:
fpath - output folder path
"""


def check_download_inception(fpath="./"):
inception_path = os.path.join(fpath, "inception-2015-12-05.pt")
if not os.path.exists(inception_path):
# download the file
with urllib.request.urlopen(inception_url) as response, open(inception_path, 'wb') as f:
with urllib.request.urlopen(inception_url) as response, open(
inception_path, "wb"
) as f:
shutil.copyfileobj(response, f)
return inception_path

Expand All @@ -27,13 +31,15 @@ def check_download_inception(fpath="./"):
local_folder - output folder path
url - the weburl to download
"""


def check_download_url(local_folder, url):
name = os.path.basename(url)
local_path = os.path.join(local_folder, name)
if not os.path.exists(local_path):
os.makedirs(local_folder, exist_ok=True)
print(f"downloading statistics to {local_path}")
with urllib.request.urlopen(url) as response, open(local_path, 'wb') as f:
with urllib.request.urlopen(url) as response, open(local_path, "wb") as f:
shutil.copyfileobj(response, f)
return local_path

Expand All @@ -44,20 +50,22 @@ def check_download_url(local_folder, url):
file_id - id of the google drive file
out_path - output folder path
"""


def download_google_drive(file_id, out_path):
def get_confirm_token(response):
for key, value in response.cookies.items():
if key.startswith('download_warning'):
if key.startswith("download_warning"):
return value
return None

URL = "https://drive.google.com/uc?export=download"
session = requests.Session()
response = session.get(URL, params={'id': file_id}, stream=True)
response = session.get(URL, params={"id": file_id}, stream=True)
token = get_confirm_token(response)

if token:
params = {'id': file_id, 'confirm': token}
params = {"id": file_id, "confirm": token}
response = session.get(URL, params=params, stream=True)

CHUNK_SIZE = 32768
Expand Down
71 changes: 58 additions & 13 deletions cleanfid/features.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,52 @@
"""
helpers for extracting features from image
"""

import os
import platform

import numpy as np
import torch

import cleanfid
from cleanfid.downloads_helper import check_download_url
from cleanfid.inception_pytorch import InceptionV3
from cleanfid.inception_torchscript import InceptionV3W


"""
returns a functions that takes an image in range [0,255]
and outputs a feature embedding vector
"""
def feature_extractor(name="torchscript_inception", device=torch.device("cuda"), resize_inside=False, use_dataparallel=True):


def feature_extractor(
name: str = "torchscript_inception",
device=torch.device("cuda"),
resize_inside=False,
use_dataparallel=True,
):
device_ids = [device]
if name == "torchscript_inception":
path = "./" if platform.system() == "Windows" else "/tmp"
model = InceptionV3W(path, download=True, resize_inside=resize_inside).to(device)
model = InceptionV3W(
path,
download=True,
resize_inside=resize_inside,
).to(device)
model.eval()
if use_dataparallel:
model = torch.nn.DataParallel(model)
def model_fn(x): return model(x)
model = torch.nn.DataParallel(model, device_ids=device_ids)

def model_fn(x):
return model(x)
elif name == "pytorch_inception":
model = InceptionV3(output_blocks=[3], resize_input=False).to(device)
model.eval()
if use_dataparallel:
model = torch.nn.DataParallel(model)
def model_fn(x): return model(x/255)[0].squeeze(-1).squeeze(-1)
model = torch.nn.DataParallel(model, device_ids=device_ids)

def model_fn(x):
return model(x / 255)[0].squeeze(-1).squeeze(-1)
else:
raise ValueError(f"{name} feature extractor not implemented")
return model_fn
Expand All @@ -37,27 +55,54 @@ def model_fn(x): return model(x/255)[0].squeeze(-1).squeeze(-1)
"""
Build a feature extractor for each of the modes
"""


def build_feature_extractor(mode, device=torch.device("cuda"), use_dataparallel=True):
if mode == "legacy_pytorch":
feat_model = feature_extractor(name="pytorch_inception", resize_inside=False, device=device, use_dataparallel=use_dataparallel)
feat_model = feature_extractor(
name="pytorch_inception",
resize_inside=False,
device=device,
use_dataparallel=use_dataparallel,
)
elif mode == "legacy_tensorflow":
feat_model = feature_extractor(name="torchscript_inception", resize_inside=True, device=device, use_dataparallel=use_dataparallel)
feat_model = feature_extractor(
name="torchscript_inception",
resize_inside=True,
device=device,
use_dataparallel=use_dataparallel,
)
elif mode == "clean":
feat_model = feature_extractor(name="torchscript_inception", resize_inside=False, device=device, use_dataparallel=use_dataparallel)
feat_model = feature_extractor(
name="torchscript_inception",
resize_inside=False,
device=device,
use_dataparallel=use_dataparallel,
)
return feat_model


"""
Load precomputed reference statistics for commonly used datasets
"""
def get_reference_statistics(name, res, mode="clean", model_name="inception_v3", seed=0, split="test", metric="FID"):


def get_reference_statistics(
name,
res,
mode="clean",
model_name="inception_v3",
seed=0,
split="test",
metric="FID",
):
base_url = "https://www.cs.cmu.edu/~clean-fid/stats/"
if split == "custom":
res = "na"
if model_name=="inception_v3":
if model_name == "inception_v3":
model_modifier = ""
else:
model_modifier = "_"+model_name
model_modifier = "_" + model_name
if metric == "FID":
rel_path = (f"{name}_{mode}{model_modifier}_{split}_{res}.npz").lower()
url = f"{base_url}/{rel_path}"
Expand Down
Loading