Skip to content

Commit

Permalink
support gradio
Browse files Browse the repository at this point in the history
  • Loading branch information
TakanoTaiga committed Oct 14, 2024
1 parent eb5179d commit e91a512
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 34 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
*.txt
*.csv
*.xls
*.png
114 changes: 80 additions & 34 deletions wavelet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,25 @@
import numpy as np
import matplotlib.pyplot as plt
import math


# ファイルから信号データを読み込む
def load_signal(file_path):
signal = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
try:
signal.append(float(line))
except ValueError as e:
print(e, file=sys.stderr)
return signal
import pandas as pd
import gradio as gr
import tempfile


# CSVファイルから信号データを読み込む
def load_signal(file_path, column_name):
try:
df = pd.read_csv(file_path)
signal = df[column_name].values # 指定した列の信号を取得
return signal
except FileNotFoundError as e:
print(f"Error: {e}", file=sys.stderr)
return []
except KeyError as e:
print(
f"Column '{column_name}' not found in the file.({e})",
file=sys.stderr)
return []


# モルレーウェーブレット関数
Expand Down Expand Up @@ -54,37 +61,76 @@ def plot_cwt(cwt_result, time_data, fmax):
plt.xlabel("Time [min]")
plt.ylabel("Frequency [Hz]")
plt.axis([0, len(time_data) / 1000, 0, fmax - 1])
plt.xticks(np.arange(0, 1200001, step=60000))
plt.xticks(np.arange(0, len(time_data), step=60000))
plt.ticklabel_format(style='plain', axis='x')
plt.colorbar()
plt.clim(-5, 5)


if __name__ == "__main__":
# サンプリング設定
Fs = 1000 # サンプリング周波数
Ts = 1 / Fs # 時間ステップ
time_S = 1200 # 信号長さ(秒)
t_data = np.arange(0, time_S, Ts) # 時間データ

# 信号データを読み込み
signal = load_signal('fp2.txt')

# 連続ウェーブレット変換
fmax = 60
cwt_signal = continuous_wavelet_transform(Fs=Fs, data=signal, fmax=fmax)

# 信号のプロット
def wavelet_ui(filepath: str, Fs: float, fmax: float, column_name: str):
signal = load_signal(filepath, column_name)
t_data = np.arange(0, len(signal) / Fs, 1 / Fs)
signal_filename = tempfile.NamedTemporaryFile(
delete=False,
suffix='.png').name
plt.figure(dpi=200)
plt.title("Signal")
plt.plot(t_data, signal)
plt.xlim(0, 1200)
plt.xticks(np.arange(0, 1201, step=100))
plt.xlim(0, t_data[-1])
plt.xticks(np.arange(0, t_data[-1] + 1, step=100))
plt.xlabel("Time [s]")
plt.savefig("signal_fp2.png")
plt.savefig(signal_filename)

# 連続ウェーブレット変換のプロット
cwt_signal_filename = tempfile.NamedTemporaryFile(
delete=False,
suffix='.png').name
cwt_signal = continuous_wavelet_transform(Fs=Fs, data=signal, fmax=fmax)
plt.figure(dpi=200)
plot_cwt(cwt_signal, t_data, fmax)
plt.savefig("cwt_fp2.png")
plt.show()
plt.savefig(cwt_signal_filename)

return cwt_signal_filename, signal_filename


with gr.Blocks() as main_ui:
with gr.Tab("Wavelet"):
gr.Interface(
wavelet_ui,
[
gr.File(
label="CSVファイルをアップロードしてください。",
file_count="single",
file_types=["csv"]),
gr.Slider(
minimum=0,
maximum=10000,
value=1000,
label="サンプリング周波数",
step=10,
info="単位はHz。"
),
gr.Slider(
minimum=0,
maximum=200,
value=60,
label="wavelet 最大周波数",
step=10,
info="単位はHz。"
),
gr.Dropdown(
["Fp1", "Fp2", "T7", "T8", "O1", "O2"],
value="Fp2",
label="使用する信号データ",
allow_custom_value=True,
info="使用する信号データを選んでください。デフォルトはFp2です。"
),
],
[
gr.Image(type="filepath", label="Wavelet"),
gr.Image(type="filepath", label="Signal")
]
)


if __name__ == "__main__":
main_ui.queue().launch(server_name="0.0.0.0")

0 comments on commit e91a512

Please sign in to comment.