Skip to content

Commit 01bb4ee

Browse files
fix: command response processing issues (#40)
Fix command response processing issues introduced by notification handling. This includes: Random hangs due to writes to channels with no readers. Potential race when reusing client response buffer. Close of notify channel while handler was still running. Eliminated race on IsConnected check in ExecCmd. Also: Use just \n for keep-alive to reduce data sent across the wire. Wrap more errors to improve error reporting. Remove io.(Reader|Writer) wrapping as that breaks scanner io.EOF handling. --------- Co-authored-by: HalloTschuess <hallo.ich.f@gmail.com>
1 parent 6cd984d commit 01bb4ee

File tree

5 files changed

+150
-68
lines changed

5 files changed

+150
-68
lines changed

‎.github/workflows/go.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ jobs:
6363
outformat: out-format
6464
with:
6565
version: ${{ matrix.golangci }}
66-
args: "--%outformat% colored-line-number"
66+
args: "--%outformat% colored-line-number --timeout 2m"
6767
skip-pkg-cache: true
6868
skip-build-cache: true
6969

‎client.go

Lines changed: 138 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@ package ts3
22

33
import (
44
"bufio"
5-
"errors"
65
"fmt"
76
"io"
87
"net"
98
"regexp"
109
"strings"
10+
"sync"
1111
"time"
1212

1313
"golang.org/x/crypto/ssh"
@@ -26,13 +26,18 @@ const (
2626
// startBufSize is the initial size of allocation for the parse buffer.
2727
startBufSize = 4096
2828

29-
// keepAliveData is the keepalive data.
30-
keepAliveData = " \n"
29+
// responseErrTimeout is the timeout used for sending response errors.
30+
responseErrTimeout = time.Millisecond * 100
3131
)
3232

3333
var (
34+
// respTrailerRe is the regexp which matches a server response to a command.
3435
respTrailerRe = regexp.MustCompile(`^error id=(\d+) msg=([^ ]+)(.*)`)
3536

37+
// keepAliveData is data which will be ignored by the server used to ensure
38+
// the connection is kept alive.
39+
keepAliveData = []byte(" \n")
40+
3641
// DefaultTimeout is the default read / write / dial timeout for Clients.
3742
DefaultTimeout = 10 * time.Second
3843

@@ -52,6 +57,11 @@ type Connection interface {
5257
Connect(addr string, timeout time.Duration) error
5358
}
5459

60+
type response struct {
61+
err error
62+
lines []string
63+
}
64+
5565
// Client is a TeamSpeak 3 ServerQuery client.
5666
type Client struct {
5767
conn Connection
@@ -62,11 +72,13 @@ type Client struct {
6272
maxBufSize int
6373
notifyBufSize int
6474
work chan string
65-
err chan error
75+
response chan response
6676
notify chan Notification
67-
disconnect chan struct{}
68-
res []string
77+
closing chan struct{} // closing is closed to indicate we're closing our connection.
78+
done chan struct{} // done is closed once we're seen a fatal error.
79+
doneOnce sync.Once
6980
connectHeader string
81+
wg sync.WaitGroup
7082

7183
Server *ServerMethods
7284
}
@@ -151,8 +163,9 @@ func NewClient(addr string, options ...func(c *Client) error) (*Client, error) {
151163
maxBufSize: MaxParseTokenSize,
152164
notifyBufSize: DefaultNotifyBufSize,
153165
work: make(chan string),
154-
err: make(chan error),
155-
disconnect: make(chan struct{}),
166+
response: make(chan response),
167+
closing: make(chan struct{}),
168+
done: make(chan struct{}),
156169
connectHeader: DefaultConnectHeader,
157170
}
158171
for _, f := range options {
@@ -183,7 +196,7 @@ func NewClient(addr string, options ...func(c *Client) error) (*Client, error) {
183196

184197
// Read the connection header
185198
if !c.scanner.Scan() {
186-
return nil, c.scanErr()
199+
return nil, fmt.Errorf("client: header: %w", c.scanErr())
187200
}
188201

189202
if l := c.scanner.Text(); l != c.connectHeader {
@@ -192,30 +205,67 @@ func NewClient(addr string, options ...func(c *Client) error) (*Client, error) {
192205

193206
// Slurp the banner
194207
if !c.scanner.Scan() {
195-
return nil, c.scanErr()
208+
return nil, fmt.Errorf("client: banner: %w", c.scanErr())
196209
}
197210

198211
if err := c.conn.SetReadDeadline(time.Time{}); err != nil {
199212
return nil, fmt.Errorf("client: set read deadline: %w", err)
200213
}
201214

202215
// Start handlers
216+
c.wg.Add(2)
203217
go c.messageHandler()
204218
go c.workHandler()
205219

206220
return c, nil
207221
}
208222

223+
// fatalError returns false if err is nil otherwise it ensures
224+
// that done is closed and returns true.
225+
func (c *Client) fatalError(err error) bool {
226+
if err == nil {
227+
return false
228+
}
229+
230+
c.closeDone()
231+
return true
232+
}
233+
234+
// closeDone safely closes c.done.
235+
func (c *Client) closeDone() {
236+
c.doneOnce.Do(func() {
237+
close(c.done)
238+
})
239+
}
240+
209241
// messageHandler scans incoming lines and handles them accordingly.
242+
// - Notifications are sent to c.notify.
243+
// - ExecCmd responses are sent to c.response.
244+
// If a fatal error occurs it stops processing and exits.
210245
func (c *Client) messageHandler() {
246+
defer func() {
247+
close(c.notify)
248+
c.wg.Done()
249+
}()
250+
251+
buf := make([]string, 0, 10)
211252
for {
212253
if c.scanner.Scan() {
213254
line := c.scanner.Text()
214-
//nolint: gocritic
215255
if line == "error id=0 msg=ok" {
216-
c.err <- nil
256+
var resp response
257+
// Avoid creating a new buf if there was no data in the response.
258+
if len(buf) > 0 {
259+
resp.lines = buf
260+
buf = make([]string, 0, 10)
261+
}
262+
c.response <- resp
217263
} else if matches := respTrailerRe.FindStringSubmatch(line); len(matches) == 4 {
218-
c.err <- NewError(matches)
264+
c.response <- response{err: NewError(matches)}
265+
// Avoid creating a new buf if there was no data in the response.
266+
if len(buf) > 0 {
267+
buf = make([]string, 0, 10)
268+
}
219269
} else if strings.Index(line, "notify") == 0 {
220270
if n, err := decodeNotification(line); err == nil {
221271
// non-blocking write
@@ -225,40 +275,70 @@ func (c *Client) messageHandler() {
225275
}
226276
}
227277
} else {
228-
c.res = append(c.res, line)
278+
// Partial response.
279+
buf = append(buf, line)
229280
}
230281
} else {
231-
err := c.scanErr()
232-
c.err <- err
233-
if errors.Is(err, io.ErrUnexpectedEOF) {
234-
close(c.disconnect)
235-
return
282+
if err := c.scanErr(); c.fatalError(err) {
283+
c.responseErr(err)
284+
} else {
285+
// Ensure that done is closed as scanner has seen an io.EOF.
286+
c.closeDone()
236287
}
288+
return
237289
}
238290
}
239291
}
240292

293+
// responseErr sends err to c.response with a timeout to ensure it
294+
// doesn't block forever when multiple errors occur during the
295+
// processing of a single ExecCmd call.
296+
func (c *Client) responseErr(err error) {
297+
t := time.NewTimer(responseErrTimeout)
298+
defer t.Stop()
299+
300+
select {
301+
case c.response <- response{err: err}:
302+
case <-t.C:
303+
}
304+
}
305+
241306
// workHandler handles commands and keepAlive messages.
242307
func (c *Client) workHandler() {
308+
defer c.wg.Done()
309+
243310
for {
244311
select {
245312
case w := <-c.work:
246-
c.process(w)
313+
if err := c.write([]byte(w)); c.fatalError(err) {
314+
// Command send failed, inform the caller.
315+
c.responseErr(err)
316+
return
317+
}
247318
case <-time.After(c.keepAlive):
248-
c.process(keepAliveData)
249-
case <-c.disconnect:
319+
// Send a keep alive to prevent the connection from timing out.
320+
if err := c.write(keepAliveData); c.fatalError(err) {
321+
// We don't send to c.response as no ExecCmd is expecting a
322+
// response and the next caller will get an error.
323+
return
324+
}
325+
case <-c.done:
250326
return
251327
}
252328
}
253329
}
254330

255-
func (c *Client) process(data string) {
331+
// write writes data to the clients connection with the configured timeout
332+
// returning any error.
333+
func (c *Client) write(data []byte) error {
256334
if err := c.conn.SetWriteDeadline(time.Now().Add(c.timeout)); err != nil {
257-
c.err <- err
335+
return fmt.Errorf("set deadline: %w", err)
258336
}
259-
if _, err := c.conn.Write([]byte(data)); err != nil {
260-
c.err <- err
337+
if _, err := c.conn.Write(data); err != nil {
338+
return fmt.Errorf("write: %w", err)
261339
}
340+
341+
return nil
262342
}
263343

264344
// Exec executes cmd on the server and returns the response.
@@ -268,37 +348,36 @@ func (c *Client) Exec(cmd string) ([]string, error) {
268348

269349
// ExecCmd executes cmd on the server and returns the response.
270350
func (c *Client) ExecCmd(cmd *Cmd) ([]string, error) {
271-
if !c.IsConnected() {
351+
select {
352+
case c.work <- cmd.String():
353+
case <-c.done:
272354
return nil, ErrNotConnected
273355
}
274356

275-
c.work <- cmd.String()
276-
357+
var resp response
277358
select {
278-
case err := <-c.err:
279-
if err != nil {
280-
return nil, err
359+
case resp = <-c.response:
360+
if resp.err != nil {
361+
return nil, resp.err
281362
}
282363
case <-time.After(c.timeout):
283364
return nil, ErrTimeout
284365
}
285366

286-
res := c.res
287-
c.res = nil
288-
289367
if cmd.response != nil {
290-
if err := DecodeResponse(res, cmd.response); err != nil {
368+
if err := DecodeResponse(resp.lines, cmd.response); err != nil {
291369
return nil, err
292370
}
293371
}
294372

295-
return res, nil
373+
return resp.lines, nil
296374
}
297375

298-
// IsConnected returns whether the client is connected.
376+
// IsConnected returns true if the client is connected,
377+
// false otherwise.
299378
func (c *Client) IsConnected() bool {
300379
select {
301-
case <-c.disconnect:
380+
case <-c.done:
302381
return false
303382
default:
304383
return true
@@ -307,8 +386,10 @@ func (c *Client) IsConnected() bool {
307386

308387
// Close closes the connection to the server.
309388
func (c *Client) Close() error {
310-
defer close(c.notify)
389+
defer c.wg.Wait()
311390

391+
// Signal we're expecting EOF.
392+
close(c.closing)
312393
_, err := c.Exec("quit")
313394
err2 := c.conn.Close()
314395

@@ -321,11 +402,23 @@ func (c *Client) Close() error {
321402
return nil
322403
}
323404

324-
// scanError returns the error from the scanner if non-nil,
325-
// `io.ErrUnexpectedEOF` otherwise.
405+
// scanError returns nil if c is closing else if the scanner returns a
406+
// non-nil error it is returned, otherwise returns `io.ErrUnexpectedEOF`.
407+
// Callers must have seen c.scanner.Scan() return false.
326408
func (c *Client) scanErr() error {
327-
if err := c.scanner.Err(); err != nil {
328-
return fmt.Errorf("client: scan: %w", err)
409+
select {
410+
case <-c.closing:
411+
// We know we're closing the connection so ignore any errors
412+
// an return nil. This prevents spurious errors being returned
413+
// to the caller.
414+
return nil
415+
default:
416+
if err := c.scanner.Err(); err != nil {
417+
return fmt.Errorf("scan: %w", err)
418+
}
419+
420+
// As caller has seen c.scanner.Scan() return false
421+
// this must have been triggered by an unexpected EOF.
422+
return io.ErrUnexpectedEOF
329423
}
330-
return io.ErrUnexpectedEOF
331424
}

‎connection.go

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -89,18 +89,13 @@ func (c *sshConnection) Connect(addr string, timeout time.Duration) error {
8989

9090
// Read implements io.Reader.
9191
func (c *sshConnection) Read(p []byte) (n int, err error) {
92-
if n, err = c.channel.Read(p); err != nil {
93-
return n, fmt.Errorf("ssh connection: read: %w", err)
94-
}
95-
return n, nil
92+
// Don't wrap as it needs to return raw EOF as per https://pkg.go.dev/io#Reader
93+
return c.channel.Read(p) //nolint: wrapcheck
9694
}
9795

9896
// Write implements io.Writer.
9997
func (c *sshConnection) Write(p []byte) (n int, err error) {
100-
if n, err = c.channel.Write(p); err != nil {
101-
return n, fmt.Errorf("ssh connection: write: %w", err)
102-
}
103-
return n, nil
98+
return c.channel.Write(p) //nolint: wrapcheck
10499
}
105100

106101
// Close implements io.Closer.

‎mockserver_test.go

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -393,11 +393,7 @@ func (c *sshServerShell) Read(b []byte) (int, error) {
393393
return 0, err
394394
}
395395

396-
n, err := ch.Read(b)
397-
if err != nil {
398-
return n, fmt.Errorf("mock ssh shell: channel read: %w", err)
399-
}
400-
return n, nil
396+
return ch.Read(b) //nolint: wrapcheck
401397
}
402398

403399
// Write writes to the ssh channel.
@@ -407,11 +403,7 @@ func (c *sshServerShell) Write(b []byte) (int, error) {
407403
return 0, err
408404
}
409405

410-
n, err := ch.Write(b)
411-
if err != nil {
412-
return n, fmt.Errorf("mock ssh shell: channel write: %w", err)
413-
}
414-
return n, nil
406+
return ch.Write(b) //nolint: wrapcheck
415407
}
416408

417409
// Close closes the ssh channel and connection.
@@ -420,8 +412,6 @@ func (c *sshServerShell) Close() error {
420412
c.closed = true
421413
c.mtx.Unlock()
422414
c.cond.Broadcast()
423-
if err := c.Conn.Close(); err != nil {
424-
return fmt.Errorf("mock ssh shell: close: %w", err)
425-
}
426-
return nil
415+
416+
return c.Conn.Close() //nolint: wrapcheck
427417
}

0 commit comments

Comments
 (0)