diff --git a/pysteps/nowcasts/steps.py b/pysteps/nowcasts/steps.py index 7cdd0d3f..abbc0100 100644 --- a/pysteps/nowcasts/steps.py +++ b/pysteps/nowcasts/steps.py @@ -71,17 +71,24 @@ class StepsNowcasterConfig: class StepsNowcasterParams: fft: Any = None bandpass_filter: Any = None + extrapolator_method: Any = None decomposition_method: Any = None recomposition_method: Any = None noise_generator: Optional[callable] = None perturbation_generator: Optional[callable] = None - noise_std_coeffs: Optional[np.ndarray] = None - ar_model_coefficients: Optional[np.ndarray] = None + noise_std_coefficients: Optional[np.ndarray] = None + ar_model_coefficients: Optional[np.ndarray] = None # Corresponds to phi + autocorrelation_coefficients: Optional[np.ndarray] = None # Corresponds to gamma domain_mask: Optional[np.ndarray] = None structuring_element: Optional[np.ndarray] = None precipitation_mean: Optional[float] = None wet_area_ratio: Optional[float] = None num_workers: int = 1 + generate_noise: Optional[callable] = None + xy_coordinates: Optional[np.ndarray] = None + velocity_perturbation_parallel: Optional[List[float]] = None + velocity_perturbation_perpendicular: Optional[List[float]] = None + num_ensemble_workers: int = 1 @dataclass @@ -89,6 +96,7 @@ class StepsNowcasterState: precip_cascades: Optional[List[List[np.ndarray]]] = field(default_factory=list) precip_decomposed: Optional[List[Dict[str, Any]]] = field(default_factory=list) mask_precip: Optional[np.ndarray] = None + mask_threshold: Optional[np.ndarray] = None random_generator_precip: Optional[List[np.random.RandomState]] = field( default_factory=list ) @@ -97,22 +105,19 @@ class StepsNowcasterState: ) velocity_perturbations: Optional[List[callable]] = field(default_factory=list) fft_objects: Optional[List[Any]] = field(default_factory=list) + precip_forecast: Optional[List[Any]] = field(default_factory=list) class StepsNowcaster: - def __init__(self, precip, velocity, timesteps, steps_config): + def __init__(self, precip, velocity, time_steps, steps_config): # Store inputs and optional parameters self.precip = precip self.velocity = velocity - self.timesteps = timesteps + self.tim_esteps = time_steps # Store the config data: self.config = steps_config - # Store the state and params data: - self.state = StepsNowcasterState() - self.params = StepsNowcasterParams() - # Additional variables for internal state management self.fft = None self.bandpass_filter = None @@ -138,18 +143,21 @@ def __init__(self, precip, velocity, timesteps, steps_config): self.recomp_method = None self.xy_coords = None - self.mu_0 = None + self.precipitation_mean = None + # Initialize number of ensemble workers + self.num_ensemble_workers = min( + self.config.n_ens_members, self.config.num_workers + ) + + # Store the state and params data: + self.state = StepsNowcasterState() + self.params = StepsNowcasterParams() # Additional variables for time measurement self.start_time_init = None self.init_time = None self.mainloop_time = None - # Initialize number of ensemble workers - self.num_ensemble_workers = min( - self.config.n_ens_members, self.config.num_workers - ) - def compute_forecast(self): """ Main loop for nowcast ensemble generation. This handles extrapolation, @@ -214,7 +222,7 @@ def _nowcast_main(self): precip, self.velocity, state, - self.timesteps, + self.tim_esteps, self.config.extrapolation_method, self._update_state, # Reference to the update function extrap_kwargs=self.extrapolation_kwargs, @@ -248,8 +256,8 @@ def _check_inputs(self): f"shape(precip)={self.precip.shape}, shape(velocity)={self.velocity.shape}" ) if ( - isinstance(self.timesteps, list) - and not sorted(self.timesteps) == self.timesteps + isinstance(self.tim_esteps, list) + and not sorted(self.tim_esteps) == self.tim_esteps ): raise ValueError("timesteps must be in ascending order") if np.any(~np.isfinite(self.velocity)): @@ -347,10 +355,10 @@ def _print_forecast_info(self): print("Parameters") print("----------") - if isinstance(self.timesteps, int): - print(f"number of time steps: {self.timesteps}") + if isinstance(self.tim_esteps, int): + print(f"number of time steps: {self.tim_esteps}") else: - print(f"time steps: {self.timesteps}") + print(f"time steps: {self.tim_esteps}") print(f"ensemble size: {self.config.n_ens_members}") print(f"parallel threads: {self.config.num_workers}") print(f"number of cascade levels: {self.config.n_cascade_levels}") @@ -379,6 +387,11 @@ def _initialize_nowcast_components(self): """ Initialize the FFT, bandpass filters, decomposition methods, and extrapolation method. """ + # Initialize number of ensemble workers + self.params.num_ensemble_workers = min( + self.config.n_ens_members, self.config.num_workers + ) + M, N = self.precip.shape[1:] # Extract the spatial dimensions (height, width) # Initialize FFT method @@ -666,13 +679,13 @@ def _initialize_precipitation_mask(self): self.precip_forecast = [[] for _ in range(self.config.n_ens_members)] if self.config.probmatching_method == "mean": - self.mu_0 = np.mean( + self.precipitation_mean = np.mean( self.precip[-1, :, :][ self.precip[-1, :, :] >= self.config.precip_threshold ] ) else: - self.mu_0 = None + self.precipitation_mean = None self.precip_mask = None self.precip_mask_decomposed = None @@ -759,7 +772,7 @@ def _initialize_params(self, precip): "generate_noise": self.generate_noise, "mask_method": self.config.mask_method, "mask_rim": self.mask_rim, - "mu_0": self.mu_0, + "mu_0": self.precipitation_mean, "n_cascade_levels": self.config.n_cascade_levels, "n_ens_members": self.config.n_ens_members, "noise_method": self.config.noise_method,