Skip to content

Commit

Permalink
tested version of pytest suite
Browse files Browse the repository at this point in the history
  • Loading branch information
JaerongA committed Mar 21, 2024
1 parent 3ae2cbc commit 9c033c4
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 28 deletions.
65 changes: 38 additions & 27 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,31 @@ def dj_config():
"database.user": os.environ.get("DJ_USER") or dj.config["database.user"],
}
)
os.environ["DATABASE_PREFIX"] = "test_"
return


@pytest.fixture(autouse=True, scope="session")
def pipeline():
import tutorial_pipeline as pipeline
from . import tutorial_pipeline as pipeline

yield {
"lab": pipeline.lab,
"subject": pipeline.subject,
"session": pipeline.session,
"probe": pipeline.probe,
"ephys": pipeline.ephys,
"ephys_report": pipeline.ephys_report,
"get_ephys_root_data_dir": pipeline.get_ephys_root_data_dir,
}

if _tear_down:
pipeline.subject.Subject.delete()
pipeline.ephys_report.schema.drop()
pipeline.ephys.schema.drop()
pipeline.probe.schema.drop()
pipeline.session.schema.drop()
pipeline.subject.schema.drop()
pipeline.lab.schema.drop()


@pytest.fixture(scope="session")
Expand All @@ -53,37 +60,46 @@ def insert_upstreams(pipeline):
ephys = pipeline["ephys"]

subject.Subject.insert1(
dict(subject="subject5", subject_birth_date="2023-01-01", sex="U")
dict(subject="subject5", subject_birth_date="2023-01-01", sex="U"),
skip_duplicates=True,
)

session_key = dict(subject="subject5", session_datetime="2023-01-01 00:00:00")
session.Session.insert1(session_key, skip_duplicates=True)
session_dir = "raw/subject5/session1"

session.SessionDirectory.insert1(dict(**session_key, session_dir=session_dir))
probe.Probe.insert1(dict(probe="714000838", probe_type="neuropixels 1.0 - 3B"))
session.SessionDirectory.insert1(
dict(**session_key, session_dir=session_dir), skip_duplicates=True
)
probe.Probe.insert1(
dict(probe="714000838", probe_type="neuropixels 1.0 - 3B"), skip_duplicates=True
)
ephys.ProbeInsertion.insert1(
dict(
session_key,
**session_key,
insertion_number=1,
probe="714000838",
)
),
skip_duplicates=True,
)
yield

if _tear_down:
subject.Subject.delete()
probe.Probe.delete()
return


@pytest.fixture(scope="session")
def populate_ephys_recording(pipeline, insert_upstream):
def populate_ephys_recording(pipeline, insert_upstreams):
ephys = pipeline["ephys"]
ephys.EphysRecording.populate()

yield
return

if _tear_down:
ephys.EphysRecording.delete()

@pytest.fixture(scope="session")
def populate_lfp(pipeline, insert_upstreams):
ephys = pipeline["ephys"]
ephys.LFP.populate()

return


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -129,25 +145,20 @@ def insert_clustering_task(pipeline, populate_ephys_recording):
paramset_idx=0,
task_mode="load", # load or trigger
clustering_output_dir="processed/subject5/session1/probe_1/kilosort2-5_1",
)
),
skip_duplicates=True,
)

yield

if _tear_down:
ephys.ClusteringParamSet.delete()
return


@pytest.fixture(scope="session")
def processing(pipeline, populate_ephys_recording):
def processing(pipeline, insert_clustering_task):

ephys = pipeline["ephys"]
ephys.Clustering.populate()
ephys.CuratedClustering.populate()
ephys.LFP.populate()
ephys.WaveformSet.populate()
ephys.QualityMetrics.populate()

yield

if _tear_down:
ephys.CuratedClustering.delete()
ephys.LFP.delete()
return
132 changes: 132 additions & 0 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import numpy as np
import pandas as pd
import datetime
from uuid import UUID


def test_generate_pipeline(pipeline):
subject = pipeline["subject"]
session = pipeline["session"]
ephys = pipeline["ephys"]
probe = pipeline["probe"]

# test elements connection from lab, subject to Session
assert subject.Subject.full_table_name in session.Session.parents()

# test elements connection from Session to probe, ephys, ephys_report
assert session.Session.full_table_name in ephys.ProbeInsertion.parents()
assert probe.Probe.full_table_name in ephys.ProbeInsertion.parents()
assert "spike_times" in (ephys.CuratedClustering.Unit.heading.secondary_attributes)


def test_insert_upstreams(pipeline, insert_upstreams):
"""Check number of subjects inserted into the `subject.Subject` table"""
subject = pipeline["subject"]
session = pipeline["session"]
probe = pipeline["probe"]
ephys = pipeline["ephys"]

assert len(subject.Subject()) == 1
assert len(session.Session()) == 1
assert len(probe.Probe()) == 1
assert len(ephys.ProbeInsertion()) == 1


def test_populate_ephys_recording(pipeline, populate_ephys_recording):
ephys = pipeline["ephys"]

assert ephys.EphysRecording.fetch1() == {
"subject": "subject5",
"session_datetime": datetime.datetime(2023, 1, 1, 0, 0),
"insertion_number": 1,
"electrode_config_hash": UUID("8d4cc6d8-a02d-42c8-bf27-7459c39ea0ee"),
"acq_software": "SpikeGLX",
"sampling_rate": 30000.0,
"recording_datetime": datetime.datetime(2018, 7, 3, 20, 32, 28),
"recording_duration": 338.666,
}
assert (
ephys.EphysRecording.EphysFile.fetch1("file_path")
== "raw/subject5/session1/probe_1/npx_g0_t0.imec.ap.meta"
)


def test_populate_lfp(pipeline, populate_lfp):
ephys = pipeline["ephys"]

assert np.mean(ephys.LFP.fetch1("lfp_mean")) == -716.0220556825378
assert len((ephys.LFP.Electrode).fetch("electrode")) == 43


def test_insert_clustering_task(pipeline, insert_clustering_task):
ephys = pipeline["ephys"]

assert ephys.ClusteringParamSet.fetch1("param_set_hash") == UUID(
"de78cee1-526f-319e-b6d5-8a2ba04963d8"
)

assert ephys.ClusteringTask.fetch1() == {
"subject": "subject5",
"session_datetime": datetime.datetime(2023, 1, 1, 0, 0),
"insertion_number": 1,
"paramset_idx": 0,
"clustering_output_dir": "processed/subject5/session1/probe_1/kilosort2-5_1",
"task_mode": "load",
}


def test_processing(pipeline, processing):

ephys = pipeline["ephys"]

# test ephys.CuratedClustering
assert len(ephys.CuratedClustering.Unit & 'cluster_quality_label = "good"') == 176
assert np.sum(ephys.CuratedClustering.Unit.fetch("spike_count")) == 328167
# test ephys.WaveformSet
waveforms = np.vstack(
(ephys.WaveformSet.PeakWaveform).fetch("peak_electrode_waveform")
)
assert waveforms.shape == (227, 82)

# test ephys.QualityMetrics
cluster_df = (ephys.QualityMetrics.Cluster).fetch(format="frame", order_by="unit")
waveform_df = (ephys.QualityMetrics.Waveform).fetch(format="frame", order_by="unit")
test_df = pd.concat([cluster_df, waveform_df], axis=1).reset_index()
test_value = test_df.select_dtypes(include=[np.number]).mean().values

assert np.allclose(
test_value,
np.array(
[
1.00000000e00,
0.00000000e00,
1.13000000e02,
4.26880089e00,
1.24162431e00,
7.17929515e-01,
4.41633793e-01,
3.08736082e-01,
1.24039274e15,
1.66763828e-02,
4.33231948e00,
7.12304747e-01,
1.48995215e-02,
7.73432472e-02,
5.06451613e00,
7.79528634e00,
6.30182452e-01,
1.19562726e02,
7.90175419e-01,
np.nan,
8.78436780e-01,
1.08028193e-01,
-5.19418717e-02,
2.36035242e02,
7.48443665e-02,
2.77550214e-02,
]
),
rtol=1e-03,
atol=1e-03,
equal_nan=True,
)
2 changes: 1 addition & 1 deletion tests/tutorial_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import datajoint as dj
from element_animal import subject
from element_animal.subject import Subject
from element_array_ephys import probe, ephys_no_curation as ephys
from element_array_ephys import probe, ephys_no_curation as ephys, ephys_report
from element_lab import lab
from element_lab.lab import Lab, Location, Project, Protocol, Source, User
from element_lab.lab import Device as Equipment
Expand Down

0 comments on commit 9c033c4

Please sign in to comment.