Skip to content

Commit

Permalink
Fully refactored code
Browse files Browse the repository at this point in the history
  • Loading branch information
sidekock committed Oct 9, 2024
1 parent 0c5185f commit 46bc44a
Showing 1 changed file with 191 additions and 151 deletions.
342 changes: 191 additions & 151 deletions pysteps/nowcasts/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(self, precip, velocity, timesteps, **kwargs):
self.noise_std_coeffs = None
self.randgen_prec = None
self.randgen_motion = None
self.velocity_perturbators = None
self.velocity_perturbations = None
self.precip_forecast = None
self.mask_prec = None
self.mask_thr = None
Expand Down Expand Up @@ -114,7 +114,7 @@ def compute_forecast(self):

self._perform_extrapolation()
self._apply_noise_and_ar_model()
self._initialize_velocity_perturbators()
self._initialize_velocity_perturbations()
self._initialize_precipitation_mask()
self._initialize_fft_objects()
# Measure and print initialization time
Expand Down Expand Up @@ -164,7 +164,7 @@ def _nowcast_main(self):
self.extrap_method,
self._update_state, # Reference to the update function
extrap_kwargs=self.extrap_kwargs,
velocity_pert_gen=self.velocity_perturbators,
velocity_pert_gen=self.velocity_perturbations,
params=params,
ensemble=True,
num_ensemble_members=self.n_ens_members,
Expand Down Expand Up @@ -557,15 +557,15 @@ def _apply_noise_and_ar_model(self):
self.randgen_motion = None
print("AR model and noise applied to precipitation cascades.")

def _initialize_velocity_perturbators(self):
def _initialize_velocity_perturbations(self):
"""
Initialize the velocity perturbators for each ensemble member if the velocity
perturbation method is specified.
"""
if self.vel_pert_method is not None:
init_vel_noise, generate_vel_noise = noise.get_method(self.vel_pert_method)

self.velocity_perturbators = []
self.velocity_perturbations = []
for j in range(self.n_ens_members):
kwargs = {
"randstate": self.randgen_motion[j],
Expand All @@ -575,11 +575,11 @@ def _initialize_velocity_perturbators(self):
vp = init_vel_noise(
self.velocity, 1.0 / self.kmperpixel, self.timestep, **kwargs
)
self.velocity_perturbators.append(
self.velocity_perturbations.append(
lambda t, vp=vp: generate_vel_noise(vp, t * self.timestep)
)
else:
self.velocity_perturbators = None
self.velocity_perturbations = None
print("Velocity perturbations initialized successfully.")

def _initialize_precipitation_mask(self):
Expand Down Expand Up @@ -693,9 +693,190 @@ def _initialize_params(self, precip):
"war": self.war,
}

def _update_state(self):
# TODO
pass
def _update_state(self, state, params):
"""
Update the state during the nowcasting loop. This function handles the AR model iteration,
noise generation, recomposition, and mask application for each ensemble member.
"""
precip_forecast_out = [None] * params["n_ens_members"]

# Update the deterministic AR(p) model if noise or sprog mask is used
if params["noise_method"] is None or params["mask_method"] == "sprog":
self._update_deterministic_ar_model(state, params)

# Worker function for each ensemble member
def worker(j):
self._apply_ar_model_to_cascades(j, state, params)
precip_forecast_out[j] = self._recompose_and_apply_mask(j, state, params)

# Use Dask for parallel execution if available
if (
DASK_IMPORTED
and params["n_ens_members"] > 1
and params["num_ensemble_workers"] > 1
):
res = []
for j in range(params["n_ens_members"]):
res.append(dask.delayed(worker)(j))
dask.compute(*res, num_workers=params["num_ensemble_workers"])
else:
for j in range(params["n_ens_members"]):
worker(j)

return np.stack(precip_forecast_out), state

def _update_deterministic_ar_model(self, state, params):
"""
Update the deterministic AR(p) model for each cascade level if noise is disabled
or if the sprog mask is used.
"""
for i in range(params["n_cascade_levels"]):
state["precip_m"][i] = autoregression.iterate_ar_model(
state["precip_m"][i], params["phi"][i, :]
)

state["precip_m_d"]["cascade_levels"] = [
state["precip_m"][i][-1] for i in range(params["n_cascade_levels"])
]

if params["domain"] == "spatial":
state["precip_m_d"]["cascade_levels"] = np.stack(
state["precip_m_d"]["cascade_levels"]
)

precip_m_ = params["recomp_method"](state["precip_m_d"])

if params["domain"] == "spectral":
precip_m_ = params["fft"].irfft2(precip_m_)

if params["mask_method"] == "sprog":
state["mask_prec"] = compute_percentile_mask(precip_m_, params["war"])

def _apply_ar_model_to_cascades(self, j, state, params):
"""
Apply the AR(p) model to the cascades for each ensemble member, including
noise generation and normalization.
"""
# Generate noise if enabled
if params["noise_method"] is not None:
eps = self._generate_and_decompose_noise(j, state, params)
else:
eps = None

# Iterate the AR(p) model for each cascade level
for i in range(params["n_cascade_levels"]):
if eps is not None:
eps_ = eps["cascade_levels"][i]
eps_ *= params["noise_std_coeffs"][i]
else:
eps_ = None

# Apply the AR(p) model with or without perturbations
if eps is not None or params["vel_pert_method"] is not None:
state["precip_cascades"][j][i] = autoregression.iterate_ar_model(
state["precip_cascades"][j][i], params["phi"][i, :], eps=eps_
)
else:
# use the deterministic AR(p) model computed above if
# perturbations are disabled
state["precip_cascades"][j][i] = state["precip_m"][i]

eps = None
eps_ = None

def _generate_and_decompose_noise(self, j, state, params):
"""
Generate and decompose the noise field into cascades for a given ensemble member.
"""
eps = params["generate_noise"](
params["pert_gen"],
randstate=state["randgen_prec"][j],
fft_method=state["fft_objs"][j],
domain=params["domain"],
)

eps = params["decomp_method"](
eps,
params["filter"],
fft_method=state["fft_objs"][j],
input_domain=params["domain"],
output_domain=params["domain"],
compute_stats=True,
normalize=True,
compact_output=True,
)

return eps

def _recompose_and_apply_mask(self, j, state, params):
"""
Recompose the precipitation field from cascades and apply the precipitation mask.
"""
state["precip_decomp"][j]["cascade_levels"] = [
state["precip_cascades"][j][i][-1, :]
for i in range(params["n_cascade_levels"])
]

if params["domain"] == "spatial":
state["precip_decomp"][j]["cascade_levels"] = np.stack(
state["precip_decomp"][j]["cascade_levels"]
)

precip_forecast = params["recomp_method"](state["precip_decomp"][j])

if params["domain"] == "spectral":
precip_forecast = state["fft_objs"][j].irfft2(precip_forecast)

# Apply the precipitation mask
if params["mask_method"] is not None:
precip_forecast = self._apply_precipitation_mask(
precip_forecast, j, state, params
)

# Adjust the CDF of the forecast to match the observed precipitation field
if params["probmatching_method"] == "cdf":
precip_forecast = probmatching.nonparam_match_empirical_cdf(
precip_forecast, params["precip"]
)
# Adjust the mean of the forecast to match the observed mean
elif params["probmatching_method"] == "mean":
mask = precip_forecast >= params["precip_thr"]
mu_fct = np.mean(precip_forecast[mask])
precip_forecast[mask] = precip_forecast[mask] - mu_fct + params["mu_0"]

# Update the mask for incremental method
if params["mask_method"] == "incremental":
state["mask_prec"][j] = nowcast_utils.compute_dilated_mask(
precip_forecast >= params["precip_thr"],
params["struct"],
params["mask_rim"],
)

# Apply the domain mask (set masked areas to NaN)
precip_forecast[params["domain_mask"]] = np.nan

return precip_forecast

def _apply_precipitation_mask(self, precip_forecast, j, state, params):
"""
Apply the precipitation mask to prevent new precipitation from generating
in areas where it was not observed.
"""
precip_forecast_min = precip_forecast.min()

if params["mask_method"] == "incremental":
precip_forecast = (
precip_forecast_min
+ (precip_forecast - precip_forecast_min) * state["mask_prec"][j]
)
mask_prec_ = precip_forecast > precip_forecast_min
else:
mask_prec_ = state["mask_prec"]

# Set to min value outside the mask
precip_forecast[~mask_prec_] = precip_forecast_min

return precip_forecast

def _measure_time(self, label, start_time):
"""
Expand Down Expand Up @@ -975,144 +1156,3 @@ def forecast(

# Call the appropriate methods within the class
return nowcaster.compute_forecast()


def _update(state, params):
precip_forecast_out = [None] * params["n_ens_members"]

if params["noise_method"] is None or params["mask_method"] == "sprog":
for i in range(params["n_cascade_levels"]):
# use a separate AR(p) model for the non-perturbed forecast,
# from which the mask is obtained
state["precip_m"][i] = autoregression.iterate_ar_model(
state["precip_m"][i], params["phi"][i, :]
)

state["precip_m_d"]["cascade_levels"] = [
state["precip_m"][i][-1] for i in range(params["n_cascade_levels"])
]
if params["domain"] == "spatial":
state["precip_m_d"]["cascade_levels"] = np.stack(
state["precip_m_d"]["cascade_levels"]
)
precip_m_ = params["recomp_method"](state["precip_m_d"])
if params["domain"] == "spectral":
precip_m_ = params["fft"].irfft2(precip_m_)

if params["mask_method"] == "sprog":
state["mask_prec"] = compute_percentile_mask(precip_m_, params["war"])

def worker(j):
if params["noise_method"] is not None:
# generate noise field
eps = params["generate_noise"](
params["pert_gen"],
randstate=state["randgen_prec"][j],
fft_method=state["fft_objs"][j],
domain=params["domain"],
)

# decompose the noise field into a cascade
eps = params["decomp_method"](
eps,
params["filter"],
fft_method=state["fft_objs"][j],
input_domain=params["domain"],
output_domain=params["domain"],
compute_stats=True,
normalize=True,
compact_output=True,
)
else:
eps = None

# iterate the AR(p) model for each cascade level
for i in range(params["n_cascade_levels"]):
# normalize the noise cascade
if eps is not None:
eps_ = eps["cascade_levels"][i]
eps_ *= params["noise_std_coeffs"][i]
else:
eps_ = None
# apply AR(p) process to cascade level
if eps is not None or params["vel_pert_method"] is not None:
state["precip_cascades"][j][i] = autoregression.iterate_ar_model(
state["precip_cascades"][j][i], params["phi"][i, :], eps=eps_
)
else:
# use the deterministic AR(p) model computed above if
# perturbations are disabled
state["precip_cascades"][j][i] = state["precip_m"][i]

eps = None
eps_ = None

# compute the recomposed precipitation field(s) from the cascades
# obtained from the AR(p) model(s)
state["precip_decomp"][j]["cascade_levels"] = [
state["precip_cascades"][j][i][-1, :]
for i in range(params["n_cascade_levels"])
]
if params["domain"] == "spatial":
state["precip_decomp"][j]["cascade_levels"] = np.stack(
state["precip_decomp"][j]["cascade_levels"]
)

precip_forecast = params["recomp_method"](state["precip_decomp"][j])

if params["domain"] == "spectral":
precip_forecast = state["fft_objs"][j].irfft2(precip_forecast)

if params["mask_method"] is not None:
# apply the precipitation mask to prevent generation of new
# precipitation into areas where it was not originally
# observed
precip_forecast_min = precip_forecast.min()
if params["mask_method"] == "incremental":
precip_forecast = (
precip_forecast_min
+ (precip_forecast - precip_forecast_min) * state["mask_prec"][j]
)
mask_prec_ = precip_forecast > precip_forecast_min
else:
mask_prec_ = state["mask_prec"]

# set to min value outside mask
precip_forecast[~mask_prec_] = precip_forecast_min

if params["probmatching_method"] == "cdf":
# adjust the CDF of the forecast to match the most recently
# observed precipitation field
precip_forecast = probmatching.nonparam_match_empirical_cdf(
precip_forecast, params["precip"]
)
elif params["probmatching_method"] == "mean":
mask = precip_forecast >= params["precip_thr"]
mu_fct = np.mean(precip_forecast[mask])
precip_forecast[mask] = precip_forecast[mask] - mu_fct + params["mu_0"]

if params["mask_method"] == "incremental":
state["mask_prec"][j] = nowcast_utils.compute_dilated_mask(
precip_forecast >= params["precip_thr"],
params["struct"],
params["mask_rim"],
)

precip_forecast[params["domain_mask"]] = np.nan

precip_forecast_out[j] = precip_forecast

if (
DASK_IMPORTED
and params["n_ens_members"] > 1
and params["num_ensemble_workers"] > 1
):
res = []
for j in range(params["n_ens_members"]):
res.append(dask.delayed(worker)(j))
dask.compute(*res, num_workers=params["num_ensemble_workers"])
else:
for j in range(params["n_ens_members"]):
worker(j)

return np.stack(precip_forecast_out), state

0 comments on commit 46bc44a

Please sign in to comment.