Skip to content

Commit

Permalink
Propagate separation of state and boundary change through training loop
Browse files Browse the repository at this point in the history
  • Loading branch information
joeloskarsson committed Oct 28, 2024
1 parent 5565e43 commit 365675d
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 99 deletions.
80 changes: 33 additions & 47 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,20 +58,21 @@ def __init__(self, args):
self.num_grid_nodes,
grid_static_dim,
) = self.grid_static_features.shape
(
self.num_boundary_nodes,
boundary_static_dim, # TODO Need for computation below
) = self.boundary_static_features.shape
self.num_input_nodes = self.num_grid_nodes + self.num_boundary_nodes
self.grid_dim = (
2 * self.config_loader.num_data_vars()
+ grid_static_dim
+ self.config_loader.dataset.num_forcing_features
)
self.boundary_dim = self.grid_dim # TODO Compute separately

# Instantiate loss function
self.loss = metrics.get_metric(args.loss)

# Pre-compute interior mask for use in loss function
self.register_buffer(
"interior_mask", 1.0 - self.border_mask, persistent=False
) # (num_grid_nodes, 1), 1 for non-border

self.step_length = args.step_length # Number of hours per pred. step
self.val_metrics = {
"mse": [],
Expand Down Expand Up @@ -99,21 +100,16 @@ def configure_optimizers(self):
)
return opt

@property
def interior_mask_bool(self):
"""
Get the interior mask as a boolean (N,) mask.
"""
return self.interior_mask[:, 0].to(torch.bool)

@staticmethod
def expand_to_batch(x, batch_size):
"""
Expand tensor with initial batch dimension
"""
return x.unsqueeze(0).expand(batch_size, -1, -1)

def predict_step(self, prev_state, prev_prev_state, forcing):
def predict_step(
self, prev_state, prev_prev_state, forcing, boundary_forcing
):
"""
Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1
prev_state: (B, num_grid_nodes, feature_dim), X_t
Expand All @@ -122,42 +118,36 @@ def predict_step(self, prev_state, prev_prev_state, forcing):
"""
raise NotImplementedError("No prediction step implemented")

def unroll_prediction(self, init_states, forcing_features, true_states):
def unroll_prediction(self, init_states, forcing, boundary_forcing):
"""
Roll out prediction taking multiple autoregressive steps with model
init_states: (B, 2, num_grid_nodes, d_f)
forcing_features: (B, pred_steps, num_grid_nodes, d_static_f)
true_states: (B, pred_steps, num_grid_nodes, d_f)
boundary_forcing: (B, pred_steps, num_boundary_nodes, d_boundary_f)
"""
prev_prev_state = init_states[:, 0]
prev_state = init_states[:, 1]
prediction_list = []
pred_std_list = []
pred_steps = forcing_features.shape[1]
pred_steps = forcing.shape[1]

for i in range(pred_steps):
forcing = forcing_features[:, i]
border_state = true_states[:, i]
forcing_step = forcing[:, i]
boundary_forcing_step = boundary_forcing[:, i]

pred_state, pred_std = self.predict_step(
prev_state, prev_prev_state, forcing
prev_state, prev_prev_state, forcing_step, boundary_forcing_step
)
# state: (B, num_grid_nodes, d_f)
# pred_std: (B, num_grid_nodes, d_f) or None

# Overwrite border with true state
new_state = (
self.border_mask * border_state
+ self.interior_mask * pred_state
)

prediction_list.append(new_state)
prediction_list.append(pred_state)
if self.output_std:
pred_std_list.append(pred_std)

# Update conditioning states
prev_prev_state = prev_state
prev_state = new_state
prev_state = pred_state

prediction = torch.stack(
prediction_list, dim=1
Expand All @@ -177,17 +167,15 @@ def common_step(self, batch):
batch consists of:
init_states: (B, 2, num_grid_nodes, d_features)
target_states: (B, pred_steps, num_grid_nodes, d_features)
forcing_features: (B, pred_steps, num_grid_nodes, d_forcing),
forcing: (B, pred_steps, num_grid_nodes, d_forcing),
boundary_forcing:
(B, pred_steps, num_boundary_nodes, d_boundary_forcing),
where index 0 corresponds to index 1 of init_states
"""
(
init_states,
target_states,
forcing_features,
) = batch
(init_states, target_states, forcing, boundary_forcing) = batch

prediction, pred_std = self.unroll_prediction(
init_states, forcing_features, target_states
init_states, forcing, boundary_forcing
) # (B, pred_steps, num_grid_nodes, d_f)
# prediction: (B, pred_steps, num_grid_nodes, d_f)
# pred_std: (B, pred_steps, num_grid_nodes, d_f) or (d_f,)
Expand All @@ -203,7 +191,9 @@ def training_step(self, batch):
# Compute loss
batch_loss = torch.mean(
self.loss(
prediction, target, pred_std, mask=self.interior_mask_bool
prediction,
target,
pred_std,
)
) # mean over unrolled times and batch

Expand Down Expand Up @@ -234,7 +224,9 @@ def validation_step(self, batch, batch_idx):

time_step_loss = torch.mean(
self.loss(
prediction, target, pred_std, mask=self.interior_mask_bool
prediction,
target,
pred_std,
),
dim=0,
) # (time_steps-1)
Expand All @@ -255,7 +247,6 @@ def validation_step(self, batch, batch_idx):
prediction,
target,
pred_std,
mask=self.interior_mask_bool,
sum_vars=False,
) # (B, pred_steps, d_f)
self.val_metrics["mse"].append(entry_mses)
Expand All @@ -282,7 +273,9 @@ def test_step(self, batch, batch_idx):

time_step_loss = torch.mean(
self.loss(
prediction, target, pred_std, mask=self.interior_mask_bool
prediction,
target,
pred_std,
),
dim=0,
) # (time_steps-1,)
Expand All @@ -309,16 +302,13 @@ def test_step(self, batch, batch_idx):
prediction,
target,
pred_std,
mask=self.interior_mask_bool,
sum_vars=False,
) # (B, pred_steps, d_f)
self.test_metrics[metric_name].append(batch_metric_vals)

if self.output_std:
# Store output std. per variable, spatially averaged
mean_pred_std = torch.mean(
pred_std[..., self.interior_mask_bool, :], dim=-2
) # (B, pred_steps, d_f)
mean_pred_std = torch.mean(pred_std, dim=-2) # (B, pred_steps, d_f)
self.test_metrics["output_std"].append(mean_pred_std)

# Save per-sample spatial loss for specific times
Expand Down Expand Up @@ -397,7 +387,6 @@ def plot_examples(self, batch, n_examples, prediction=None):
vis.plot_prediction(
pred_t[:, var_i],
target_t[:, var_i],
self.interior_mask[:, 0],
self.config_loader,
title=f"{var_name} ({var_unit}), "
f"t={t_i} ({self.step_length * t_i} h)",
Expand Down Expand Up @@ -541,7 +530,6 @@ def on_test_epoch_end(self):
loss_map_figs = [
vis.plot_spatial_error(
loss_map,
self.interior_mask[:, 0],
self.config_loader,
title=f"Test loss, t={t_i} ({self.step_length * t_i} h)",
)
Expand All @@ -556,9 +544,7 @@ def on_test_epoch_end(self):

# also make without title and save as pdf
pdf_loss_map_figs = [
vis.plot_spatial_error(
loss_map, self.interior_mask[:, 0], self.config_loader
)
vis.plot_spatial_error(loss_map, self.config_loader)
for loss_map in mean_spatial_loss
]
pdf_loss_maps_dir = os.path.join(wandb.run.dir, "spatial_loss_maps")
Expand Down
38 changes: 35 additions & 3 deletions neural_lam/models/base_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ def __init__(self, args):
# Define sub-models
# Feature embedders for grid
self.mlp_blueprint_end = [args.hidden_dim] * (args.hidden_layers + 1)
# TODO Optional separate embedder for boundary nodes
assert self.grid_dim == self.boundary_dim, (
"Grid and boundary input dimension must be the same when using "
f"the same encoder, got grid_dim={self.grid_dim}, "
f"boundary_dim={self.boundary_dim}"
)
self.grid_embedder = utils.make_mlp(
[self.grid_dim] + self.mlp_blueprint_end
)
Expand Down Expand Up @@ -98,12 +104,15 @@ def process_step(self, mesh_rep):
"""
raise NotImplementedError("process_step not implemented")

def predict_step(self, prev_state, prev_prev_state, forcing):
def predict_step(
self, prev_state, prev_prev_state, forcing, boundary_forcing
):
"""
Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1
prev_state: (B, num_grid_nodes, feature_dim), X_t
prev_prev_state: (B, num_grid_nodes, feature_dim), X_{t-1}
forcing: (B, num_grid_nodes, forcing_dim)
boundary_forcing: (B, num_boundary_nodes, boundary_forcing_dim)
"""
batch_size = prev_state.shape[0]

Expand All @@ -117,22 +126,45 @@ def predict_step(self, prev_state, prev_prev_state, forcing):
),
dim=-1,
)
# Create full boundary node features of shape
# (B, num_boundary_nodes, boundary_dim)
boundary_features = torch.cat(
(
boundary_forcing,
self.expand_to_batch(self.boundary_static_features, batch_size),
),
dim=-1,
)

# Embed all features
grid_emb = self.grid_embedder(grid_features) # (B, num_grid_nodes, d_h)
boundary_emb = self.grid_embedder(boundary_features)
# (B, num_boundary_nodes, d_h)
g2m_emb = self.g2m_embedder(self.g2m_features) # (M_g2m, d_h)
m2g_emb = self.m2g_embedder(self.m2g_features) # (M_m2g, d_h)
mesh_emb = self.embedd_mesh_nodes()

# Merge interior and boundary emb into input embedding
# TODO Can we enforce ordering in the graph creation process to make
# this just a concat instead?
input_emb = torch.zeros(
batch_size,
self.num_input_nodes,
grid_emb.shape[2],
device=grid_emb.device,
)
input_emb[:, self.interior_mask] = grid_emb
input_emb[:, self.boundary_mask] = boundary_emb

# Map from grid to mesh
mesh_emb_expanded = self.expand_to_batch(
mesh_emb, batch_size
) # (B, num_mesh_nodes, d_h)
g2m_emb_expanded = self.expand_to_batch(g2m_emb, batch_size)

# This also splits representation into grid and mesh
# Encode to mesh
mesh_rep = self.g2m_gnn(
grid_emb, mesh_emb_expanded, g2m_emb_expanded
input_emb, mesh_emb_expanded, g2m_emb_expanded
) # (B, num_mesh_nodes, d_h)
# Also MLP with residual for grid representation
grid_rep = grid_emb + self.encoding_grid_mlp(
Expand Down
24 changes: 16 additions & 8 deletions neural_lam/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,22 @@ def loads_file(fn):
)

# Load border mask, 1. if node is part of border, else 0.
border_mask_np = np.load(os.path.join(static_dir_path, "border_mask.npy"))
border_mask = (
torch.tensor(border_mask_np, dtype=torch.float32, device=device)
boundary_mask_np = np.load(os.path.join(static_dir_path, "border_mask.npy"))
boundary_mask = (
torch.tensor(boundary_mask_np, dtype=torch.float32, device=device)
.flatten(0, 1)
.unsqueeze(1)
) # (N_grid, 1)
.to(torch.bool)
) # (N_grid,)
interior_mask = torch.logical_not(boundary_mask)

grid_static_features = loads_file(
full_grid_static_features = loads_file(
"grid_features.pt"
) # (N_grid, d_grid_static)
) # (N_full_grid, d_grid_static)

grid_static_features = full_grid_static_features[interior_mask]
# (num_grid_nodes, d_grid_static)
boundary_static_features = full_grid_static_features[boundary_mask]
# (num_boundary_nodes, d_grid_static)

# Load step diff stats
step_diff_mean = loads_file("diff_mean.pt") # (d_f,)
Expand All @@ -73,8 +79,10 @@ def loads_file(fn):
) # (d_f,)

return {
"border_mask": border_mask,
"boundary_mask": boundary_mask,
"interior_mask": interior_mask,
"grid_static_features": grid_static_features,
"boundary_static_features": boundary_static_features,
"step_diff_mean": step_diff_mean,
"step_diff_std": step_diff_std,
"data_mean": data_mean,
Expand Down
Loading

0 comments on commit 365675d

Please sign in to comment.