Skip to content

Commit

Permalink
Merge pull request #15 from arnaudon/plottting
Browse files Browse the repository at this point in the history
some plotting updates
  • Loading branch information
razimantv authored Jun 15, 2022
2 parents f9adc79 + d1ee9c3 commit 379b503
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 7 deletions.
35 changes: 29 additions & 6 deletions netsalt/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,16 @@ def plot_stem_spectra(
else:
fig = None

# add first and last point to have spectra spanning entire plot
threshold_modes = np.insert(threshold_modes, 0, ks[0])
modal_amplitudes = np.insert(modal_amplitudes, 0, 0)
threshold_modes = np.append(threshold_modes, ks[-1])
modal_amplitudes = np.append(modal_amplitudes, 0)
markerline, _, baseline = ax.stem(
threshold_modes, modal_amplitudes, "-", linefmt="grey", markerfmt=" "
)

# colors = cycle(["C{}".format(i) for i in range(10)])
markerline.set_markerfacecolor("white")
# plt.setp(stemlines, "alpha", 0.5, "linewidth", 2)
plt.setp(baseline, "color", "grey", "linewidth", 1)
ax.set_xlabel(r"$k$")
ax.set_ylabel("Intensity (a.u.)")
Expand All @@ -114,7 +117,9 @@ def plot_stem_spectra(

ax2 = ax.twinx()
ks = np.linspace(graph.graph["params"]["k_min"], graph.graph["params"]["k_max"], 1000)
ax2.plot(ks, lorentzian(ks, graph), "r--")
gain = lorentzian(ks, graph)
ax2.plot(ks, gain, "r--")
ax2.set_ylim(0, max(gain) * 1.1)
ax2.set_xlabel(r"$\lambda$")
ax2.set_ylabel("Gain spectrum (a.u.)")

Expand Down Expand Up @@ -370,7 +375,9 @@ def plot_quantum_graph(
_savefig(graph, fig, folder, filename)


def plot_pump_traj(modes_df, with_scatter=True, with_approx=True, ax=None):
def plot_pump_traj(
modes_df, with_scatter=True, with_approx=True, ax=None, d0s_max=None, s=1, c="d0"
):
"""plot pump trajectories"""
if ax is None:
ax = plt.gca()
Expand All @@ -379,9 +386,25 @@ def plot_pump_traj(modes_df, with_scatter=True, with_approx=True, ax=None):

pumped_modes = modes_df["mode_trajectories"].to_numpy()
for pumped_mode in pumped_modes:
if with_scatter:
ax.scatter(np.real(pumped_mode), -np.imag(pumped_mode), marker="o", s=10, c="b")
ax.plot(np.real(pumped_mode), -np.imag(pumped_mode), c=next(colors))
if with_scatter:
vmax = None
if c == "d0":
c = modes_df["mode_trajectories"].columns.to_list()
if d0s_max is None:
vmax = c[max(np.argmin(abs(np.imag(pumped_modes)), axis=1)) + 1]
else:
vmax = d0s_max
ax.scatter(
np.real(pumped_mode),
-np.imag(pumped_mode),
marker="o",
s=s,
c=c,
vmax=vmax,
vmin=0,
zorder=10,
)

if "mode_trajectories_approx" in modes_df and with_approx:
pumped_modes_approx = modes_df["mode_trajectories_approx"].to_numpy()
Expand Down
2 changes: 1 addition & 1 deletion netsalt/tasks/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def run(self):
qualities = load_qualities(filename=self.input()["qualities"].path)
modes_df = load_modes(self.input()["thresholds"].path)

plot_scan(qg, qualities, modes_df, relax_upper=True)
plot_scan(qg, qualities, modes_df, relax_upper=True, with_approx=False)
plt.savefig(self.output().path, bbox_inches="tight")

def output(self):
Expand Down

0 comments on commit 379b503

Please sign in to comment.