Skip to content

Commit

Permalink
allow for assigning a specific gpu id for training, readying for mult…
Browse files Browse the repository at this point in the history
…i-gpu
  • Loading branch information
lucidrains committed Sep 22, 2020
1 parent 0c9b33c commit b95548c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
24 changes: 17 additions & 7 deletions bin/stylegan2_pytorch
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
#!/usr/bin/env python
import os
import fire
from retry.api import retry_call
from tqdm import tqdm
from stylegan2_pytorch import Trainer, NanException
from datetime import datetime

def cast_list(el):
return el if isinstance(el, list) else [el]

def timestamped_filename(prefix = 'generated-'):
now = datetime.now()
timestamp = now.strftime("%m-%d-%Y_%H-%M-%S")
return f'{prefix}{timestamp}'

def train_from_folder(
data = './data',
results_dir = './results',
Expand Down Expand Up @@ -37,7 +45,13 @@ def train_from_folder(
no_const = False,
aug_prob = 0.,
dataset_aug_prob = 0.,
gpu_ids = [0]
):
gpu_ids = cast_list(gpu_ids)
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, gpu_ids))

from stylegan2_pytorch import Trainer, NanException

model = Trainer(
name,
results_dir,
Expand Down Expand Up @@ -70,17 +84,13 @@ def train_from_folder(
model.clear()

if generate:
now = datetime.now()
timestamp = now.strftime("%m-%d-%Y_%H-%M-%S")
samples_name = f'generated-{timestamp}'
samples_name = timestamped_filename()
model.evaluate(samples_name, num_image_tiles)
print(f'sample images generated at {results_dir}/{name}/{samples_name}')
return

if generate_interpolation:
now = datetime.now()
timestamp = now.strftime("%m-%d-%Y_%H-%M-%S")
samples_name = f'generated-{timestamp}'
samples_name = timestamped_filename()
model.generate_interpolation(samples_name, num_image_tiles, save_frames = save_frames)
print(f'interpolation generated at {results_dir}/{name}/{samples_name}')
return
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'stylegan2_pytorch',
packages = find_packages(),
scripts=['bin/stylegan2_pytorch'],
version = '0.21.0',
version = '0.21.1',
license='GPLv3+',
description = 'StyleGan2 in Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit b95548c

Please sign in to comment.