Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of median amplitude normalization. #261

Merged
merged 5 commits into from
Apr 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion data/examples/unpack.bash
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ cd $(dirname ${BASH_SOURCE[0]})
wd=$PWD

for filename in \
20090407201255351.tgz 20210809074550.tgz 20SPECFEM3D_SGT.tgz SPECFEM3D_SAC.tgz;
20090407201255351.tgz 20210809074550.tgz SPECFEM3D_SGT.tgz SPECFEM3D_SAC.tgz;
do
cd $wd
cd $(dirname $filename)
Expand Down
6 changes: 4 additions & 2 deletions mtuq/graphics/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from matplotlib import pyplot
from obspy.geodetics import gps2dist_azimuth
# location to degree distance with obspy
from obspy.geodetics import locations2degrees


def station_label_writer(ax, station, origin, units='km'):
Expand Down Expand Up @@ -31,7 +33,8 @@ def station_label_writer(ax, station, origin, units='km'):
label = '%d km' % round(distance_in_m/1000.)

elif units=='deg':
label = '%d%s' % (round(m_to_deg(distance_in_m)), u'\N{DEGREE SIGN}')
label = '%d%s' % (round(locations2degrees(origin.latitude, origin.longitude,
station.latitude, station.longitude)), u'\N{DEGREE SIGN}')

pyplot.text(0.2,0.35, label, fontsize=11, transform=ax.transAxes)

Expand Down Expand Up @@ -89,4 +92,3 @@ def _getattr(trace, name, *args):
else:
raise TypeError("Wrong number of arguments")


2 changes: 1 addition & 1 deletion mtuq/graphics/header.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def parse_data_processing(self):
if not self.process_bw:
pass
if not self.process_sw:
raise Excpetion()
raise Exception()

if self.process_sw.freq_max > 1.:
units = 'Hz'
Expand Down
114 changes: 91 additions & 23 deletions mtuq/graphics/waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,17 @@ def plot_waveforms1(filename,

max_amplitude = _max(data, synthetics)

if normalize == 'median_amplitude':
# Using the updated _median_amplitude function to calculate the median of non-zero maximum amplitudes
max_amplitude_median = _median_amplitude(data, synthetics)
max_amplitudes = np.array([max_amplitude_median if len(data[i]) > 0 and len(synthetics[i]) > 0 else 0.0 for i in range(len(data))])
elif normalize == 'maximum_amplitude':
max_amplitudes = np.array([max_amplitude if len(data[i]) > 0 and len(synthetics[i]) > 0 else 0.0 for i in range(len(data))])
elif normalize == 'station_amplitude' or normalize == 'trace_amplitude':
pass
else:
raise ValueError("Invalid normalization method specified.")

#
# loop over stations
#
Expand Down Expand Up @@ -100,7 +111,7 @@ def plot_waveforms1(filename,
continue

_plot_ZRT(axes[ir], 1, dat, syn, component,
normalize, trace_label_writer, max_amplitude, total_misfit)
normalize, trace_label_writer, max_amplitudes[_i], total_misfit)

ir += 1

Expand Down Expand Up @@ -153,6 +164,23 @@ def plot_waveforms2(filename,
max_amplitude_sw = _max(data_sw, synthetics_sw)


if normalize == 'median_amplitude':
# For body wave data and synthetics
bw_median = _median_amplitude(data_bw, synthetics_bw)
max_amplitudes_bw = np.array([bw_median if len(data_bw[i]) > 0 and len(synthetics_bw[i]) > 0 else 0.0 for i in range(len(data_bw))])

# For surface wave data and synthetics
sw_median = _median_amplitude(data_sw, synthetics_sw)
max_amplitudes_sw = np.array([sw_median if len(data_sw[i]) > 0 and len(synthetics_sw[i]) > 0 else 0.0 for i in range(len(data_sw))])
elif normalize == 'maximum_amplitude':
max_amplitudes_bw = np.array([max_amplitude_bw if len(data_bw[i]) > 0 and len(synthetics_bw[i]) > 0 else 0.0 for i in range(len(data_bw))])
max_amplitudes_sw = np.array([max_amplitude_sw if len(data_sw[i]) > 0 and len(synthetics_sw[i]) > 0 else 0.0 for i in range(len(data_sw))])
elif normalize == 'station_amplitude' or normalize == 'trace_amplitude':
max_amplitudes_bw = np.array([_max(data_bw[i], synthetics_bw[i]) if len(data_bw[i]) > 0 and len(synthetics_bw[i]) > 0 else 0.0 for i in range(len(data_bw))])
max_amplitudes_sw = np.array([_max(data_sw[i], synthetics_sw[i]) if len(data_sw[i]) > 0 and len(synthetics_sw[i]) > 0 else 0.0 for i in range(len(data_sw))])
else:
raise ValueError("Invalid normalization method specified.")

#
# loop over stations
#
Expand Down Expand Up @@ -191,7 +219,7 @@ def plot_waveforms2(filename,
continue

_plot_ZR(axes[ir], 1, dat, syn, component,
normalize, trace_label_writer, max_amplitude_bw, total_misfit_bw)
normalize, trace_label_writer, max_amplitudes_bw[_i], total_misfit_bw)


#
Expand All @@ -216,7 +244,7 @@ def plot_waveforms2(filename,
continue

_plot_ZRT(axes[ir], 3, dat, syn, component,
normalize, trace_label_writer, max_amplitude_sw, total_misfit_sw)
normalize, trace_label_writer, max_amplitudes_sw[_i], total_misfit_sw)


ir += 1
Expand Down Expand Up @@ -373,7 +401,7 @@ def _initialize(nrows=None, ncolumns=None, column_width_ratios=None,

def _plot_ZRT(axes, ic, dat, syn, component,
normalize='maximum_amplitude', trace_label_writer=None,
max_amplitude=1., total_misfit=1.):
normalization_amplitude=1., total_misfit=1.):

# plot traces
if component=='Z':
Expand All @@ -387,17 +415,13 @@ def _plot_ZRT(axes, ic, dat, syn, component,

_plot(axis, dat, syn)

# normalize amplitude
# normalize amplitude -- logic for station_amplitude, median_amplitude, and maximum_amplitude is done at higher level
if normalize=='trace_amplitude':
max_trace = _max(dat, syn)
ylim = [-1.5*max_trace, +1.5*max_trace]
axis.set_ylim(*ylim)
elif normalize=='station_amplitude':
max_stream = _max(stream_dat, stream_syn)
ylim = [-1.5*max_stream, +1.5*max_stream]
axis.set_ylim(*ylim)
elif normalize=='maximum_amplitude':
ylim = [-0.75*max_amplitude, +0.75*max_amplitude]
elif normalize=='station_amplitude' or normalize=='median_amplitude' or normalize=='maximum_amplitude':
ylim = [-1.25*normalization_amplitude, +1.25*normalization_amplitude]
axis.set_ylim(*ylim)

if trace_label_writer is not None:
Expand All @@ -406,7 +430,7 @@ def _plot_ZRT(axes, ic, dat, syn, component,

def _plot_ZR(axes, ic, dat, syn, component,
normalize='maximum_amplitude', trace_label_writer=None,
max_amplitude=1., total_misfit=1.):
normalization_amplitude=1., total_misfit=1.):

# plot traces
if component=='Z':
Expand All @@ -418,20 +442,15 @@ def _plot_ZR(axes, ic, dat, syn, component,

_plot(axis, dat, syn)

# normalize amplitude
# normalize amplitude -- logic for station_amplitude, median_amplitude, and maximum_amplitude is done at higher level
if normalize=='trace_amplitude':
max_trace = _max(dat, syn)
ylim = [-1.5*max_trace, +1.5*max_trace]
axis.set_ylim(*ylim)
elif normalize=='station_amplitude':
max_stream = _max(stream_dat, stream_syn)
ylim = [-1.5*max_stream, +1.5*max_stream]
axis.set_ylim(*ylim)
elif normalize=='maximum_amplitude':
ylim = [-0.75*max_amplitude, +0.75*max_amplitude]
elif normalize=='station_amplitude' or normalize=='median_amplitude' or normalize=='maximum_amplitude':
ylim = [-1.25*normalization_amplitude, +1.25*normalization_amplitude]
axis.set_ylim(*ylim)


if trace_label_writer is not None:
trace_label_writer(axis, dat, syn, total_misfit)

Expand All @@ -450,9 +469,9 @@ def _plot(axis, dat, syn, label=None):
s = syn.data

axis.plot(t, d, 'k', linewidth=1.5,
clip_on=False, zorder=10)
clip_on=True, zorder=10)
axis.plot(t, s[start:stop], 'r', linewidth=1.25,
clip_on=False, zorder=10)
clip_on=True, zorder=10)


def _add_component_labels1(axes, body_wave_labels=True, surface_wave_labels=True):
Expand Down Expand Up @@ -534,6 +553,19 @@ def _isempty(dataset):


def _max(dat, syn):
"""
Computes the maximum value from a set of two input data objects (observed and synthetics).

Parameters:
dat (Trace, Stream, or Dataset): observed data.
syn (Trace, Stream, or Dataset): synthetics.

Returns:
float: The maximum value for normalization purposes.

Raises:
TypeError: If the input objects are not of the same type (Trace, Stream, or Dataset).
"""
if type(dat)==type(syn)==Trace:
return max(
abs(dat.max()),
Expand All @@ -552,6 +584,43 @@ def _max(dat, syn):
else:
raise TypeError

def _median_amplitude(data, synthetics):
"""
Computes the median of the maximum non-zero amplitudes for pairs of data and synthetic traces.

Args:
data: A list of of observed data (can be Trace, Stream, or Dataset objects).
synthetics: A list of synthetic traces corresponding to the observed data.

Returns:
The median of the non-zero maximum amplitudes computed across all pairs.

Raises:
ValueError: If the lengths of data and synthetics lists differ.
"""
# Validate input lengths
# If Trace directly input, make it a list
data = [data] if isinstance(data, Trace) else data
synthetics = [synthetics] if isinstance(synthetics, Trace) else synthetics

# Validate lengths
if len(data) != len(synthetics):
raise ValueError("Data and synthetics lists must have the same length.")

max_amplitudes = []

# Iterate over pairs and handle empty traces - This gets a list of maximum amplitudes for each pair of data and synthetics
for dat, syn in zip(data, synthetics):
if not dat or not syn:
max_amplitudes.append(0)
else:
max_amplitudes.append(_max(dat, syn))

# Convert to NumPy array for efficient filtering
max_amplitudes = np.array(max_amplitudes)

# Compute median of non-zero values or return 0 if none exist
return np.median(max_amplitudes[max_amplitudes > 0]) if np.any(max_amplitudes > 0) else 0.0


def _hide_axes(axes):
Expand Down Expand Up @@ -597,4 +666,3 @@ def _get_tag(tags, pattern):
else:
return None


2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def run_tests(self):
# (consider using a conda based installation instead)
install_requires=[
"numpy",
"scipy",
"scipy<1.13.0",
"pandas",
"xarray",
"netCDF4",
Expand Down
Loading