Skip to content

Commit

Permalink
refactor transformers
Browse files Browse the repository at this point in the history
  • Loading branch information
Justin Kiggins committed Nov 22, 2017
1 parent a8e3776 commit c16dd98
Showing 1 changed file with 41 additions and 49 deletions.
90 changes: 41 additions & 49 deletions neuroglia/calcium.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from scipy.signal import medfilt, savgol_filter


class MedianFilterDetrend(BaseEstimator, TransformerMixin):
class MedianFilterDetrender(BaseEstimator, TransformerMixin):
"""
Median filter detrending
"""
Expand All @@ -15,30 +15,30 @@ def __init__(self,
self.window = window
self.peak_std_threshold = peak_std_threshold

def robust_std(self, x):
def _robust_std(self, x):
'''
Robust estimate of std
'''
MAD = np.median(np.abs(x - np.median(x)))
return 1.4826*MAD

def fit(self, X, y=None):
self.fit_params = {}
return self

def transform(self,X):
self.fit_params = {}
X_new = X.copy()
for col in X.columns:
tmp_data = X[col].values.astype(np.double)
mf = medfilt(tmp_data, self.window)
mf = np.minimum(mf, self.peak_std_threshold * self.robust_std(mf))
mf = np.minimum(mf, self.peak_std_threshold * self._robust_std(mf))
self.fit_params[col] = dict(mf=mf)
X_new[col] = tmp_data - mf

return X_new


class SavGolFilterDetrend(BaseEstimator, TransformerMixin):
class SavGolFilterDetrender(BaseEstimator, TransformerMixin):
"""
Savitzky-Golay filter detrending
"""
Expand All @@ -50,10 +50,10 @@ def __init__(self,
self.order = order

def fit(self, X, y=None):
self.fit_params = {}
return self

def transform(self,X):
self.fit_params = {}
X_new = X.copy()
for col in X.columns:
tmp_data = X[col].values.astype(np.double)
Expand All @@ -64,19 +64,16 @@ def transform(self,X):
return X_new


class EventRescale(BaseEstimator, TransformerMixin):
class EventRescaler(BaseEstimator, TransformerMixin):
"""
Savitzky-Golay filter detrending
rescale events
"""
def __init__(self,
log_transform=True,
scale=5):
def __init__(self,log_transform=True,scale=5):

self.log_transform = log_transform
self.scale = scale

def fit(self, X, y=None):
self.fit_params = {}
return self

def transform(self,X):
Expand All @@ -91,56 +88,51 @@ def transform(self,X):
return X_new


class OASISInferer(BaseEstimator, TransformerMixin):

def oasis_kwargs(penalty,indicator):

kwargs = {}

if penalty=='l0':
kwargs.update(penalty=0)
elif penalty=='l1':
kwargs.update(penalty=1)
# elif penalty=='l2':
# kwargs.update(penalty=2)

if indicator.lower()=='gcamp6f':
kwargs.update(g=(None,))
elif indicator.lower()=='gcamp6s':
kwargs.update(g=(None,None))

return kwargs


class CalciumDeconvolver(BaseEstimator, TransformerMixin):
"""docstring for OASISInferer."""
def __init__(self,
output='spikes',
g=(None,),
sn=None,
b=None,
b_nonneg=True,
optimize_g=0,
penalty=0,
**kwargs
):
super(OASISInferer, self).__init__()

self.output = output
self.g = g
self.sn = sn
self.b = b
self.b_nonneg = b_nonneg
self.optimize_g = optimize_g
def __init__(self,penalty='l0',indicator='GCaMP6f'):
self.penalty = penalty
self.kwargs = kwargs
self.indicator = indicator

def fit(self, X, y=None):
self.fit_params = {}
return self

def transform(self,X):

kwargs = oasis_kwargs(
self.penalty,
self.indicator,
)

X_new = X.copy()

self.fit_params = {}
for col in X.columns:
c, s, b, g, lam = deconvolve(
denoised, spikes, b, g, lam = deconvolve(
X[col].values.astype(np.double),
g = self.g,
sn = self.sn,
b = self.b,
b_nonneg = self.b_nonneg,
optimize_g = self.optimize_g,
penalty = self.penalty,
**self.kwargs
)
**kwargs)
self.fit_params[col] = dict(b=b,g=g,lam=lam,)

if self.output=='denoised':
X_new[col] = c
elif self.output=='spikes':
X_new[col] = np.maximum(0, s)
else:
raise NotImplementedError
X_new[col] = spikes

return X_new

Expand Down

0 comments on commit c16dd98

Please sign in to comment.