Skip to content

Commit 7747924

Browse files
fix: stop tearing down non-TTY processes on SSH session end (cherry-pick #18673) (#18676)
Cherry-picked fix: stop tearing down non-TTY processes on SSH session end (#18673) (possibly temporary) fix for #18519 Matches OpenSSH for non-tty sessions, where we don't actively terminate the process. Adds explicit tracking to the SSH server for these processes so that if we are shutting down we terminate them: this ensures that we can shut down quickly to allow shutdown scripts to run. It also ensures our tests don't leak system resources. Co-authored-by: Spike Curtis <[email protected]>
1 parent 4a61bbe commit 7747924

File tree

3 files changed

+54
-18
lines changed

3 files changed

+54
-18
lines changed

agent/agentssh/agentssh.go

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ type Server struct {
125125
listeners map[net.Listener]struct{}
126126
conns map[net.Conn]struct{}
127127
sessions map[ssh.Session]struct{}
128+
processes map[*os.Process]struct{}
128129
closing chan struct{}
129130
// Wait for goroutines to exit, waited without
130131
// a lock on mu but protected by closing.
@@ -183,6 +184,7 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
183184
fs: fs,
184185
conns: make(map[net.Conn]struct{}),
185186
sessions: make(map[ssh.Session]struct{}),
187+
processes: make(map[*os.Process]struct{}),
186188
logger: logger,
187189

188190
config: config,
@@ -587,7 +589,10 @@ func (s *Server) startNonPTYSession(logger slog.Logger, session ssh.Session, mag
587589
// otherwise context cancellation will not propagate properly
588590
// and SSH server close may be delayed.
589591
cmd.SysProcAttr = cmdSysProcAttr()
590-
cmd.Cancel = cmdCancel(session.Context(), logger, cmd)
592+
593+
// to match OpenSSH, we don't actually tear a non-TTY command down, even if the session ends.
594+
// c.f. https://github.com/coder/coder/issues/18519#issuecomment-3019118271
595+
cmd.Cancel = nil
591596

592597
cmd.Stdout = session
593598
cmd.Stderr = session.Stderr()
@@ -610,6 +615,16 @@ func (s *Server) startNonPTYSession(logger slog.Logger, session ssh.Session, mag
610615
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "no", "start_command").Add(1)
611616
return xerrors.Errorf("start: %w", err)
612617
}
618+
619+
// Since we don't cancel the process when the session stops, we still need to tear it down if we are closing. So
620+
// track it here.
621+
if !s.trackProcess(cmd.Process, true) {
622+
// must be closing
623+
err = cmdCancel(logger, cmd.Process)
624+
return xerrors.Errorf("failed to track process: %w", err)
625+
}
626+
defer s.trackProcess(cmd.Process, false)
627+
613628
sigs := make(chan ssh.Signal, 1)
614629
session.Signals(sigs)
615630
defer func() {
@@ -1070,6 +1085,27 @@ func (s *Server) trackSession(ss ssh.Session, add bool) (ok bool) {
10701085
return true
10711086
}
10721087

1088+
// trackCommand registers the process with the server. If the server is
1089+
// closing, the process is not registered and should be closed.
1090+
//
1091+
//nolint:revive
1092+
func (s *Server) trackProcess(p *os.Process, add bool) (ok bool) {
1093+
s.mu.Lock()
1094+
defer s.mu.Unlock()
1095+
if add {
1096+
if s.closing != nil {
1097+
// Server closed.
1098+
return false
1099+
}
1100+
s.wg.Add(1)
1101+
s.processes[p] = struct{}{}
1102+
return true
1103+
}
1104+
s.wg.Done()
1105+
delete(s.processes, p)
1106+
return true
1107+
}
1108+
10731109
// Close the server and all active connections. Server can be re-used
10741110
// after Close is done.
10751111
func (s *Server) Close() error {
@@ -1109,6 +1145,10 @@ func (s *Server) Close() error {
11091145
_ = c.Close()
11101146
}
11111147

1148+
for p := range s.processes {
1149+
_ = cmdCancel(s.logger, p)
1150+
}
1151+
11121152
s.logger.Debug(ctx, "closing SSH server")
11131153
err := s.srv.Close()
11141154

agent/agentssh/exec_other.go

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ package agentssh
44

55
import (
66
"context"
7-
"os/exec"
7+
"os"
88
"syscall"
99

1010
"cdr.dev/slog"
@@ -16,9 +16,7 @@ func cmdSysProcAttr() *syscall.SysProcAttr {
1616
}
1717
}
1818

19-
func cmdCancel(ctx context.Context, logger slog.Logger, cmd *exec.Cmd) func() error {
20-
return func() error {
21-
logger.Debug(ctx, "cmdCancel: sending SIGHUP to process and children", slog.F("pid", cmd.Process.Pid))
22-
return syscall.Kill(-cmd.Process.Pid, syscall.SIGHUP)
23-
}
19+
func cmdCancel(logger slog.Logger, p *os.Process) error {
20+
logger.Debug(context.Background(), "cmdCancel: sending SIGHUP to process and children", slog.F("pid", p.Pid))
21+
return syscall.Kill(-p.Pid, syscall.SIGHUP)
2422
}

agent/agentssh/exec_windows.go

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ package agentssh
22

33
import (
44
"context"
5-
"os/exec"
5+
"os"
66
"syscall"
77

88
"cdr.dev/slog"
@@ -12,14 +12,12 @@ func cmdSysProcAttr() *syscall.SysProcAttr {
1212
return &syscall.SysProcAttr{}
1313
}
1414

15-
func cmdCancel(ctx context.Context, logger slog.Logger, cmd *exec.Cmd) func() error {
16-
return func() error {
17-
logger.Debug(ctx, "cmdCancel: killing process", slog.F("pid", cmd.Process.Pid))
18-
// Windows doesn't support sending signals to process groups, so we
19-
// have to kill the process directly. In the future, we may want to
20-
// implement a more sophisticated solution for process groups on
21-
// Windows, but for now, this is a simple way to ensure that the
22-
// process is terminated when the context is cancelled.
23-
return cmd.Process.Kill()
24-
}
15+
func cmdCancel(logger slog.Logger, p *os.Process) error {
16+
logger.Debug(context.Background(), "cmdCancel: killing process", slog.F("pid", p.Pid))
17+
// Windows doesn't support sending signals to process groups, so we
18+
// have to kill the process directly. In the future, we may want to
19+
// implement a more sophisticated solution for process groups on
20+
// Windows, but for now, this is a simple way to ensure that the
21+
// process is terminated when the context is cancelled.
22+
return p.Kill()
2523
}

0 commit comments

Comments
 (0)