Skip to content

Commit

Permalink
Added config dataclass to steps nowcast, v3
Browse files Browse the repository at this point in the history
  • Loading branch information
sidekock committed Oct 31, 2024
1 parent fa9a1ef commit 3da1696
Showing 1 changed file with 45 additions and 6 deletions.
51 changes: 45 additions & 6 deletions pysteps/nowcasts/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
from pysteps.timeseries import autoregression, correlation
from pysteps.nowcasts.utils import compute_percentile_mask, nowcast_main_loop

from dataclasses import field
from typing import Optional, Dict, Any, Callable
from dataclasses import dataclass, field
from typing import Optional, Dict, Any, Callable, List

try:
import dask
Expand All @@ -36,6 +36,7 @@
DASK_IMPORTED = False


@dataclass
class StepsNowcasterConfig:
n_ens_members: int = 24
n_cascade_levels: int = 6
Expand Down Expand Up @@ -66,14 +67,52 @@ class StepsNowcasterConfig:
return_output: bool = True


@dataclass
class StepsNowcasterParams:
fft: Any = None
bandpass_filter: Any = None
decomposition_method: Any = None
recomposition_method: Any = None
noise_generator: Optional[callable] = None
perturbation_generator: Optional[callable] = None
noise_std_coeffs: Optional[np.ndarray] = None
ar_model_coefficients: Optional[np.ndarray] = None
domain_mask: Optional[np.ndarray] = None
structuring_element: Optional[np.ndarray] = None
precipitation_mean: Optional[float] = None
wet_area_ratio: Optional[float] = None
num_workers: int = 1


@dataclass
class StepsNowcasterState:
precip_cascades: Optional[List[List[np.ndarray]]] = field(default_factory=list)
precip_decomposed: Optional[List[Dict[str, Any]]] = field(default_factory=list)
mask_precip: Optional[np.ndarray] = None
random_generator_precip: Optional[List[np.random.RandomState]] = field(
default_factory=list
)
random_generator_motion: Optional[List[np.random.RandomState]] = field(
default_factory=list
)
velocity_perturbations: Optional[List[callable]] = field(default_factory=list)
fft_objects: Optional[List[Any]] = field(default_factory=list)


class StepsNowcaster:
def __init__(self, precip, velocity, timesteps, steps_config, **kwargs):
self.config = steps_config
def __init__(self, precip, velocity, timesteps, steps_config):
# Store inputs and optional parameters
self.precip = precip
self.velocity = velocity
self.timesteps = timesteps

# Store the config data:
self.config = steps_config

# Store the state and params data:
self.state = StepsNowcasterState()
self.params = StepsNowcasterParams()

# Additional variables for internal state management
self.fft = None
self.bandpass_filter = None
Expand Down Expand Up @@ -1201,7 +1240,7 @@ def forecast(
nowcaster = StepsNowcaster(
precip, velocity, timesteps, steps_config=nowcaster_config
)
forecast = nowcaster.compute_forecast()
forecast_steps_nowcast = nowcaster.compute_forecast()
nowcaster.reset_states()
# Call the appropriate methods within the class
return forecast
return forecast_steps_nowcast

0 comments on commit 3da1696

Please sign in to comment.