diff --git a/README.md b/README.md index e39c75a15..72c59b99c 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ bash launch/run_train.sh ``` To test: ``` -python test --input_path path/to/img +python test.py --input_path path/to/img ``` ## Citation diff --git a/data/dataset_aus.py b/data/dataset_aus.py index 24192741f..e2edf9ed9 100644 --- a/data/dataset_aus.py +++ b/data/dataset_aus.py @@ -34,9 +34,9 @@ def __getitem__(self, index): real_cond = self._get_cond_by_id(sample_id) if real_img is None: - print 'error reading image %s, skipping sample' % sample_id + print('error reading image %s, skipping sample' % sample_id) if real_cond is None: - print 'error reading aus %s, skipping sample' % sample_id + print('error reading aus %s, skipping sample' % sample_id) desired_cond = self._generate_random_cond() @@ -92,11 +92,11 @@ def _create_transform(self): def _read_ids(self, file_path): ids = np.loadtxt(file_path, delimiter='\t', dtype=np.str) - return [id[:-4] for id in ids] + return [str.encode(id[:-4]) for id in ids] def _read_conds(self, file_path): with open(file_path, 'rb') as f: - return pickle.load(f) + return pickle.load(f, encoding='bytes') def _get_cond_by_id(self, id): if id in self._conds: diff --git a/launch/run_train.sh b/launch/run_train.sh index b427f40cd..740eb220a 100644 --- a/launch/run_train.sh +++ b/launch/run_train.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash python train.py \ ---data_dir path/to/dataset/ \ +--data_dir sample_dataset \ --name experiment_1 \ --batch_size 25 \ diff --git a/models/models.py b/models/models.py index 34e1ae024..fbe74492a 100644 --- a/models/models.py +++ b/models/models.py @@ -90,13 +90,13 @@ def _load_optimizer(self, optimizer, optimizer_label, epoch_label): load_path), 'Weights file not found. Have you trained a model!? We are not providing one' % load_path optimizer.load_state_dict(torch.load(load_path)) - print 'loaded optimizer: %s' % load_path + print('loaded optimizer: %s' % load_path) def _save_network(self, network, network_label, epoch_label): save_filename = 'net_epoch_%s_id_%s.pth' % (epoch_label, network_label) save_path = os.path.join(self._save_dir, save_filename) torch.save(network.state_dict(), save_path) - print 'saved net: %s' % save_path + print('saved net: %s' % save_path) def _load_network(self, network, network_label, epoch_label): load_filename = 'net_epoch_%s_id_%s.pth' % (epoch_label, network_label) @@ -105,7 +105,7 @@ def _load_network(self, network, network_label, epoch_label): load_path), 'Weights file not found. Have you trained a model!? We are not providing one' % load_path network.load_state_dict(torch.load(load_path)) - print 'loaded net: %s' % load_path + print('loaded net: %s' % load_path) def update_learning_rate(self): pass diff --git a/networks/networks.py b/networks/networks.py index c2fb9ed76..6da104a44 100644 --- a/networks/networks.py +++ b/networks/networks.py @@ -17,7 +17,7 @@ def get_by_name(network_name, *args, **kwargs): else: raise ValueError("Network %s not recognized." % network_name) - print "Network %s was created" % network_name + print("Network %s was created" % network_name) return network