diff --git a/cookie.go b/cookie.go index d5cdb29..9caa0a9 100644 --- a/cookie.go +++ b/cookie.go @@ -98,6 +98,8 @@ func (c Cookie) replyChecked() ([]byte, error) { return reply, nil case err := <-c.errorChan: return nil, err + case <-c.conn.done: + return nil, errors.New("X connection was closed") } } @@ -121,6 +123,8 @@ func (c Cookie) replyUnchecked() ([]byte, error) { return reply, nil case <-c.pingChan: return nil, nil + case <-c.conn.done: + return nil, errors.New("X connection was closed") } } @@ -161,5 +165,7 @@ func (c Cookie) Check() error { return err case <-c.pingChan: return nil + case <-c.conn.done: + return errors.New("X connection was closed") } } diff --git a/xgb.go b/xgb.go index 3d2c61f..068d160 100644 --- a/xgb.go +++ b/xgb.go @@ -60,7 +60,8 @@ type Conn struct { xidChan chan xid seqChan chan uint16 reqChan chan *request - closing chan chan struct{} + done chan struct{} + wg sync.WaitGroup // ExtLock is a lock used whenever new extensions are initialized. // It should not be used. It is exported for use in the extension @@ -101,7 +102,7 @@ func NewConnDisplay(display string) (*Conn, error) { return postNewConn(conn) } -// NewConnDisplay is just like NewConn, but allows a specific net.Conn +// NewConnNet is just like NewConn, but allows a specific net.Conn // to be used. func NewConnNet(netConn net.Conn) (*Conn, error) { conn := &Conn{} @@ -125,8 +126,9 @@ func postNewConn(conn *Conn) (*Conn, error) { conn.seqChan = make(chan uint16, seqBuffer) conn.reqChan = make(chan *request, reqBuffer) conn.eventChan = make(chan eventOrError, eventBuffer) - conn.closing = make(chan chan struct{}, 1) + conn.done = make(chan struct{}) + conn.wg.Add(4) go conn.generateXIds() go conn.generateSeqIds() go conn.sendRequests() @@ -137,7 +139,19 @@ func postNewConn(conn *Conn) (*Conn, error) { // Close gracefully closes the connection to the X server. func (c *Conn) Close() { - close(c.reqChan) + c.broadcastDone() + c.wg.Wait() + c.conn.Close() + c.conn = nil +} + +func (c *Conn) broadcastDone() { + select { + case <-c.done: + return + default: + close(c.done) + } } // Event is an interface that can contain any of the events returned by the @@ -217,8 +231,9 @@ type xid struct { // This needs to be updated to use the XC Misc extension once we run out of // new ids. // Thanks to libxcb/src/xcb_xid.c. This code is greatly inspired by it. -func (conn *Conn) generateXIds() { - defer close(conn.xidChan) +func (c *Conn) generateXIds() { + defer c.wg.Done() + defer close(c.xidChan) // This requires some explanation. From the horse's mouth: // "The resource-id-mask contains a single contiguous set of bits (at least @@ -236,23 +251,22 @@ func (conn *Conn) generateXIds() { // 00111000 & 11001000 = 00001000. // And we use that value to increment the last resource id to get a new one. // (And then, of course, we OR it with resource-id-base.) - inc := conn.setupResourceIdMask & -conn.setupResourceIdMask - max := conn.setupResourceIdMask + inc := c.setupResourceIdMask & -c.setupResourceIdMask + max := c.setupResourceIdMask last := uint32(0) for { - // TODO: Use the XC Misc extension to look for released ids. + var id xid if last > 0 && last >= max-inc+1 { - conn.xidChan <- xid{ - id: 0, - err: errors.New("There are no more available resource" + - "identifiers."), - } + // TODO: Use the XC Misc extension to look for released ids. + id.err = errors.New("there are no more available resource identifiers") + } else { + last += inc + id.id = last | c.setupResourceIdBase } - - last += inc - conn.xidChan <- xid{ - id: last | conn.setupResourceIdBase, - err: nil, + select { + case <-c.done: + return + case c.xidChan <- id: } } } @@ -271,15 +285,20 @@ func (c *Conn) newSequenceId() uint16 { // N.B. As long as the cookie buffer is less than 2^16, there are no limitations // on the number (or kind) of requests made in sequence. func (c *Conn) generateSeqIds() { + defer c.wg.Done() defer close(c.seqChan) seqid := uint16(1) for { - c.seqChan <- seqid - if seqid == uint16((1<<16)-1) { - seqid = 0 - } else { - seqid++ + select { + case <-c.done: + return + case c.seqChan <- seqid: + if seqid == uint16((1<<16)-1) { + seqid = 0 + } else { + seqid++ + } } } } @@ -315,8 +334,20 @@ type request struct { // edits the generated code for the request you want to issue. func (c *Conn) NewRequest(buf []byte, cookie *Cookie) { seq := make(chan struct{}) - c.reqChan <- &request{buf: buf, cookie: cookie, seq: seq} - <-seq + + select { + case c.reqChan <- &request{buf: buf, cookie: cookie, seq: seq}: + case <-c.done: + // If connection was broken, all goroutines, including `sendRequests`, will be closed. + // This prevents NewRequest from blocking forever in c.reqChan <- &request if reqChan is full. + // We can't close c.reqChan since NewRequest will panic when sending to it. + return + } + + select { + case <-seq: + case <-c.done: + } } // sendRequests is run as a single goroutine that takes requests and writes @@ -324,28 +355,32 @@ func (c *Conn) NewRequest(buf []byte, cookie *Cookie) { // It is meant to be run as its own goroutine. func (c *Conn) sendRequests() { defer close(c.cookieChan) + defer func() { + c.noop() // Flush the response reading goroutine, ignore error. + c.wg.Done() + }() - for req := range c.reqChan { - // ho there! if the cookie channel is nearly full, force a round - // trip to clear out the cookie buffer. - // Note that we circumvent the request channel, because we're *in* - // the request channel. - if len(c.cookieChan) == cookieBuffer-1 { - if err := c.noop(); err != nil { - // Shut everything down. - break + for { + select { + case req := <-c.reqChan: + // ho there! if the cookie channel is nearly full, force a round + // trip to clear out the cookie buffer. + // Note that we circumvent the request channel, because we're *in* + // the request channel. + if len(c.cookieChan) == cookieBuffer-1 { + if err := c.noop(); err != nil { + // Shut everything down. + return + } } + req.cookie.Sequence = c.newSequenceId() + c.cookieChan <- req.cookie + c.writeBuffer(req.buf) + close(req.seq) + case <-c.done: + return } - req.cookie.Sequence = c.newSequenceId() - c.cookieChan <- req.cookie - c.writeBuffer(req.buf) - close(req.seq) } - response := make(chan struct{}) - c.closing <- response - c.noop() // Flush the response reading goroutine, ignore error. - <-response - c.conn.Close() } // noop circumvents the usual request sending goroutines and forces a round @@ -366,9 +401,8 @@ func (c *Conn) writeBuffer(buf []byte) error { if _, err := c.conn.Write(buf); err != nil { Logger.Printf("A write error is unrecoverable: %s", err) return err - } else { - return nil } + return nil } // readResponses is a goroutine that reads events, errors and @@ -381,6 +415,7 @@ func (c *Conn) writeBuffer(buf []byte) error { // channel. (It is an error if no such cookie exists in this case.) // Finally, cookies that came "before" this reply are always cleaned up. func (c *Conn) readResponses() { + defer c.wg.Done() defer close(c.eventChan) var ( @@ -391,18 +426,16 @@ func (c *Conn) readResponses() { for { select { - case respond := <-c.closing: - respond <- struct{}{} + case <-c.done: return default: } - buf := make([]byte, 32) err, seq = nil, 0 if _, err := io.ReadFull(c.conn, buf); err != nil { Logger.Printf("A read error is unrecoverable: %s", err) c.eventChan <- err - c.Close() + c.broadcastDone() continue } switch buf[0] { @@ -432,7 +465,7 @@ func (c *Conn) readResponses() { if _, err := io.ReadFull(c.conn, biggerBuf[32:]); err != nil { Logger.Printf("A read error is unrecoverable: %s", err) c.eventChan <- err - c.Close() + c.broadcastDone() continue } replyBytes = biggerBuf