Skip to content

Commit 75e7a93

Browse files
fix: stop tearing down non-TTY processes on SSH session end (cherry-pick #18673) (#18677)
Cherry-picked fix: stop tearing down non-TTY processes on SSH session end (#18673) (possibly temporary) fix for #18519 Matches OpenSSH for non-tty sessions, where we don't actively terminate the process. Adds explicit tracking to the SSH server for these processes so that if we are shutting down we terminate them: this ensures that we can shut down quickly to allow shutdown scripts to run. It also ensures our tests don't leak system resources. Co-authored-by: Spike Curtis <[email protected]>
1 parent 8e8dd58 commit 75e7a93

File tree

3 files changed

+54
-18
lines changed

3 files changed

+54
-18
lines changed

agent/agentssh/agentssh.go

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ type Server struct {
124124
listeners map[net.Listener]struct{}
125125
conns map[net.Conn]struct{}
126126
sessions map[ssh.Session]struct{}
127+
processes map[*os.Process]struct{}
127128
closing chan struct{}
128129
// Wait for goroutines to exit, waited without
129130
// a lock on mu but protected by closing.
@@ -182,6 +183,7 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
182183
fs: fs,
183184
conns: make(map[net.Conn]struct{}),
184185
sessions: make(map[ssh.Session]struct{}),
186+
processes: make(map[*os.Process]struct{}),
185187
logger: logger,
186188

187189
config: config,
@@ -586,7 +588,10 @@ func (s *Server) startNonPTYSession(logger slog.Logger, session ssh.Session, mag
586588
// otherwise context cancellation will not propagate properly
587589
// and SSH server close may be delayed.
588590
cmd.SysProcAttr = cmdSysProcAttr()
589-
cmd.Cancel = cmdCancel(session.Context(), logger, cmd)
591+
592+
// to match OpenSSH, we don't actually tear a non-TTY command down, even if the session ends.
593+
// c.f. https://github.com/coder/coder/issues/18519#issuecomment-3019118271
594+
cmd.Cancel = nil
590595

591596
cmd.Stdout = session
592597
cmd.Stderr = session.Stderr()
@@ -609,6 +614,16 @@ func (s *Server) startNonPTYSession(logger slog.Logger, session ssh.Session, mag
609614
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "no", "start_command").Add(1)
610615
return xerrors.Errorf("start: %w", err)
611616
}
617+
618+
// Since we don't cancel the process when the session stops, we still need to tear it down if we are closing. So
619+
// track it here.
620+
if !s.trackProcess(cmd.Process, true) {
621+
// must be closing
622+
err = cmdCancel(logger, cmd.Process)
623+
return xerrors.Errorf("failed to track process: %w", err)
624+
}
625+
defer s.trackProcess(cmd.Process, false)
626+
612627
sigs := make(chan ssh.Signal, 1)
613628
session.Signals(sigs)
614629
defer func() {
@@ -1052,6 +1067,27 @@ func (s *Server) trackSession(ss ssh.Session, add bool) (ok bool) {
10521067
return true
10531068
}
10541069

1070+
// trackCommand registers the process with the server. If the server is
1071+
// closing, the process is not registered and should be closed.
1072+
//
1073+
//nolint:revive
1074+
func (s *Server) trackProcess(p *os.Process, add bool) (ok bool) {
1075+
s.mu.Lock()
1076+
defer s.mu.Unlock()
1077+
if add {
1078+
if s.closing != nil {
1079+
// Server closed.
1080+
return false
1081+
}
1082+
s.wg.Add(1)
1083+
s.processes[p] = struct{}{}
1084+
return true
1085+
}
1086+
s.wg.Done()
1087+
delete(s.processes, p)
1088+
return true
1089+
}
1090+
10551091
// Close the server and all active connections. Server can be re-used
10561092
// after Close is done.
10571093
func (s *Server) Close() error {
@@ -1091,6 +1127,10 @@ func (s *Server) Close() error {
10911127
_ = c.Close()
10921128
}
10931129

1130+
for p := range s.processes {
1131+
_ = cmdCancel(s.logger, p)
1132+
}
1133+
10941134
s.logger.Debug(ctx, "closing SSH server")
10951135
err := s.srv.Close()
10961136

agent/agentssh/exec_other.go

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ package agentssh
44

55
import (
66
"context"
7-
"os/exec"
7+
"os"
88
"syscall"
99

1010
"cdr.dev/slog"
@@ -16,9 +16,7 @@ func cmdSysProcAttr() *syscall.SysProcAttr {
1616
}
1717
}
1818

19-
func cmdCancel(ctx context.Context, logger slog.Logger, cmd *exec.Cmd) func() error {
20-
return func() error {
21-
logger.Debug(ctx, "cmdCancel: sending SIGHUP to process and children", slog.F("pid", cmd.Process.Pid))
22-
return syscall.Kill(-cmd.Process.Pid, syscall.SIGHUP)
23-
}
19+
func cmdCancel(logger slog.Logger, p *os.Process) error {
20+
logger.Debug(context.Background(), "cmdCancel: sending SIGHUP to process and children", slog.F("pid", p.Pid))
21+
return syscall.Kill(-p.Pid, syscall.SIGHUP)
2422
}

agent/agentssh/exec_windows.go

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ package agentssh
22

33
import (
44
"context"
5-
"os/exec"
5+
"os"
66
"syscall"
77

88
"cdr.dev/slog"
@@ -12,14 +12,12 @@ func cmdSysProcAttr() *syscall.SysProcAttr {
1212
return &syscall.SysProcAttr{}
1313
}
1414

15-
func cmdCancel(ctx context.Context, logger slog.Logger, cmd *exec.Cmd) func() error {
16-
return func() error {
17-
logger.Debug(ctx, "cmdCancel: killing process", slog.F("pid", cmd.Process.Pid))
18-
// Windows doesn't support sending signals to process groups, so we
19-
// have to kill the process directly. In the future, we may want to
20-
// implement a more sophisticated solution for process groups on
21-
// Windows, but for now, this is a simple way to ensure that the
22-
// process is terminated when the context is cancelled.
23-
return cmd.Process.Kill()
24-
}
15+
func cmdCancel(logger slog.Logger, p *os.Process) error {
16+
logger.Debug(context.Background(), "cmdCancel: killing process", slog.F("pid", p.Pid))
17+
// Windows doesn't support sending signals to process groups, so we
18+
// have to kill the process directly. In the future, we may want to
19+
// implement a more sophisticated solution for process groups on
20+
// Windows, but for now, this is a simple way to ensure that the
21+
// process is terminated when the context is cancelled.
22+
return p.Kill()
2523
}

0 commit comments

Comments
 (0)