Skip to content

Ensure connection is closed at all error points #193

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 12 additions & 14 deletions read.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,9 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro
defer c.readMu.unlock()

if !c.msgReader.fin {
return 0, nil, errors.New("previous message not read to completion")
err = errors.New("previous message not read to completion")
c.close(fmt.Errorf("failed to get reader: %w", err))
return 0, nil, err
}

h, err := c.readLoop(ctx)
Expand Down Expand Up @@ -361,21 +363,9 @@ func (mr *msgReader) setFrame(h header) {
}

func (mr *msgReader) Read(p []byte) (n int, err error) {
defer func() {
if errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate {
err = io.EOF
}
if errors.Is(err, io.EOF) {
err = io.EOF
mr.putFlateReader()
return
}
errd.Wrap(&err, "failed to read")
}()

err = mr.c.readMu.lock(mr.ctx)
if err != nil {
return 0, err
return 0, fmt.Errorf("failed to read: %w", err)
}
defer mr.c.readMu.unlock()

Expand All @@ -384,6 +374,14 @@ func (mr *msgReader) Read(p []byte) (n int, err error) {
p = p[:n]
mr.dict.write(p)
}
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate {
mr.putFlateReader()
return n, io.EOF
}
if err != nil {
err = fmt.Errorf("failed to read: %w", err)
mr.c.close(err)
}
return n, err
}

Expand Down
42 changes: 30 additions & 12 deletions write.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"errors"
"fmt"
"io"
"sync"
"time"

"github.com/klauspost/compress/flate"
Expand Down Expand Up @@ -71,7 +70,7 @@ type msgWriterState struct {
c *Conn

mu *mu
writeMu sync.Mutex
writeMu *mu

ctx context.Context
opcode opcode
Expand All @@ -83,8 +82,9 @@ type msgWriterState struct {

func newMsgWriterState(c *Conn) *msgWriterState {
mw := &msgWriterState{
c: c,
mu: newMu(c),
c: c,
mu: newMu(c),
writeMu: newMu(c),
}
return mw
}
Expand Down Expand Up @@ -155,10 +155,18 @@ func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error {

// Write writes the given bytes to the WebSocket connection.
func (mw *msgWriterState) Write(p []byte) (_ int, err error) {
defer errd.Wrap(&err, "failed to write")
err = mw.writeMu.lock(mw.ctx)
if err != nil {
return 0, fmt.Errorf("failed to write: %w", err)
}
defer mw.writeMu.unlock()

mw.writeMu.Lock()
defer mw.writeMu.Unlock()
defer func() {
if err != nil {
err = fmt.Errorf("failed to write: %w", err)
mw.c.close(err)
}
}()

if mw.c.flate() {
// Only enables flate if the length crosses the
Expand Down Expand Up @@ -193,8 +201,11 @@ func (mw *msgWriterState) write(p []byte) (int, error) {
func (mw *msgWriterState) Close() (err error) {
defer errd.Wrap(&err, "failed to close writer")

mw.writeMu.Lock()
defer mw.writeMu.Unlock()
err = mw.writeMu.lock(mw.ctx)
if err != nil {
return err
}
defer mw.writeMu.unlock()

_, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil)
if err != nil {
Expand All @@ -214,7 +225,7 @@ func (mw *msgWriterState) close() {
putBufioWriter(mw.c.bw)
}

mw.writeMu.Lock()
mw.writeMu.forceLock()
mw.dict.close()
}

Expand All @@ -230,8 +241,8 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error
}

// frame handles all writes to the connection.
func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (int, error) {
err := c.writeFrameMu.lock(ctx)
func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) {
err = c.writeFrameMu.lock(ctx)
if err != nil {
return 0, err
}
Expand All @@ -243,6 +254,13 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco
case c.writeTimeout <- ctx:
}

defer func() {
if err != nil {
err = fmt.Errorf("failed to write frame: %w", err)
c.close(err)
}
}()

c.writeHeader.fin = fin
c.writeHeader.opcode = opcode
c.writeHeader.payloadLength = int64(len(p))
Expand Down