Skip to content

Commit

Permalink
fix(pkg/process): panic on wait before process initialization (#72)
Browse files Browse the repository at this point in the history
* fix(pkg/process): panic on wait before process initialization

Signed-off-by: Gyuho Lee <[email protected]>
  • Loading branch information
gyuho authored Sep 18, 2024
1 parent 6879cc3 commit 3a2d60a
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 38 deletions.
63 changes: 47 additions & 16 deletions pkg/process/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ func (p *process) startCommand() error {
p.cmd.Stdout = p.outputFile
p.cmd.Stderr = p.outputFile

p.stdoutReader = p.outputFile
p.stderrReader = p.outputFile

default:
var err error
p.stdoutReader, err = p.cmd.StdoutPipe()
Expand Down Expand Up @@ -188,12 +191,24 @@ func (p *process) Wait() <-chan error {
}

func (p *process) watchCmd() {
if p.cmd == nil {
return
}
defer func() {
close(p.errc)
}()

restartCount := 0
for {
if p.cmd.Process == nil { // Wait cannot be called if the process is not started yet
select {
case <-p.ctx.Done():
return
case <-time.After(time.Second):
}
continue
}

errc := make(chan error)
go func() {
errc <- p.cmd.Wait()
Expand Down Expand Up @@ -265,22 +280,22 @@ func (p *process) Abort(ctx context.Context) error {

p.cancel()

finished := false
if err := p.cmd.Process.Signal(syscall.SIGTERM); err != nil {
if err.Error() == "os: process already finished" {
finished = true
} else {
log.Logger.Warnw("failed to send SIGTERM to process", "error", err)
if p.cmd.Process != nil {
finished := false
if err := p.cmd.Process.Signal(syscall.SIGTERM); err != nil {
if err.Error() == "os: process already finished" {
finished = true
} else {
log.Logger.Warnw("failed to send SIGTERM to process", "error", err)
}
}
}

if !finished {
select {
case <-p.ctx.Done():
return ctx.Err()
case <-time.After(3 * time.Second):
if err := p.cmd.Process.Kill(); err != nil {
log.Logger.Warnw("failed to send SIGKILL to process", "error", err)
if !finished {
select {
case <-p.ctx.Done():
case <-time.After(3 * time.Second):
if err := p.cmd.Process.Kill(); err != nil {
log.Logger.Warnw("failed to send SIGKILL to process", "error", err)
}
}
}
}
Expand All @@ -291,7 +306,23 @@ func (p *process) Abort(ctx context.Context) error {
return os.RemoveAll(p.runBashFile.Name())
}

p.cmd = nil
if p.stdoutReader != nil {
_ = p.stdoutReader.Close()
p.stdoutReader = nil
}
if p.stderrReader != nil {
_ = p.stderrReader.Close()
p.stderrReader = nil
}

if p.cmd.Cancel != nil { // if created with CommandContext
_ = p.cmd.Cancel()
}

// do not set p.cmd to nil
// as Wait is still waiting for the process to exit
// p.cmd = nil

return nil
}

Expand Down
41 changes: 19 additions & 22 deletions pkg/process/process_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,18 @@ import (
"bufio"
"context"
"fmt"
"io"
"os"
"strings"
"testing"
"time"
)

func TestProcess(t *testing.T) {
t.Parallel()

p, err := New(
[][]string{
{"echo", "hello"},
},
WithOutputFile(os.Stderr),
)
if err != nil {
t.Fatal(err)
Expand All @@ -31,6 +29,22 @@ func TestProcess(t *testing.T) {
}
t.Logf("pid: %d", p.PID())

b, err := io.ReadAll(p.StderrReader())
if err != nil {
if !strings.Contains(err.Error(), "file already closed") {
t.Fatal(err)
}
}
t.Logf("stderr: %q", string(b))

b, err = io.ReadAll(p.StdoutReader())
if err != nil {
if !strings.Contains(err.Error(), "file already closed") {
t.Fatal(err)
}
}
t.Logf("stdout: %q", string(b))

select {
case err := <-p.Wait():
if err != nil {
Expand All @@ -43,20 +57,17 @@ func TestProcess(t *testing.T) {
if err := p.Abort(ctx); err != nil {
t.Fatal(err)
}
if err := p.Abort(ctx); err == nil {
t.Fatal("exected error")
if err := p.Abort(ctx); err != nil {
t.Fatal(err)
}
}

func TestProcessWithBash(t *testing.T) {
t.Parallel()

p, err := New(
[][]string{
{"echo", "hello"},
{"echo hello && echo 111 | grep 1"},
},
WithOutputFile(os.Stderr),
WithRunAsBashScript(),
)
if err != nil {
Expand Down Expand Up @@ -86,8 +97,6 @@ func TestProcessWithBash(t *testing.T) {
}

func TestProcessWithTempFile(t *testing.T) {
t.Parallel()

// create a temporary file
tmpFile, err := os.CreateTemp("", "process-test-*.txt")
if err != nil {
Expand Down Expand Up @@ -140,8 +149,6 @@ func TestProcessWithTempFile(t *testing.T) {
}

func TestProcessWithStdoutReader(t *testing.T) {
t.Parallel()

p, err := New(
[][]string{
{"echo hello && sleep 1000"},
Expand Down Expand Up @@ -187,8 +194,6 @@ func TestProcessWithStdoutReader(t *testing.T) {
}

func TestProcessWithStdoutReaderUntilEOF(t *testing.T) {
t.Parallel()

p, err := New(
[][]string{
{"echo hello 1 && sleep 1"},
Expand Down Expand Up @@ -238,14 +243,11 @@ func TestProcessWithStdoutReaderUntilEOF(t *testing.T) {
}

func TestProcessWithRestarts(t *testing.T) {
t.Parallel()

p, err := New(
[][]string{
{"echo hello"},
{"echo 111 && exit 1"},
},
WithOutputFile(os.Stderr),
WithRunAsBashScript(),
WithRestartConfig(RestartConfig{
OnError: true,
Expand Down Expand Up @@ -288,13 +290,10 @@ func TestProcessWithRestarts(t *testing.T) {
}

func TestProcessSleep(t *testing.T) {
t.Parallel()

p, err := New(
[][]string{
{"sleep", "9999"},
},
WithOutputFile(os.Stderr),
)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -324,8 +323,6 @@ func TestProcessSleep(t *testing.T) {
}

func TestProcessStream(t *testing.T) {
t.Parallel()

cmds := make([][]string, 0, 100)
for i := 0; i < 100; i++ {
cmds = append(cmds, []string{fmt.Sprintf("echo hello %d && sleep 1", i)})
Expand Down

0 comments on commit 3a2d60a

Please sign in to comment.