Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add capability to write signal with unique samps_per_frame to wfdb.io.wrsamp #510

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
185 changes: 185 additions & 0 deletions tests/test_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,25 @@ class TestRecord(unittest.TestCase):

"""

wrsamp_params = [
"record_name",
"fs",
"units",
"sig_name",
"p_signal",
"d_signal",
"e_p_signal",
"e_d_signal",
"samps_per_frame",
"fmt",
"adc_gain",
"baseline",
"comments",
"base_time",
"base_date",
"base_datetime",
]

# ----------------------- 1. Basic Tests -----------------------#

def test_1a(self):
Expand Down Expand Up @@ -286,6 +305,172 @@ def test_read_write_flac_multifrequency(self):
)
assert record == record_write

def test_unique_samps_per_frame_e_p_signal(self):
"""
Test writing an e_p_signal with wfdb.io.wrsamp where the signals have different samples per frame. All other
parameters which overlap between a Record object and wfdb.io.wrsamp are also checked.
"""
# Read in a record with different samples per frame
record = wfdb.rdrecord(
"sample-data/mixedsignals",
smooth_frames=False,
)

# Write the signals
wfdb.io.wrsamp(
"mixedsignals",
fs=record.fs,
units=record.units,
sig_name=record.sig_name,
base_date=record.base_date,
base_time=record.base_time,
comments=record.comments,
p_signal=record.p_signal,
d_signal=record.d_signal,
e_p_signal=record.e_p_signal,
e_d_signal=record.e_d_signal,
samps_per_frame=record.samps_per_frame,
baseline=record.baseline,
adc_gain=record.adc_gain,
fmt=record.fmt,
write_dir=self.temp_path,
)

# Check that the written record matches the original
# Read in the original and written records
record = wfdb.rdrecord("sample-data/mixedsignals", smooth_frames=False)
record_write = wfdb.rdrecord(
os.path.join(self.temp_path, "mixedsignals"),
smooth_frames=False,
)

# Check that the signals match
for n, name in enumerate(record.sig_name):
np.testing.assert_array_equal(
record.e_p_signal[n],
record_write.e_p_signal[n],
f"Mismatch in {name}",
)

# Filter out the signal
record_filtered = {
k: getattr(record, k)
for k in self.wrsamp_params
if not (
isinstance(getattr(record, k), np.ndarray)
or (
isinstance(getattr(record, k), list)
and all(
isinstance(item, np.ndarray)
for item in getattr(record, k)
)
)
)
}

record_write_filtered = {
k: getattr(record_write, k)
for k in self.wrsamp_params
if not (
isinstance(getattr(record_write, k), np.ndarray)
or (
isinstance(getattr(record_write, k), list)
and all(
isinstance(item, np.ndarray)
for item in getattr(record_write, k)
)
)
)
}

# Check that the arguments beyond the signals also match
assert record_filtered == record_write_filtered

def test_unique_samps_per_frame_e_d_signal(self):
"""
Test writing an e_d_signal with wfdb.io.wrsamp where the signals have different samples per frame. All other
parameters which overlap between a Record object and wfdb.io.wrsamp are also checked.
"""
# Read in a record with different samples per frame
record = wfdb.rdrecord(
"sample-data/mixedsignals",
physical=False,
smooth_frames=False,
)

# Write the signals
wfdb.io.wrsamp(
"mixedsignals",
fs=record.fs,
units=record.units,
sig_name=record.sig_name,
base_date=record.base_date,
base_time=record.base_time,
comments=record.comments,
p_signal=record.p_signal,
d_signal=record.d_signal,
e_p_signal=record.e_p_signal,
e_d_signal=record.e_d_signal,
samps_per_frame=record.samps_per_frame,
baseline=record.baseline,
adc_gain=record.adc_gain,
fmt=record.fmt,
write_dir=self.temp_path,
)

# Check that the written record matches the original
# Read in the original and written records
record = wfdb.rdrecord(
"sample-data/mixedsignals", physical=False, smooth_frames=False
)
record_write = wfdb.rdrecord(
os.path.join(self.temp_path, "mixedsignals"),
physical=False,
smooth_frames=False,
)

# Check that the signals match
for n, name in enumerate(record.sig_name):
np.testing.assert_array_equal(
record.e_d_signal[n],
record_write.e_d_signal[n],
f"Mismatch in {name}",
)

# Filter out the signal
record_filtered = {
k: getattr(record, k)
for k in self.wrsamp_params
if not (
isinstance(getattr(record, k), np.ndarray)
or (
isinstance(getattr(record, k), list)
and all(
isinstance(item, np.ndarray)
for item in getattr(record, k)
)
)
)
}

record_write_filtered = {
k: getattr(record_write, k)
for k in self.wrsamp_params
if not (
isinstance(getattr(record_write, k), np.ndarray)
or (
isinstance(getattr(record_write, k), list)
and all(
isinstance(item, np.ndarray)
for item in getattr(record_write, k)
)
)
)
}

# Check that the arguments beyond the signals also match
assert record_filtered == record_write_filtered

def test_read_write_flac_many_channels(self):
"""
Check we can read and write to format 516 with more than 8 channels.
Expand Down
2 changes: 1 addition & 1 deletion wfdb/io/_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def set_d_features(self, do_adc=False, single_fmt=True, expanded=False):
self.check_field("baseline", "all")

# All required fields are present and valid. Perform ADC
self.d_signal = self.adc(expanded)
self.e_d_signal = self.adc(expanded)

# Use e_d_signal to set fields
self.check_field("e_d_signal", "all")
Expand Down
133 changes: 95 additions & 38 deletions wfdb/io/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -2822,6 +2822,9 @@ def wrsamp(
sig_name,
p_signal=None,
d_signal=None,
e_p_signal=None,
e_d_signal=None,
samps_per_frame=None,
fmt=None,
adc_gain=None,
baseline=None,
Expand Down Expand Up @@ -2860,6 +2863,14 @@ def wrsamp(
file(s). The dtype must be an integer type. Either p_signal or d_signal
must be set, but not both. In addition, if d_signal is set, fmt, gain
and baseline must also all be set.
e_p_signal : ndarray, optional
The expanded physical conversion of the signal. Either a 2d numpy
array or a list of 1d numpy arrays.
e_d_signal : ndarray, optional
The expanded digital conversion of the signal. Either a 2d numpy
array or a list of 1d numpy arrays.
samps_per_frame : int or list of ints, optional
The total number of samples per frame.
fmt : list, optional
A list of strings giving the WFDB format of each file used to store each
channel. Accepted formats are: '80','212','16','24', and '32'. There are
Expand Down Expand Up @@ -2911,59 +2922,105 @@ def wrsamp(
if "." in record_name:
raise Exception("Record name must not contain '.'")
# Check input field combinations
if p_signal is not None and d_signal is not None:
signal_list = [p_signal, d_signal, e_p_signal, e_d_signal]
signals_set = sum(1 for var in signal_list if var is not None)
if signals_set != 1:
Copy link
Member

@tompollard tompollard Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intentionally changing behavior for calls where p_signal=None and d_signal=None? In the past, this would evaluate to False when e_p_signal=None and e_d_signal=None, but it now evaluates to True.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tompollard , if all signals are passed as None a failure will occur further down the line. @bemoody , do we want to allow all signals to be None when calling wfdb.io.wrsamp? If we don't want to allow that then I will leave the code as I have it which requires that one of the signals is set.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If passing None for both signals (p_signal and d_signal) previously resulted in an error, then the new behavior is an improvement. Unless people were intentioanlly calling the function with both signals set to None previously, this change isn't a problem.

raise Exception(
"Must only give one of the inputs: p_signal or d_signal"
"Must provide one and only one input signal: p_signal, d_signal, e_p_signal, or e_d_signal"
)
if d_signal is not None:
if d_signal is not None or e_d_signal is not None:
if fmt is None or adc_gain is None or baseline is None:
raise Exception(
"When using d_signal, must also specify 'fmt', 'gain', and 'baseline' fields."
"When using d_signal or e_d_signal, must also specify 'fmt', 'gain', and 'baseline' fields"
)
# Depending on whether d_signal or p_signal was used, set other
# required features.
if p_signal is not None:
# Create the Record object
record = Record(
record_name=record_name,
p_signal=p_signal,
fs=fs,
fmt=fmt,
units=units,
sig_name=sig_name,
adc_gain=adc_gain,
baseline=baseline,
comments=comments,
base_time=base_time,
base_date=base_date,
base_datetime=base_datetime,
if (
e_p_signal is not None or e_d_signal is not None
) and samps_per_frame is None:
raise Exception(
"When passing e_p_signal or e_d_signal, you also need to specify samples per frame for each channel"
)

# If samps_per_frame is provided, check that it aligns as expected with the channels in the signal
if samps_per_frame:
# Get the number of elements being passed in samps_per_frame
samps_per_frame_length = (
len(samps_per_frame) if isinstance(samps_per_frame, list) else 1
)
# Get properties of the signal being passed
first_valid_signal = next(
signal for signal in signal_list if signal is not None
)
if isinstance(first_valid_signal, np.ndarray):
num_sig_channels = first_valid_signal.shape[1]
channel_samples = [
first_valid_signal.shape[0]
] * first_valid_signal.shape[1]
elif isinstance(first_valid_signal, list):
num_sig_channels = len(first_valid_signal)
channel_samples = [len(channel) for channel in first_valid_signal]
else:
raise TypeError(
"Unsupported signal format. Must be ndarray or list of lists."
)
# Check that the number of channels matches the number of samps_per_frame entries
if num_sig_channels != samps_per_frame_length:
raise Exception(
"When passing samps_per_frame, it must have the same number of entries as the signal has channels"
)
# Check that the number of frames is the same across all channels
frames = [a / b for a, b in zip(channel_samples, samps_per_frame)]
if len(set(frames)) > 1:
raise Exception(
"The number of samples in a channel divided by the corresponding samples_per_frame entry must be uniform"
)

# Create the Record object
record = Record(
record_name=record_name,
p_signal=p_signal,
d_signal=d_signal,
e_p_signal=e_p_signal,
e_d_signal=e_d_signal,
samps_per_frame=samps_per_frame,
fs=fs,
fmt=fmt,
units=units,
sig_name=sig_name,
adc_gain=adc_gain,
baseline=baseline,
comments=comments,
base_time=base_time,
base_date=base_date,
base_datetime=base_datetime,
)

# Depending on which signal was used, set other required fields.
if p_signal is not None:
# Compute optimal fields to store the digital signal, carry out adc,
# and set the fields.
record.set_d_features(do_adc=1)
else:
# Create the Record object
record = Record(
record_name=record_name,
d_signal=d_signal,
fs=fs,
fmt=fmt,
units=units,
sig_name=sig_name,
adc_gain=adc_gain,
baseline=baseline,
comments=comments,
base_time=base_time,
base_date=base_date,
base_datetime=base_datetime,
)
elif d_signal is not None:
# Use d_signal to set the fields directly
record.set_d_features()
elif e_p_signal is not None:
# Compute optimal fields to store the digital signal, carry out adc,
# and set the fields.
record.set_d_features(do_adc=1, expanded=True)
elif e_d_signal is not None:
# Use e_d_signal to set the fields directly
record.set_d_features(expanded=True)

# Set default values of any missing field dependencies
record.set_defaults()

# Determine whether the signal is expanded
if (e_d_signal or e_p_signal) is not None:
expanded = True
else:
expanded = False

# Write the record files - header and associated dat
record.wrsamp(write_dir=write_dir)
record.wrsamp(write_dir=write_dir, expanded=expanded)


def dl_database(
Expand Down
Loading