diff --git a/tests/test_record.py b/tests/test_record.py index 3459897b..313eeb23 100644 --- a/tests/test_record.py +++ b/tests/test_record.py @@ -20,6 +20,25 @@ class TestRecord(unittest.TestCase): """ + wrsamp_params = [ + "record_name", + "fs", + "units", + "sig_name", + "p_signal", + "d_signal", + "e_p_signal", + "e_d_signal", + "samps_per_frame", + "fmt", + "adc_gain", + "baseline", + "comments", + "base_time", + "base_date", + "base_datetime", + ] + # ----------------------- 1. Basic Tests -----------------------# def test_1a(self): @@ -286,6 +305,172 @@ def test_read_write_flac_multifrequency(self): ) assert record == record_write + def test_unique_samps_per_frame_e_p_signal(self): + """ + Test writing an e_p_signal with wfdb.io.wrsamp where the signals have different samples per frame. All other + parameters which overlap between a Record object and wfdb.io.wrsamp are also checked. + """ + # Read in a record with different samples per frame + record = wfdb.rdrecord( + "sample-data/mixedsignals", + smooth_frames=False, + ) + + # Write the signals + wfdb.io.wrsamp( + "mixedsignals", + fs=record.fs, + units=record.units, + sig_name=record.sig_name, + base_date=record.base_date, + base_time=record.base_time, + comments=record.comments, + p_signal=record.p_signal, + d_signal=record.d_signal, + e_p_signal=record.e_p_signal, + e_d_signal=record.e_d_signal, + samps_per_frame=record.samps_per_frame, + baseline=record.baseline, + adc_gain=record.adc_gain, + fmt=record.fmt, + write_dir=self.temp_path, + ) + + # Check that the written record matches the original + # Read in the original and written records + record = wfdb.rdrecord("sample-data/mixedsignals", smooth_frames=False) + record_write = wfdb.rdrecord( + os.path.join(self.temp_path, "mixedsignals"), + smooth_frames=False, + ) + + # Check that the signals match + for n, name in enumerate(record.sig_name): + np.testing.assert_array_equal( + record.e_p_signal[n], + record_write.e_p_signal[n], + f"Mismatch in {name}", + ) + + # Filter out the signal + record_filtered = { + k: getattr(record, k) + for k in self.wrsamp_params + if not ( + isinstance(getattr(record, k), np.ndarray) + or ( + isinstance(getattr(record, k), list) + and all( + isinstance(item, np.ndarray) + for item in getattr(record, k) + ) + ) + ) + } + + record_write_filtered = { + k: getattr(record_write, k) + for k in self.wrsamp_params + if not ( + isinstance(getattr(record_write, k), np.ndarray) + or ( + isinstance(getattr(record_write, k), list) + and all( + isinstance(item, np.ndarray) + for item in getattr(record_write, k) + ) + ) + ) + } + + # Check that the arguments beyond the signals also match + assert record_filtered == record_write_filtered + + def test_unique_samps_per_frame_e_d_signal(self): + """ + Test writing an e_d_signal with wfdb.io.wrsamp where the signals have different samples per frame. All other + parameters which overlap between a Record object and wfdb.io.wrsamp are also checked. + """ + # Read in a record with different samples per frame + record = wfdb.rdrecord( + "sample-data/mixedsignals", + physical=False, + smooth_frames=False, + ) + + # Write the signals + wfdb.io.wrsamp( + "mixedsignals", + fs=record.fs, + units=record.units, + sig_name=record.sig_name, + base_date=record.base_date, + base_time=record.base_time, + comments=record.comments, + p_signal=record.p_signal, + d_signal=record.d_signal, + e_p_signal=record.e_p_signal, + e_d_signal=record.e_d_signal, + samps_per_frame=record.samps_per_frame, + baseline=record.baseline, + adc_gain=record.adc_gain, + fmt=record.fmt, + write_dir=self.temp_path, + ) + + # Check that the written record matches the original + # Read in the original and written records + record = wfdb.rdrecord( + "sample-data/mixedsignals", physical=False, smooth_frames=False + ) + record_write = wfdb.rdrecord( + os.path.join(self.temp_path, "mixedsignals"), + physical=False, + smooth_frames=False, + ) + + # Check that the signals match + for n, name in enumerate(record.sig_name): + np.testing.assert_array_equal( + record.e_d_signal[n], + record_write.e_d_signal[n], + f"Mismatch in {name}", + ) + + # Filter out the signal + record_filtered = { + k: getattr(record, k) + for k in self.wrsamp_params + if not ( + isinstance(getattr(record, k), np.ndarray) + or ( + isinstance(getattr(record, k), list) + and all( + isinstance(item, np.ndarray) + for item in getattr(record, k) + ) + ) + ) + } + + record_write_filtered = { + k: getattr(record_write, k) + for k in self.wrsamp_params + if not ( + isinstance(getattr(record_write, k), np.ndarray) + or ( + isinstance(getattr(record_write, k), list) + and all( + isinstance(item, np.ndarray) + for item in getattr(record_write, k) + ) + ) + ) + } + + # Check that the arguments beyond the signals also match + assert record_filtered == record_write_filtered + def test_read_write_flac_many_channels(self): """ Check we can read and write to format 516 with more than 8 channels. diff --git a/wfdb/io/_signal.py b/wfdb/io/_signal.py index e3df065e..4687aece 100644 --- a/wfdb/io/_signal.py +++ b/wfdb/io/_signal.py @@ -433,7 +433,7 @@ def set_d_features(self, do_adc=False, single_fmt=True, expanded=False): self.check_field("baseline", "all") # All required fields are present and valid. Perform ADC - self.d_signal = self.adc(expanded) + self.e_d_signal = self.adc(expanded) # Use e_d_signal to set fields self.check_field("e_d_signal", "all") diff --git a/wfdb/io/record.py b/wfdb/io/record.py index 4b900b17..1a8855ed 100644 --- a/wfdb/io/record.py +++ b/wfdb/io/record.py @@ -2822,6 +2822,9 @@ def wrsamp( sig_name, p_signal=None, d_signal=None, + e_p_signal=None, + e_d_signal=None, + samps_per_frame=None, fmt=None, adc_gain=None, baseline=None, @@ -2860,6 +2863,14 @@ def wrsamp( file(s). The dtype must be an integer type. Either p_signal or d_signal must be set, but not both. In addition, if d_signal is set, fmt, gain and baseline must also all be set. + e_p_signal : ndarray, optional + The expanded physical conversion of the signal. Either a 2d numpy + array or a list of 1d numpy arrays. + e_d_signal : ndarray, optional + The expanded digital conversion of the signal. Either a 2d numpy + array or a list of 1d numpy arrays. + samps_per_frame : int or list of ints, optional + The total number of samples per frame. fmt : list, optional A list of strings giving the WFDB format of each file used to store each channel. Accepted formats are: '80','212','16','24', and '32'. There are @@ -2911,59 +2922,105 @@ def wrsamp( if "." in record_name: raise Exception("Record name must not contain '.'") # Check input field combinations - if p_signal is not None and d_signal is not None: + signal_list = [p_signal, d_signal, e_p_signal, e_d_signal] + signals_set = sum(1 for var in signal_list if var is not None) + if signals_set != 1: raise Exception( - "Must only give one of the inputs: p_signal or d_signal" + "Must provide one and only one input signal: p_signal, d_signal, e_p_signal, or e_d_signal" ) - if d_signal is not None: + if d_signal is not None or e_d_signal is not None: if fmt is None or adc_gain is None or baseline is None: raise Exception( - "When using d_signal, must also specify 'fmt', 'gain', and 'baseline' fields." + "When using d_signal or e_d_signal, must also specify 'fmt', 'gain', and 'baseline' fields" ) - # Depending on whether d_signal or p_signal was used, set other - # required features. - if p_signal is not None: - # Create the Record object - record = Record( - record_name=record_name, - p_signal=p_signal, - fs=fs, - fmt=fmt, - units=units, - sig_name=sig_name, - adc_gain=adc_gain, - baseline=baseline, - comments=comments, - base_time=base_time, - base_date=base_date, - base_datetime=base_datetime, + if ( + e_p_signal is not None or e_d_signal is not None + ) and samps_per_frame is None: + raise Exception( + "When passing e_p_signal or e_d_signal, you also need to specify samples per frame for each channel" + ) + + # If samps_per_frame is provided, check that it aligns as expected with the channels in the signal + if samps_per_frame: + # Get the number of elements being passed in samps_per_frame + samps_per_frame_length = ( + len(samps_per_frame) if isinstance(samps_per_frame, list) else 1 ) + # Get properties of the signal being passed + first_valid_signal = next( + signal for signal in signal_list if signal is not None + ) + if isinstance(first_valid_signal, np.ndarray): + num_sig_channels = first_valid_signal.shape[1] + channel_samples = [ + first_valid_signal.shape[0] + ] * first_valid_signal.shape[1] + elif isinstance(first_valid_signal, list): + num_sig_channels = len(first_valid_signal) + channel_samples = [len(channel) for channel in first_valid_signal] + else: + raise TypeError( + "Unsupported signal format. Must be ndarray or list of lists." + ) + # Check that the number of channels matches the number of samps_per_frame entries + if num_sig_channels != samps_per_frame_length: + raise Exception( + "When passing samps_per_frame, it must have the same number of entries as the signal has channels" + ) + # Check that the number of frames is the same across all channels + frames = [a / b for a, b in zip(channel_samples, samps_per_frame)] + if len(set(frames)) > 1: + raise Exception( + "The number of samples in a channel divided by the corresponding samples_per_frame entry must be uniform" + ) + + # Create the Record object + record = Record( + record_name=record_name, + p_signal=p_signal, + d_signal=d_signal, + e_p_signal=e_p_signal, + e_d_signal=e_d_signal, + samps_per_frame=samps_per_frame, + fs=fs, + fmt=fmt, + units=units, + sig_name=sig_name, + adc_gain=adc_gain, + baseline=baseline, + comments=comments, + base_time=base_time, + base_date=base_date, + base_datetime=base_datetime, + ) + + # Depending on which signal was used, set other required fields. + if p_signal is not None: # Compute optimal fields to store the digital signal, carry out adc, # and set the fields. record.set_d_features(do_adc=1) - else: - # Create the Record object - record = Record( - record_name=record_name, - d_signal=d_signal, - fs=fs, - fmt=fmt, - units=units, - sig_name=sig_name, - adc_gain=adc_gain, - baseline=baseline, - comments=comments, - base_time=base_time, - base_date=base_date, - base_datetime=base_datetime, - ) + elif d_signal is not None: # Use d_signal to set the fields directly record.set_d_features() + elif e_p_signal is not None: + # Compute optimal fields to store the digital signal, carry out adc, + # and set the fields. + record.set_d_features(do_adc=1, expanded=True) + elif e_d_signal is not None: + # Use e_d_signal to set the fields directly + record.set_d_features(expanded=True) # Set default values of any missing field dependencies record.set_defaults() + + # Determine whether the signal is expanded + if (e_d_signal or e_p_signal) is not None: + expanded = True + else: + expanded = False + # Write the record files - header and associated dat - record.wrsamp(write_dir=write_dir) + record.wrsamp(write_dir=write_dir, expanded=expanded) def dl_database(