Skip to content

Commit

Permalink
Added config dataclass to steps nowcast, v4
Browse files Browse the repository at this point in the history
  • Loading branch information
sidekock committed Oct 31, 2024
1 parent 3da1696 commit 8c7982c
Showing 1 changed file with 36 additions and 23 deletions.
59 changes: 36 additions & 23 deletions pysteps/nowcasts/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,24 +71,32 @@ class StepsNowcasterConfig:
class StepsNowcasterParams:
fft: Any = None
bandpass_filter: Any = None
extrapolator_method: Any = None
decomposition_method: Any = None
recomposition_method: Any = None
noise_generator: Optional[callable] = None
perturbation_generator: Optional[callable] = None
noise_std_coeffs: Optional[np.ndarray] = None
ar_model_coefficients: Optional[np.ndarray] = None
noise_std_coefficients: Optional[np.ndarray] = None
ar_model_coefficients: Optional[np.ndarray] = None # Corresponds to phi
autocorrelation_coefficients: Optional[np.ndarray] = None # Corresponds to gamma
domain_mask: Optional[np.ndarray] = None
structuring_element: Optional[np.ndarray] = None
precipitation_mean: Optional[float] = None
wet_area_ratio: Optional[float] = None
num_workers: int = 1
generate_noise: Optional[callable] = None
xy_coordinates: Optional[np.ndarray] = None
velocity_perturbation_parallel: Optional[List[float]] = None
velocity_perturbation_perpendicular: Optional[List[float]] = None
num_ensemble_workers: int = 1


@dataclass
class StepsNowcasterState:
precip_cascades: Optional[List[List[np.ndarray]]] = field(default_factory=list)
precip_decomposed: Optional[List[Dict[str, Any]]] = field(default_factory=list)
mask_precip: Optional[np.ndarray] = None
mask_threshold: Optional[np.ndarray] = None
random_generator_precip: Optional[List[np.random.RandomState]] = field(
default_factory=list
)
Expand All @@ -97,22 +105,19 @@ class StepsNowcasterState:
)
velocity_perturbations: Optional[List[callable]] = field(default_factory=list)
fft_objects: Optional[List[Any]] = field(default_factory=list)
precip_forecast: Optional[List[Any]] = field(default_factory=list)


class StepsNowcaster:
def __init__(self, precip, velocity, timesteps, steps_config):
def __init__(self, precip, velocity, time_steps, steps_config):
# Store inputs and optional parameters
self.precip = precip
self.velocity = velocity
self.timesteps = timesteps
self.tim_esteps = time_steps

# Store the config data:
self.config = steps_config

# Store the state and params data:
self.state = StepsNowcasterState()
self.params = StepsNowcasterParams()

# Additional variables for internal state management
self.fft = None
self.bandpass_filter = None
Expand All @@ -138,18 +143,21 @@ def __init__(self, precip, velocity, timesteps, steps_config):
self.recomp_method = None
self.xy_coords = None

self.mu_0 = None
self.precipitation_mean = None
# Initialize number of ensemble workers
self.num_ensemble_workers = min(
self.config.n_ens_members, self.config.num_workers
)

# Store the state and params data:
self.state = StepsNowcasterState()
self.params = StepsNowcasterParams()

# Additional variables for time measurement
self.start_time_init = None
self.init_time = None
self.mainloop_time = None

# Initialize number of ensemble workers
self.num_ensemble_workers = min(
self.config.n_ens_members, self.config.num_workers
)

def compute_forecast(self):
"""
Main loop for nowcast ensemble generation. This handles extrapolation,
Expand Down Expand Up @@ -214,7 +222,7 @@ def _nowcast_main(self):
precip,
self.velocity,
state,
self.timesteps,
self.tim_esteps,
self.config.extrapolation_method,
self._update_state, # Reference to the update function
extrap_kwargs=self.extrapolation_kwargs,
Expand Down Expand Up @@ -248,8 +256,8 @@ def _check_inputs(self):
f"shape(precip)={self.precip.shape}, shape(velocity)={self.velocity.shape}"
)
if (
isinstance(self.timesteps, list)
and not sorted(self.timesteps) == self.timesteps
isinstance(self.tim_esteps, list)
and not sorted(self.tim_esteps) == self.tim_esteps
):
raise ValueError("timesteps must be in ascending order")
if np.any(~np.isfinite(self.velocity)):
Expand Down Expand Up @@ -347,10 +355,10 @@ def _print_forecast_info(self):

print("Parameters")
print("----------")
if isinstance(self.timesteps, int):
print(f"number of time steps: {self.timesteps}")
if isinstance(self.tim_esteps, int):
print(f"number of time steps: {self.tim_esteps}")
else:
print(f"time steps: {self.timesteps}")
print(f"time steps: {self.tim_esteps}")
print(f"ensemble size: {self.config.n_ens_members}")
print(f"parallel threads: {self.config.num_workers}")
print(f"number of cascade levels: {self.config.n_cascade_levels}")
Expand Down Expand Up @@ -379,6 +387,11 @@ def _initialize_nowcast_components(self):
"""
Initialize the FFT, bandpass filters, decomposition methods, and extrapolation method.
"""
# Initialize number of ensemble workers
self.params.num_ensemble_workers = min(
self.config.n_ens_members, self.config.num_workers
)

M, N = self.precip.shape[1:] # Extract the spatial dimensions (height, width)

# Initialize FFT method
Expand Down Expand Up @@ -666,13 +679,13 @@ def _initialize_precipitation_mask(self):
self.precip_forecast = [[] for _ in range(self.config.n_ens_members)]

if self.config.probmatching_method == "mean":
self.mu_0 = np.mean(
self.precipitation_mean = np.mean(
self.precip[-1, :, :][
self.precip[-1, :, :] >= self.config.precip_threshold
]
)
else:
self.mu_0 = None
self.precipitation_mean = None

self.precip_mask = None
self.precip_mask_decomposed = None
Expand Down Expand Up @@ -759,7 +772,7 @@ def _initialize_params(self, precip):
"generate_noise": self.generate_noise,
"mask_method": self.config.mask_method,
"mask_rim": self.mask_rim,
"mu_0": self.mu_0,
"mu_0": self.precipitation_mean,
"n_cascade_levels": self.config.n_cascade_levels,
"n_ens_members": self.config.n_ens_members,
"noise_method": self.config.noise_method,
Expand Down

0 comments on commit 8c7982c

Please sign in to comment.