Skip to content

Commit

Permalink
Merge pull request #584 from mala-project/validate_every_n_epochs
Browse files Browse the repository at this point in the history
Validation every N steps
  • Loading branch information
RandomDefaultUser authored Oct 14, 2024
2 parents ff0ea2f + 032feb7 commit e062deb
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 7 deletions.
16 changes: 11 additions & 5 deletions examples/advanced/ex03_tensor_board.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
parameters = mala.Parameters()
parameters.data.input_rescaling_type = "feature-wise-standard"
parameters.data.output_rescaling_type = "normal"
parameters.targets.ldos_gridsize = 11
parameters.targets.ldos_gridspacing_ev = 2.5
parameters.targets.ldos_gridoffset_ev = -5
parameters.network.layer_activations = ["ReLU"]
parameters.running.max_number_epochs = 100
parameters.running.mini_batch_size = 40
Expand All @@ -22,16 +25,19 @@

# Turn the visualization on and select a folder to save the visualization
# files into.
parameters.running.visualisation = 1
parameters.running.visualisation_dir = "mala_vis"

parameters.running.logger = "tensorboard"
parameters.running.logging_dir = "mala_vis"
parameters.running.validation_metrics = ["ldos", "band_energy"]
parameters.running.validate_every_n_epochs = 5

data_handler = mala.DataHandler(parameters)
data_handler.add_snapshot(
"Be_snapshot0.in.npy", data_path, "Be_snapshot0.out.npy", data_path, "tr"
"Be_snapshot0.in.npy", data_path, "Be_snapshot0.out.npy", data_path, "tr",
calculation_output_file=os.path.join(data_path, "Be_snapshot0.out"),
)
data_handler.add_snapshot(
"Be_snapshot1.in.npy", data_path, "Be_snapshot1.out.npy", data_path, "va"
"Be_snapshot1.in.npy", data_path, "Be_snapshot1.out.npy", data_path, "va",
calculation_output_file=os.path.join(data_path, "Be_snapshot1.out"),
)
data_handler.prepare_data()
parameters.network.layer_sizes = [
Expand Down
1 change: 1 addition & 0 deletions mala/common/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,7 @@ def __init__(self):
self.logger = "tensorboard"
self.validation_metrics = ["ldos"]
self.validate_on_training_data = False
self.validate_every_n_epochs = 1
self.inference_data_grid = [0, 0, 0]
self.use_mixed_precision = False
self.use_graphs = False
Expand Down
10 changes: 8 additions & 2 deletions mala/network/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,16 +434,22 @@ def train_network(self):
self.network, inputs, outputs
)
batchid += 1
total_batch_id += 1

dataset_fractions = ["validation"]
if self.parameters.validate_on_training_data:
dataset_fractions.append("train")
validation_metrics = ["ldos"]
if (epoch != 0 and
(epoch - 1) % self.parameters.validate_every_n_epochs == 0):
validation_metrics = self.parameters.validation_metrics
errors = self._validate_network(
dataset_fractions, self.parameters.validation_metrics
dataset_fractions, validation_metrics
)
for dataset_fraction in dataset_fractions:
for metric in errors[dataset_fraction]:
errors[dataset_fraction][metric] = np.mean(
errors[dataset_fraction][metric]
np.abs(errors[dataset_fraction][metric])
)
vloss = errors["validation"][
self.parameters.during_training_metric
Expand Down

0 comments on commit e062deb

Please sign in to comment.