Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support general I/O for ProgressBar output #57

Merged
merged 2 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/progress_bar.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export ProgressBar
export ProgressBar, next!

struct ProgressBar{CT,T1<:Integer, T2<:Real}
counter::CT
Expand All @@ -9,11 +9,11 @@ struct ProgressBar{CT,T1<:Integer, T2<:Real}
lock::ReentrantLock
end

function ProgressBar(max_counts; enable=true, bar_width=30)
function ProgressBar(max_counts::Int; enable::Bool=true, bar_width::Int=30)
return ProgressBar(Ref{Int64}(0), max_counts, enable, bar_width, time(), ReentrantLock())
end

function next!(p::ProgressBar)
function next!(p::ProgressBar, io::IO=stdout)

lock(p.lock)

Expand Down Expand Up @@ -43,10 +43,10 @@ function next!(p::ProgressBar)
# Construct the progress bar string
bar = "[" * repeat("=", progress) * repeat(" ", bar_width - progress) * "]"

print("\rProgress: $bar $percentage_100% --- Elapsed Time: $elapsed_time_str (ETA: $eta_str)")
flush(stdout)
print(io, "\rProgress: $bar $percentage_100% --- Elapsed Time: $elapsed_time_str (ETA: $eta_str)")
flush(io)

unlock(p.lock)

p.counter[] >= p.max_counts ? print("\n") : nothing
p.counter[] >= p.max_counts ? print(io, "\n") : nothing
end
4 changes: 3 additions & 1 deletion test/low_rank_dynamics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@
p0 = 0.,
atol_inv = 1e-6,
adj_condition="variational",
Δt = 0.2, )
Δt = 0.2,
progress = false
)
lrsol = lr_mesolve(H, z, B, tl, c_ops; e_ops=e_ops, f_ops=(f_entropy,), opt=opt)

# Test
Expand Down
14 changes: 14 additions & 0 deletions test/progress_bar.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
@testset "Progress Bar" begin
bar_width = 30
strLength = 67 + bar_width # including "\r" in the beginning of the string
prog = ProgressBar(bar_width, enable=true, bar_width=bar_width)
for p in 1:bar_width
output = sprint((t, s) -> next!(s, t), prog)

if p < bar_width
@test length(output) == strLength
else # the last output has an extra "\n" in the end
@test length(output) == strLength + 1
end
end
end
7 changes: 4 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ const GROUP = get(ENV, "GROUP", "All")

const testdir = dirname(@__FILE__)

# Put them in alphabetical order
tests = [
# Put core tests in alphabetical order
core_tests = [
"correlations_and_spectrum.jl",
"dynamical_fock_dimension_mesolve.jl",
"dynamical-shifted-fock.jl",
Expand All @@ -16,14 +16,15 @@ tests = [
"low_rank_dynamics.jl",
"negativity_and_partial_transpose.jl",
"permutation.jl",
"progress_bar.jl",
"quantum_objects.jl",
"steady_state.jl",
"time_evolution_and_partial_trace.jl",
"wigner.jl",
]

if (GROUP == "All") || (GROUP == "Core")
for test in tests
for test in core_tests
include(joinpath(testdir, test))
end
end
Loading