Skip to content

Commit

Permalink
refactor tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Justin Kiggins committed Nov 22, 2017
1 parent 80d6799 commit a8e3776
Showing 1 changed file with 21 additions and 23 deletions.
44 changes: 21 additions & 23 deletions tests/test_calcium.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from neuroglia.calcium import MedianFilterDetrend, SavGolFilterDetrend
from neuroglia.calcium import OASISInferer
from neuroglia.calcium import MedianFilterDetrender, SavGolFilterDetrender
from neuroglia.calcium import CalciumDeconvolver
from oasis.functions import gen_data
from sklearn.base import clone

Expand All @@ -10,13 +10,6 @@
import numpy.testing as npt
import xarray.testing as xrt

# Test for proper parameter structure
def test_params():
fn_list = [MedianFilterDetrend(), SavGolFilterDetrend(), OASISInferer()]
for fn in fn_list:
new_object_params = fn.get_params(deep=False)
for name, param in new_object_params.items():
new_object_params[name] = clone(param, safe=False)

# Test functions perform as expected
true_b = 2
Expand All @@ -26,28 +19,33 @@ def test_params():
LBL = ['a', 'b', 'c']
sin_scale = 5

data = y + sin_scale*np.sin(.05*TIME)[:,None]
DFF = pd.DataFrame(data, TIME, LBL)
# data = y
DFF = pd.DataFrame(y, TIME, LBL)
DFF_WITH_DRIFT = DFF.apply(lambda y: y + sin_scale*np.sin(.05*TIME),axis=0)

assert np.all(np.mean(DFF) > 2)

def test_MedianFilterDetrend():
tmp = MedianFilterDetrend().fit_transform(DFF)
def test_MedianFilterDetrender():
detrender = MedianFilterDetrender()
tmp = detrender.fit_transform(DFF_WITH_DRIFT)
assert np.all(np.isclose(np.mean(tmp), 0, atol=.1))
clone(detrender)

def test_SavGolFilterDetrend():
tmp = SavGolFilterDetrend().fit_transform(DFF)
def test_SavGolFilterDetrender():
detrender = SavGolFilterDetrender()
tmp = detrender.fit_transform(DFF_WITH_DRIFT)
assert np.all(np.isclose(np.mean(tmp), 0, atol=.1))
clone(detrender)

def test_OASISInferer():
tmp = OASISInferer().fit_transform(SavGolFilterDetrend().fit_transform(DFF))
assert np.all(np.array([np.corrcoef(true_s[n], np.array(tmp[a]))[0][1] for n,a in zip(range(3), LBL)]) > 0.6)
tmp = OASISInferer().fit_transform(MedianFilterDetrend().fit_transform(DFF))
def test_CalciumDeconvolver():
deconvolver = CalciumDeconvolver()
tmp = deconvolver.fit_transform(DFF)
assert np.all(np.array([np.corrcoef(true_s[n], np.array(tmp[a]))[0][1] for n,a in zip(range(3), LBL)]) > 0.6)
clone(deconvolver)


if __name__ == '__main__':
test_MedianFilterDetrend()
test_SavGolFilterDetrend()
test_OASISInferer()
test_params()
test_MedianFilterDetrender()
test_SavGolFilterDetrender()
test_CalciumDeconvolver()
# test_params()

0 comments on commit a8e3776

Please sign in to comment.