Skip to content

Commit

Permalink
Support High/Low Pass Fillter
Browse files Browse the repository at this point in the history
  • Loading branch information
TakanoTaiga committed Oct 23, 2024
1 parent bf435ab commit 1a2008f
Show file tree
Hide file tree
Showing 9 changed files with 57 additions and 230 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
__pycache__
flagged
*.DS_Store
*.mp4
5 changes: 5 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,8 @@ WORKDIR /app
COPY pyproject.toml* poetry.lock* /app/
RUN poetry install
RUN rm -rf /app/pyproject.toml* /app/poetry.lock*

# For Huggingface
# COPY . /app/

# CMD [ "python3", "lab_tool_webui.py" ]
18 changes: 16 additions & 2 deletions lab_tool_webui.py → app.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ def update_slider_range(filepath):
column_dropdown = gr.Dropdown(["Fp1", "Fp2", "T7", "T8", "O1", "O2"], value="Fp2", label="使用する信号データ", allow_custom_value=True, info="使用する信号データを選んでください。デフォルトはFp2です。")
start_time = gr.Slider(minimum=0, maximum=60, value=0.0, step=0.5, label="Start Time (sec)")
end_time = gr.Slider(minimum=0, maximum=60, value=60.0, step=0.5, label="End Time (sec)")
filter_setting = gr.Radio(
["No Filter", "High PASS", "Low PASS"],
label="フィルター設定",
value="High PASS",
)
fp_hp = gr.Slider(minimum=0, maximum=20, value=3, step=0.1, label="通過域端周波数 [Hz]")
fs_hp = gr.Slider(minimum=0, maximum=20, value=1, step=0.1, label="阻止域端周波数 [Hz]")
gpass = gr.Slider(minimum=0, maximum=100, value=3, step=1, label="通過域端最大損失 [dB]")
gstop = gr.Slider(minimum=0, maximum=100, value=40, step=1, label="阻止域端最小損失 [dB]")

submit_button = gr.Button("計算開始")

file_input.change(
Expand All @@ -34,7 +44,11 @@ def update_slider_range(filepath):
wavelet_image = gr.Image(type="filepath", label="Wavelet")
signal_image = gr.Image(type="filepath", label="Signal")

submit_button.click(wavelet.wavelet_ui, inputs=[file_input, fs_slider, fmax_slider, column_dropdown, start_time, end_time], outputs=[wavelet_image, signal_image])
submit_button.click(wavelet.wavelet_ui, inputs=[
file_input,
fs_slider, fmax_slider, column_dropdown, start_time, end_time,
filter_setting, fp_hp, fs_hp, gpass, gstop],
outputs=[wavelet_image, signal_image])

with gr.Tab("1f Noise Search"):
with gr.Row():
Expand All @@ -50,4 +64,4 @@ def update_slider_range(filepath):


if __name__ == "__main__":
main_ui.queue().launch(server_name="0.0.0.0")
main_ui.queue().launch(server_name="0.0.0.0", server_port=7860)
21 changes: 21 additions & 0 deletions lab_tools/filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import scipy.signal as signal


def apply_filter(signal_data, sample_rate, pass_freq, stop_freq, pass_gain, stop_gain, filter_type):
nyquist_freq = sample_rate / 2 # ナイキスト周波数
normalized_pass_freq = pass_freq / nyquist_freq # 通過域端周波数を正規化
normalized_stop_freq = stop_freq / nyquist_freq # 阻止域端周波数を正規化
filter_order, cutoff_freq = signal.buttord(
normalized_pass_freq, normalized_stop_freq, pass_gain, stop_gain
) # フィルタのオーダーと正規化周波数を計算
b, a = signal.butter(filter_order, cutoff_freq, filter_type) # フィルタの伝達関数を計算
filtered_signal = signal.filtfilt(b, a, signal_data) # 信号にフィルタを適用
return filtered_signal


def lowpass(signal_data, sample_rate, pass_freq, stop_freq, pass_gain, stop_gain):
return apply_filter(signal_data, sample_rate, pass_freq, stop_freq, pass_gain, stop_gain, "low")


def highpass(signal_data, sample_rate, pass_freq, stop_freq, pass_gain, stop_gain):
return apply_filter(signal_data, sample_rate, pass_freq, stop_freq, pass_gain, stop_gain, "high")
62 changes: 0 additions & 62 deletions lab_tools/highpass.py

This file was deleted.

16 changes: 14 additions & 2 deletions lab_tools/wavelet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import tempfile

from lab_tools import labutils
from lab_tools import filter


# モルレーウェーブレット関数
Expand Down Expand Up @@ -42,13 +43,24 @@ def plot_cwt(cwt_result, time_data, fmax):


# グラフ描画とCWTの処理を行う関数
def wavelet_ui(uploaded_file, Fs, fmax, column_name, start_time, end_time):
def wavelet_ui(
uploaded_file,
Fs, fmax, column_name, start_time, end_time,
filter_setting, fp_hp, fs_hp, gpass, gstop):
filepath = uploaded_file.name
signal = labutils.load_signal(filepath, column_name)

if len(signal) == 0:
return None, None

# Filter
timestamps = labutils.load_signal(filepath, "Timestamp")
dt = (timestamps[1] - timestamps[0])
samplerate = 1.0 / dt
if filter_setting == "High PASS":
signal = filter.highpass(signal, samplerate, fp_hp, fs_hp, gpass, gstop)
elif filter_setting == "Low PASS":
signal = filter.lowpass(signal, samplerate, fp_hp, fs_hp, gpass, gstop)

# 時間データを計算
t_data = np.arange(0, len(signal) / Fs, 1 / Fs)

Expand Down
71 changes: 0 additions & 71 deletions run.py

This file was deleted.

File renamed without changes.
93 changes: 0 additions & 93 deletions test.py

This file was deleted.

0 comments on commit 1a2008f

Please sign in to comment.