Skip to content

Commit

Permalink
Merge pull request #10074 from gem/no-circular
Browse files Browse the repository at this point in the history
ARISTOTLE: fix circular imports
  • Loading branch information
micheles authored Oct 22, 2024
2 parents c414a1a + f889d9a commit 083e1be
Show file tree
Hide file tree
Showing 10 changed files with 121 additions and 115 deletions.
70 changes: 70 additions & 0 deletions openquake/calculators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import logging
import operator
import traceback
import getpass
from datetime import datetime
from shapely import wkt
import psutil
Expand Down Expand Up @@ -71,6 +72,9 @@
('min', F32), ('max', F32),
('len', U16)])

USER = getpass.getuser()
MB = 1024 ** 2


def get_aelo_changelog():
dic = collections.defaultdict(list)
Expand Down Expand Up @@ -1544,3 +1548,69 @@ def run_calc(job_ini, **kw):
calc = calculators(log.get_oqparam(), log.calc_id)
calc.run()
return calc


def expose_outputs(dstore, owner=USER, status='complete'):
"""
Build a correspondence between the outputs in the datastore and the
ones in the database.
:param dstore: datastore
"""
oq = dstore['oqparam']
exportable = set(ekey[0] for ekey in exp)
calcmode = oq.calculation_mode
dskeys = set(dstore) & exportable # exportable datastore keys
dskeys.add('fullreport')
if 'avg_gmf' in dskeys:
dskeys.remove('avg_gmf') # hide
rlzs = dstore['full_lt'].rlzs
if len(rlzs) > 1:
dskeys.add('realizations')
hdf5 = dstore.hdf5
if 'hcurves-stats' in hdf5 or 'hcurves-rlzs' in hdf5:
if oq.hazard_stats() or oq.individual_rlzs or len(rlzs) == 1:
dskeys.add('hcurves')
if oq.uniform_hazard_spectra:
dskeys.add('uhs') # export them
if oq.hazard_maps:
dskeys.add('hmaps') # export them
if len(rlzs) > 1 and not oq.collect_rlzs:
if 'aggrisk' in dstore:
dskeys.add('aggrisk-stats')
if 'aggcurves' in dstore:
dskeys.add('aggcurves-stats')
if not oq.individual_rlzs:
for out in ['avg_losses-rlzs', 'aggrisk', 'aggcurves']:
if out in dskeys:
dskeys.remove(out)
if 'curves-rlzs' in dstore and len(rlzs) == 1:
dskeys.add('loss_curves-rlzs')
if 'curves-stats' in dstore and len(rlzs) > 1:
dskeys.add('loss_curves-stats')
if oq.conditional_loss_poes: # expose loss_maps outputs
if 'loss_curves-stats' in dstore:
dskeys.add('loss_maps-stats')
if 'ruptures' in dskeys:
if 'scenario' in calcmode or len(dstore['ruptures']) == 0:
# do not export, as requested by Vitor
exportable.remove('ruptures')
else:
dskeys.add('event_based_mfd')
if 'hmaps' in dskeys and not oq.hazard_maps:
dskeys.remove('hmaps') # do not export the hazard maps
if logs.dbcmd('get_job', dstore.calc_id) is None:
# the calculation has not been imported in the db yet
logs.dbcmd('import_job', dstore.calc_id, oq.calculation_mode,
oq.description + ' [parent]', owner, status,
oq.hazard_calculation_id, dstore.datadir)
keysize = []
for key in sorted(dskeys & exportable):
try:
size_mb = dstore.getsize(key) / MB
except (KeyError, AttributeError):
size_mb = -1
if size_mb:
keysize.append((key, size_mb))
ds_size = dstore.getsize() / MB
logs.dbcmd('create_outputs', dstore.calc_id, keysize, ds_size)
7 changes: 3 additions & 4 deletions openquake/calculators/event_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from openquake.baselib import config, hdf5, parallel, python3compat
from openquake.baselib.general import (
AccumDict, humansize, groupby, block_splitter)
from openquake.engine.aristotle import get_close_mosaic_models
from openquake.hazardlib.geo.packager import fiona
from openquake.hazardlib.map_array import MapArray, get_mean_curve
from openquake.hazardlib.stats import geom_avg_std, compute_stats
Expand All @@ -45,14 +44,14 @@
from openquake.commonlib import util, logs, readinput, datastore
from openquake.commonlib.calc import (
gmvs_to_poes, make_hmaps, slice_dt, build_slice_by_event, RuptureImporter,
SLICE_BY_EVENT_NSITES)
SLICE_BY_EVENT_NSITES, get_close_mosaic_models)
from openquake.risklib.riskinput import str2rsi, rsi2str
from openquake.calculators import base, views
from openquake.calculators.getters import get_rupture_getters, sig_eps_dt
from openquake.calculators.classical import ClassicalCalculator
from openquake.calculators.extract import Extractor
from openquake.calculators.postproc.plots import plot_avg_gmf
from openquake.engine import engine
from openquake.calculators.base import expose_outputs
from PIL import Image

U8 = numpy.uint8
Expand Down Expand Up @@ -809,7 +808,7 @@ def post_execute(self, dummy):
# source model, however usually this is quite fast and
# does not dominate the computation
self.cl.run()
engine.expose_outputs(self.cl.datastore)
expose_outputs(self.cl.datastore)
all = slice(None)
for imt in oq.imtls:
cl_mean_curves = get_mean_curve(
Expand Down
4 changes: 2 additions & 2 deletions openquake/calculators/post_risk.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
from openquake.baselib import general, parallel, python3compat
from openquake.commonlib import datastore, logs
from openquake.risklib import asset, scientific, reinsurance
from openquake.engine import engine
from openquake.calculators import base, views
from openquake.calculators.base import expose_outputs

U8 = numpy.uint8
F32 = numpy.float32
Expand Down Expand Up @@ -659,4 +659,4 @@ def post_aggregate(calc_id: int, aggregate_by):
parallel.Starmap.init()
prc = PostRiskCalculator(oqp, log.calc_id)
prc.run(aggregate_by=[aggby])
engine.expose_outputs(prc.datastore)
expose_outputs(prc.datastore)
4 changes: 2 additions & 2 deletions openquake/commands/importcalc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import logging
from openquake.commonlib import logs, datastore
from openquake.calculators.extract import WebExtractor
from openquake.engine import engine
from openquake.calculators.base import expose_outputs


def main(calc_id):
Expand Down Expand Up @@ -51,7 +51,7 @@ def main(calc_id):
webex.close()
with datastore.read(calc_id) as dstore:
pprint.pprint(dstore.get_attrs('/'))
engine.expose_outputs(dstore, status='complete')
expose_outputs(dstore, status='complete')
logging.info('Imported calculation %s successfully', calc_id)


Expand Down
1 change: 1 addition & 0 deletions openquake/commands/tests/independence_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def test_commonlib(self):

def test_engine(self):
assert_independent('openquake.engine', 'openquake.server')
assert_independent('openquake.calculators', 'openquake.engine')


class CaseConsistencyTestCase(unittest.TestCase):
Expand Down
38 changes: 34 additions & 4 deletions openquake/commonlib/calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@
import operator
import functools
import numpy
from shapely.geometry import Point

from openquake.baselib import performance, parallel, hdf5, general
from openquake.hazardlib.source import rupture
from openquake.hazardlib import map_array
from openquake.hazardlib import map_array, geo
from openquake.hazardlib.source.rupture import get_events
from openquake.commonlib import util
from openquake.commonlib import util, readinput

TWO16 = 2 ** 16
TWO24 = 2 ** 24
Expand All @@ -49,6 +50,7 @@

# ############## utilities for the classical calculator ############### #


# used only in the view global_hcurves
def convert_to_array(pmap, nsites, imtls, inner_idx=0):
"""
Expand Down Expand Up @@ -218,7 +220,7 @@ def import_rups_events(self, rup_array, get_rupture_getters):
ne, nr, int(eff_time), mag))

def _save_events(self, rup_array, rgetters):
oq = self.oqparam
oq = self.oqparam
# this is very fast compared to saving the ruptures
E = rup_array['n_occ'].sum()
events = numpy.zeros(E, rupture.events_dt)
Expand Down Expand Up @@ -430,7 +432,7 @@ def starmap_from_gmfs(task_func, oq, dstore, mon):
slices.append(get_slices(sbe[slc], data, num_assets))
slices = numpy.concatenate(slices, dtype=slices[0].dtype)
dstore.swmr_on()
maxw = slices['weight'].sum()/ (oq.concurrent_tasks or 1) or 1.
maxw = slices['weight'].sum() / (oq.concurrent_tasks or 1) or 1.
logging.info('maxw = {:_d}'.format(int(maxw)))
smap = parallel.Starmap.apply(
task_func, (slices, oq, ds),
Expand All @@ -439,3 +441,31 @@ def starmap_from_gmfs(task_func, oq, dstore, mon):
weight=operator.itemgetter('weight'),
h5=dstore.hdf5)
return smap


def get_close_mosaic_models(lon, lat, buffer_radius):
"""
:param lon: longitude
:param lat: latitude
:param buffer_radius: radius of the buffer around the point.
This distance is in the same units as the point's
coordinates (i.e. degrees), and it defines how far from
the point the buffer should extend in all directions,
creating a circular buffer region around the point
:returns: list of mosaic models intersecting the circle
centered on the given coordinates having the specified radius
"""
mosaic_df = readinput.read_mosaic_df(buffer=1)
hypocenter = Point(lon, lat)
hypo_buffer = hypocenter.buffer(buffer_radius)
geoms = numpy.array([hypo_buffer])
[close_mosaic_models] = geo.utils.geolocate_geometries(geoms, mosaic_df)
if not close_mosaic_models:
raise ValueError(
f'({lon}, {lat}) is farther than {buffer_radius} deg'
f' from any mosaic model!')
elif len(close_mosaic_models) > 1:
logging.info(
'(%s, %s) is closer than %s deg with respect to the following'
' mosaic models: %s' % (lon, lat, buffer_radius, close_mosaic_models))
return close_mosaic_models
32 changes: 2 additions & 30 deletions openquake/engine/aristotle.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
import logging
from dataclasses import dataclass
import numpy
from shapely.geometry import Point
from json.decoder import JSONDecodeError
from urllib.error import HTTPError
from openquake.baselib import config, hdf5, sap
from openquake.hazardlib import geo, nrml, sourceconverter
from openquake.hazardlib import nrml, sourceconverter
from openquake.hazardlib.shakemap.parsers import (
download_rupture_dict, download_station_data_file)
from openquake.commonlib import readinput
from openquake.commonlib.calc import get_close_mosaic_models
from openquake.engine import engine

CDIR = os.path.dirname(__file__) # openquake/engine
Expand All @@ -55,34 +55,6 @@ class AristotleParam:
ignore_shakemap: bool = False


def get_close_mosaic_models(lon, lat, buffer_radius):
"""
:param lon: longitude
:param lat: latitude
:param buffer_radius: radius of the buffer around the point.
This distance is in the same units as the point's
coordinates (i.e. degrees), and it defines how far from
the point the buffer should extend in all directions,
creating a circular buffer region around the point
:returns: list of mosaic models intersecting the circle
centered on the given coordinates having the specified radius
"""
mosaic_df = readinput.read_mosaic_df(buffer=1)
hypocenter = Point(lon, lat)
hypo_buffer = hypocenter.buffer(buffer_radius)
geoms = numpy.array([hypo_buffer])
[close_mosaic_models] = geo.utils.geolocate_geometries(geoms, mosaic_df)
if not close_mosaic_models:
raise ValueError(
f'({lon}, {lat}) is farther than {buffer_radius} deg'
f' from any mosaic model!')
elif len(close_mosaic_models) > 1:
logging.info(
'(%s, %s) is closer than %s deg with respect to the following'
' mosaic models: %s' % (lon, lat, buffer_radius, close_mosaic_models))
return close_mosaic_models


def get_trts_around(mosaic_model, exposure_hdf5):
"""
:returns: list of TRTs for the given mosaic model
Expand Down
73 changes: 4 additions & 69 deletions openquake/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def setproctitle(title):
from openquake.baselib import parallel, general, config, slurm, workerpool as w
from openquake.commonlib.oqvalidation import OqParam
from openquake.commonlib import readinput, logs
from openquake.calculators import base, export
from openquake.calculators import base
from openquake.calculators.base import expose_outputs


USER = getpass.getuser()
Expand Down Expand Up @@ -91,72 +92,6 @@ def set_concurrent_tasks_default(calc):
logging.warning('Using %d %s workers', num_workers, dist)


def expose_outputs(dstore, owner=USER, status='complete'):
"""
Build a correspondence between the outputs in the datastore and the
ones in the database.
:param dstore: datastore
"""
oq = dstore['oqparam']
exportable = set(ekey[0] for ekey in export.export)
calcmode = oq.calculation_mode
dskeys = set(dstore) & exportable # exportable datastore keys
dskeys.add('fullreport')
if 'avg_gmf' in dskeys:
dskeys.remove('avg_gmf') # hide
rlzs = dstore['full_lt'].rlzs
if len(rlzs) > 1:
dskeys.add('realizations')
hdf5 = dstore.hdf5
if 'hcurves-stats' in hdf5 or 'hcurves-rlzs' in hdf5:
if oq.hazard_stats() or oq.individual_rlzs or len(rlzs) == 1:
dskeys.add('hcurves')
if oq.uniform_hazard_spectra:
dskeys.add('uhs') # export them
if oq.hazard_maps:
dskeys.add('hmaps') # export them
if len(rlzs) > 1 and not oq.collect_rlzs:
if 'aggrisk' in dstore:
dskeys.add('aggrisk-stats')
if 'aggcurves' in dstore:
dskeys.add('aggcurves-stats')
if not oq.individual_rlzs:
for out in ['avg_losses-rlzs', 'aggrisk', 'aggcurves']:
if out in dskeys:
dskeys.remove(out)
if 'curves-rlzs' in dstore and len(rlzs) == 1:
dskeys.add('loss_curves-rlzs')
if 'curves-stats' in dstore and len(rlzs) > 1:
dskeys.add('loss_curves-stats')
if oq.conditional_loss_poes: # expose loss_maps outputs
if 'loss_curves-stats' in dstore:
dskeys.add('loss_maps-stats')
if 'ruptures' in dskeys:
if 'scenario' in calcmode or len(dstore['ruptures']) == 0:
# do not export, as requested by Vitor
exportable.remove('ruptures')
else:
dskeys.add('event_based_mfd')
if 'hmaps' in dskeys and not oq.hazard_maps:
dskeys.remove('hmaps') # do not export the hazard maps
if logs.dbcmd('get_job', dstore.calc_id) is None:
# the calculation has not been imported in the db yet
logs.dbcmd('import_job', dstore.calc_id, oq.calculation_mode,
oq.description + ' [parent]', owner, status,
oq.hazard_calculation_id, dstore.datadir)
keysize = []
for key in sorted(dskeys & exportable):
try:
size_mb = dstore.getsize(key) / MB
except (KeyError, AttributeError):
size_mb = -1
if size_mb:
keysize.append((key, size_mb))
ds_size = dstore.getsize() / MB
logs.dbcmd('create_outputs', dstore.calc_id, keysize, ds_size)


class MasterKilled(KeyboardInterrupt):
"Exception raised when a job is killed manually"

Expand Down Expand Up @@ -326,7 +261,7 @@ def stop_workers(job_id):
"""
print(w.WorkerMaster(job_id).stop())


def run_jobs(jobctxs, concurrent_jobs=None, nodes=1, sbatch=False, precalc=False):
"""
Run jobs using the specified config file and other options.
Expand Down Expand Up @@ -385,7 +320,7 @@ def run_jobs(jobctxs, concurrent_jobs=None, nodes=1, sbatch=False, precalc=False
'start_time': datetime.utcnow()}
logs.dbcmd('update_job', job.calc_id, dic)
try:
if dist in ('zmq', 'slurm') and w.WorkerMaster(job_id).status() == []:
if dist in ('zmq', 'slurm') and w.WorkerMaster(job_id).status() == []:
start_workers(job_id, dist, nodes)

# run the jobs sequentially or in parallel, with slurm or without
Expand Down
Loading

0 comments on commit 083e1be

Please sign in to comment.