Skip to content

Commit

Permalink
Name changes from feedback Ruben
Browse files Browse the repository at this point in the history
  • Loading branch information
sidekock committed Oct 22, 2024
1 parent a1ce4bc commit b002354
Showing 1 changed file with 64 additions and 56 deletions.
120 changes: 64 additions & 56 deletions pysteps/nowcasts/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,35 +41,35 @@ def __init__(self, precip, velocity, timesteps, **kwargs):
self.timesteps = timesteps
self.n_ens_members = kwargs.get("n_ens_members", 24)
self.n_cascade_levels = kwargs.get("n_cascade_levels", 6)
self.precip_thr = kwargs.get("precip_thr", None)
self.precip_threshold = kwargs.get("precip_thr", None)
self.kmperpixel = kwargs.get("kmperpixel", None)
self.timestep = kwargs.get("timestep", None)
self.extrap_method = kwargs.get("extrap_method", "semilagrangian")
self.decomp_method = kwargs.get("decomp_method", "fft")
self.extrapolation_method = kwargs.get("extrap_method", "semilagrangian")
self.decomposition_method = kwargs.get("decomp_method", "fft")
self.bandpass_filter_method = kwargs.get("bandpass_filter_method", "gaussian")
self.noise_method = kwargs.get("noise_method", "nonparametric")
self.noise_stddev_adj = kwargs.get("noise_stddev_adj", None)
self.ar_order = kwargs.get("ar_order", 2)
self.vel_pert_method = kwargs.get("vel_pert_method", "bps")
self.velocity_perturbation_method = kwargs.get("vel_pert_method", "bps")
self.conditional = kwargs.get("conditional", False)
self.probmatching_method = kwargs.get("probmatching_method", "cdf")
self.mask_method = kwargs.get("mask_method", "incremental")
self.seed = kwargs.get("seed", None)
self.num_workers = kwargs.get("num_workers", 1)
self.fft_method = kwargs.get("fft_method", "numpy")
self.domain = kwargs.get("domain", "spatial")
self.extrap_kwargs = kwargs.get("extrap_kwargs", None)
self.extrapolation_kwargs = kwargs.get("extrap_kwargs", None)
self.filter_kwargs = kwargs.get("filter_kwargs", None)
self.noise_kwargs = kwargs.get("noise_kwargs", None)
self.vel_pert_kwargs = kwargs.get("vel_pert_kwargs", None)
self.velocity_pertubation_kwargs = kwargs.get("vel_pert_kwargs", None)
self.mask_kwargs = kwargs.get("mask_kwargs", None)
self.measure_time = kwargs.get("measure_time", False)
self.callback = kwargs.get("callback", None)
self.return_output = kwargs.get("return_output", True)

# Additional variables for internal state management
self.fft = None
self.bp_filter = None
self.bandpass_filter = None
self.extrapolator_method = None
self.domain_mask = None
self.precip_cascades = None
Expand All @@ -84,8 +84,8 @@ def __init__(self, precip, velocity, timesteps, **kwargs):
self.mask_prec = None
self.mask_thr = None
self.precip_decomp = None
self.vp_par = None
self.vp_perp = None
self.velocity_pertubation_parallel = None
self.velocity_pertubation_perp = None
self.fft_objs = None
self.generate_noise = None

Expand Down Expand Up @@ -121,7 +121,7 @@ def compute_forecast(self):
if self.measure_time:
self._measure_time("Initialization", self.start_time_init)

# RUn the main nowcast loop
# Run the main nowcast loop
self._nowcast_main()

if self.measure_time:
Expand Down Expand Up @@ -159,9 +159,9 @@ def _nowcast_main(self):
self.velocity,
state,
self.timesteps,
self.extrap_method,
self.extrapolation_method,
self._update_state, # Reference to the update function
extrap_kwargs=self.extrap_kwargs,
extrap_kwargs=self.extrapolation_kwargs,
velocity_pert_gen=self.velocity_perturbations,
params=params,
ensemble=True,
Expand Down Expand Up @@ -203,7 +203,7 @@ def _check_inputs(self):
f"Unknown mask method '{self.mask_method}'. "
"Must be 'obs', 'sprog', 'incremental', or None."
)
if self.precip_thr is None:
if self.precip_threshold is None:
if self.conditional:
raise ValueError("conditional=True but precip_thr is not specified.")
if self.mask_method is not None:
Expand All @@ -222,25 +222,25 @@ def _check_inputs(self):
"Must be 'auto', 'fixed', or None."
)
if self.kmperpixel is None:
if self.vel_pert_method is not None:
if self.velocity_perturbation_method is not None:
raise ValueError("vel_pert_method is set but kmperpixel=None")
if self.mask_method == "incremental":
raise ValueError("mask_method='incremental' but kmperpixel=None")
if self.timestep is None:
if self.vel_pert_method is not None:
if self.velocity_perturbation_method is not None:
raise ValueError("vel_pert_method is set but timestep=None")
if self.mask_method == "incremental":
raise ValueError("mask_method='incremental' but timestep=None")

# Handle None values for various kwargs
if self.extrap_kwargs is None:
self.extrap_kwargs = {}
if self.extrapolation_kwargs is None:
self.extrapolation_kwargs = {}
if self.filter_kwargs is None:
self.filter_kwargs = {}
if self.noise_kwargs is None:
self.noise_kwargs = {}
if self.vel_pert_kwargs is None:
self.vel_pert_kwargs = {}
if self.velocity_pertubation_kwargs is None:
self.velocity_pertubation_kwargs = {}
if self.mask_kwargs is None:
self.mask_kwargs = {}

Expand All @@ -265,16 +265,16 @@ def _print_forecast_info(self):

print("Methods")
print("-------")
print(f"extrapolation: {self.extrap_method}")
print(f"extrapolation: {self.extrapolation_method}")
print(f"bandpass filter: {self.bandpass_filter_method}")
print(f"decomposition: {self.decomp_method}")
print(f"decomposition: {self.decomposition_method}")
print(f"noise generator: {self.noise_method}")
print(
"noise adjustment: {}".format(
("yes" if self.noise_stddev_adj else "no")
)
)
print(f"velocity perturbator: {self.vel_pert_method}")
print(f"velocity perturbator: {self.velocity_perturbation_method}")
print(
"conditional statistics: {}".format(("yes" if self.conditional else "no"))
)
Expand All @@ -295,22 +295,22 @@ def _print_forecast_info(self):
print(f"number of cascade levels: {self.n_cascade_levels}")
print(f"order of the AR(p) model: {self.ar_order}")

if self.vel_pert_method == "bps":
self.vp_par = self.vel_pert_kwargs.get(
if self.velocity_perturbation_method == "bps":
self.velocity_pertubation_parallel = self.velocity_pertubation_kwargs.get(
"p_par", noise.motion.get_default_params_bps_par()
)
self.vp_perp = self.vel_pert_kwargs.get(
self.velocity_pertubation_perp = self.velocity_pertubation_kwargs.get(
"p_perp", noise.motion.get_default_params_bps_perp()
)
print(
f"velocity perturbations, parallel: {self.vp_par[0]},{self.vp_par[1]},{self.vp_par[2]}"
f"velocity perturbations, parallel: {self.velocity_pertubation_parallel[0]},{self.velocity_pertubation_parallel[1]},{self.velocity_pertubation_parallel[2]}"
)
print(
f"velocity perturbations, perpendicular: {self.vp_perp[0]},{self.vp_perp[1]},{self.vp_perp[2]}"
f"velocity perturbations, perpendicular: {self.velocity_pertubation_perp[0]},{self.velocity_pertubation_perp[1]},{self.velocity_pertubation_perp[2]}"
)

if self.precip_thr is not None:
print(f"precip. intensity threshold: {self.precip_thr}")
if self.precip_threshold is not None:
print(f"precip. intensity threshold: {self.precip_threshold}")

def _initialize_nowcast_components(self):
"""
Expand All @@ -325,15 +325,17 @@ def _initialize_nowcast_components(self):

# Initialize the band-pass filter for the cascade decomposition
filter_method = cascade.get_method(self.bandpass_filter_method)
self.bp_filter = filter_method(
self.bandpass_filter = filter_method(
(M, N), self.n_cascade_levels, **(self.filter_kwargs or {})
)

# Get the decomposition method (e.g., FFT)
self.decomp_method, self.recomp_method = cascade.get_method(self.decomp_method)
self.decomposition_method, self.recomp_method = cascade.get_method(
self.decomposition_method
)

# Get the extrapolation method (e.g., semilagrangian)
self.extrapolator_method = extrapolation.get_method(self.extrap_method)
self.extrapolator_method = extrapolation.get_method(self.extrapolation_method)

# Generate the mesh grid for spatial coordinates
x_values, y_values = np.meshgrid(np.arange(N), np.arange(M))
Expand All @@ -355,14 +357,14 @@ def _perform_extrapolation(self):
if self.conditional:
self.mask_thr = np.logical_and.reduce(
[
self.precip[i, :, :] >= self.precip_thr
self.precip[i, :, :] >= self.precip_threshold
for i in range(self.precip.shape[0])
]
)
else:
self.mask_thr = None

extrap_kwargs = self.extrap_kwargs.copy()
extrap_kwargs = self.extrapolation_kwargs.copy()
extrap_kwargs["xy_coords"] = self.xy_coords
extrap_kwargs["allow_nonfinite_values"] = (
True if np.any(~np.isfinite(self.precip)) else False
Expand Down Expand Up @@ -434,10 +436,10 @@ def _apply_noise_and_ar_model(self):
# Compute noise adjustment coefficients
self.noise_std_coeffs = noise.utils.compute_noise_stddev_adjs(
self.precip[-1, :, :],
self.precip_thr,
self.precip_threshold,
np.min(self.precip),
self.bp_filter,
self.decomp_method,
self.bandpass_filter,
self.decomposition_method,
self.pert_gen,
self.generate_noise,
20,
Expand Down Expand Up @@ -479,9 +481,9 @@ def _apply_noise_and_ar_model(self):
# Decompose the input precipitation fields
self.precip_decomp = []
for i in range(self.ar_order + 1):
precip_ = self.decomp_method(
precip_ = self.decomposition_method(
self.precip[i, :, :],
self.bp_filter,
self.bandpass_filter,
mask=self.mask_thr,
fft_method=self.fft,
output_domain=self.domain,
Expand Down Expand Up @@ -560,15 +562,21 @@ def _initialize_velocity_perturbations(self):
Initialize the velocity perturbators for each ensemble member if the velocity
perturbation method is specified.
"""
if self.vel_pert_method is not None:
init_vel_noise, generate_vel_noise = noise.get_method(self.vel_pert_method)
if self.velocity_perturbation_method is not None:
init_vel_noise, generate_vel_noise = noise.get_method(
self.velocity_perturbation_method
)

self.velocity_perturbations = []
for j in range(self.n_ens_members):
kwargs = {
"randstate": self.randgen_motion[j],
"p_par": self.vel_pert_kwargs.get("p_par", self.vp_par),
"p_perp": self.vel_pert_kwargs.get("p_perp", self.vp_perp),
"p_par": self.velocity_pertubation_kwargs.get(
"p_par", self.velocity_pertubation_parallel
),
"p_perp": self.velocity_pertubation_kwargs.get(
"p_perp", self.velocity_pertubation_perp
),
}
vp = init_vel_noise(
self.velocity, 1.0 / self.kmperpixel, self.timestep, **kwargs
Expand All @@ -588,30 +596,30 @@ def _initialize_precipitation_mask(self):

if self.probmatching_method == "mean":
self.mu_0 = np.mean(
self.precip[-1, :, :][self.precip[-1, :, :] >= self.precip_thr]
self.precip[-1, :, :][self.precip[-1, :, :] >= self.precip_threshold]
)
else:
self.mu_0 = None

self.precip_m = None
self.precip_m_d = None
self.precip_mask = None
self.precip_mask_decomposed = None
self.war = None
self.struct = None
self.mask_rim = None

if self.mask_method is not None:
self.mask_prec = self.precip[-1, :, :] >= self.precip_thr
self.mask_prec = self.precip[-1, :, :] >= self.precip_threshold

if self.mask_method == "sprog":
# Compute the wet area ratio and the precipitation mask
self.war = np.sum(self.mask_prec) / (
self.precip.shape[1] * self.precip.shape[2]
)
self.precip_m = [
self.precip_mask = [
self.precip_cascades[0][i].copy()
for i in range(self.n_cascade_levels)
]
self.precip_m_d = self.precip_decomp[0].copy()
self.precip_mask_decomposed = self.precip_decomp[0].copy()

elif self.mask_method == "incremental":
# Get mask parameters
Expand All @@ -632,8 +640,8 @@ def _initialize_precipitation_mask(self):
else:
self.mask_prec = None

if self.noise_method is None and self.precip_m is None:
self.precip_m = [
if self.noise_method is None and self.precip_mask is None:
self.precip_mask = [
self.precip_cascades[0][i].copy() for i in range(self.n_cascade_levels)
]
print("Precipitation mask initialized successfully.")
Expand All @@ -657,8 +665,8 @@ def _initialize_state(self):
"mask_prec": self.mask_prec,
"precip_cascades": self.precip_cascades,
"precip_decomp": self.precip_decomp,
"precip_m": self.precip_m,
"precip_m_d": self.precip_m_d,
"precip_m": self.precip_mask,
"precip_m_d": self.precip_mask_decomposed,
"randgen_prec": self.randgen_prec,
}

Expand All @@ -667,10 +675,10 @@ def _initialize_params(self, precip):
Initialize the params dictionary used during the nowcast iteration.
"""
return {
"decomp_method": self.decomp_method,
"decomp_method": self.decomposition_method,
"domain": self.domain,
"domain_mask": self.domain_mask,
"filter": self.bp_filter,
"filter": self.bandpass_filter,
"fft": self.fft,
"generate_noise": self.generate_noise,
"mask_method": self.mask_method,
Expand All @@ -685,7 +693,7 @@ def _initialize_params(self, precip):
"pert_gen": self.pert_gen,
"probmatching_method": self.probmatching_method,
"precip": precip,
"precip_thr": self.precip_thr,
"precip_thr": self.precip_threshold,
"recomp_method": self.recomp_method,
"struct": self.struct,
"war": self.war,
Expand Down

0 comments on commit b002354

Please sign in to comment.