480 lines
9.1 KiB
Go
480 lines
9.1 KiB
Go
// Copyright 2020 lesismal. All rights reserved.
|
|
// Use of this source code is governed by an MIT-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
//go:build linux || darwin || netbsd || freebsd || openbsd || dragonfly
|
|
// +build linux darwin netbsd freebsd openbsd dragonfly
|
|
|
|
package nbio
|
|
|
|
import (
|
|
"errors"
|
|
"net"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/lesismal/nbio/mempool"
|
|
)
|
|
|
|
// Conn implements net.Conn.
|
|
type Conn struct {
|
|
mux sync.Mutex
|
|
|
|
g *Gopher
|
|
|
|
fd int
|
|
|
|
rTimer *htimer
|
|
wTimer *htimer
|
|
|
|
writeBuffer []byte
|
|
|
|
closed bool
|
|
isWAdded bool
|
|
closeErr error
|
|
|
|
lAddr net.Addr
|
|
rAddr net.Addr
|
|
|
|
ReadBuffer []byte
|
|
|
|
session interface{}
|
|
|
|
chWaitWrite chan struct{}
|
|
|
|
execList []func()
|
|
|
|
DataHandler func(c *Conn, data []byte)
|
|
}
|
|
|
|
// Hash returns a hash code.
|
|
func (c *Conn) Hash() int {
|
|
return c.fd
|
|
}
|
|
|
|
// Read implements Read.
|
|
func (c *Conn) Read(b []byte) (int, error) {
|
|
// use lock to prevent multiple conn data confusion when fd is reused on unix.
|
|
c.mux.Lock()
|
|
if c.closed {
|
|
c.mux.Unlock()
|
|
return 0, errClosed
|
|
}
|
|
|
|
n, err := syscall.Read(c.fd, b)
|
|
c.mux.Unlock()
|
|
if err == nil {
|
|
c.g.afterRead(c)
|
|
}
|
|
|
|
return n, err
|
|
}
|
|
|
|
// Write implements Write.
|
|
func (c *Conn) Write(b []byte) (int, error) {
|
|
defer c.g.onWriteBufferFree(c, b)
|
|
|
|
c.mux.Lock()
|
|
if c.closed {
|
|
c.mux.Unlock()
|
|
return -1, errClosed
|
|
}
|
|
|
|
c.g.beforeWrite(c)
|
|
|
|
n, err := c.write(b)
|
|
if err != nil && !errors.Is(err, syscall.EINTR) && !errors.Is(err, syscall.EAGAIN) {
|
|
c.closed = true
|
|
c.mux.Unlock()
|
|
c.closeWithErrorWithoutLock(err)
|
|
return n, err
|
|
}
|
|
|
|
if len(c.writeBuffer) == 0 {
|
|
if c.wTimer != nil {
|
|
c.wTimer.Stop()
|
|
c.wTimer = nil
|
|
}
|
|
} else {
|
|
c.modWrite()
|
|
}
|
|
|
|
c.mux.Unlock()
|
|
return n, err
|
|
}
|
|
|
|
// Writev implements Writev.
|
|
func (c *Conn) Writev(in [][]byte) (int, error) {
|
|
defer func() {
|
|
for _, v := range in {
|
|
c.g.onWriteBufferFree(c, v)
|
|
}
|
|
}()
|
|
c.mux.Lock()
|
|
if c.closed {
|
|
c.mux.Unlock()
|
|
|
|
return 0, errClosed
|
|
}
|
|
|
|
c.g.beforeWrite(c)
|
|
|
|
var n int
|
|
var err error
|
|
switch len(in) {
|
|
case 1:
|
|
n, err = c.write(in[0])
|
|
default:
|
|
n, err = c.writev(in)
|
|
}
|
|
if err != nil && !errors.Is(err, syscall.EINTR) && !errors.Is(err, syscall.EAGAIN) {
|
|
c.closed = true
|
|
c.mux.Unlock()
|
|
c.closeWithErrorWithoutLock(err)
|
|
return n, err
|
|
}
|
|
if len(c.writeBuffer) == 0 {
|
|
if c.wTimer != nil {
|
|
c.wTimer.Stop()
|
|
c.wTimer = nil
|
|
}
|
|
} else {
|
|
c.modWrite()
|
|
}
|
|
|
|
c.mux.Unlock()
|
|
return n, err
|
|
}
|
|
|
|
// Close implements Close.
|
|
func (c *Conn) Close() error {
|
|
return c.closeWithError(nil)
|
|
}
|
|
|
|
// CloseWithError .
|
|
func (c *Conn) CloseWithError(err error) error {
|
|
return c.closeWithError(err)
|
|
}
|
|
|
|
// LocalAddr implements LocalAddr.
|
|
func (c *Conn) LocalAddr() net.Addr {
|
|
return c.lAddr
|
|
}
|
|
|
|
// RemoteAddr implements RemoteAddr.
|
|
func (c *Conn) RemoteAddr() net.Addr {
|
|
return c.rAddr
|
|
}
|
|
|
|
// SetDeadline implements SetDeadline.
|
|
func (c *Conn) SetDeadline(t time.Time) error {
|
|
c.mux.Lock()
|
|
if !c.closed {
|
|
if !t.IsZero() {
|
|
now := time.Now()
|
|
if c.rTimer == nil {
|
|
c.rTimer = c.g.afterFunc(t.Sub(now), func() { c.closeWithError(errReadTimeout) })
|
|
} else {
|
|
c.rTimer.Reset(t.Sub(now))
|
|
}
|
|
if c.wTimer == nil {
|
|
c.wTimer = c.g.afterFunc(t.Sub(now), func() { c.closeWithError(errWriteTimeout) })
|
|
} else {
|
|
c.wTimer.Reset(t.Sub(now))
|
|
}
|
|
} else {
|
|
if c.rTimer != nil {
|
|
c.rTimer.Stop()
|
|
c.rTimer = nil
|
|
}
|
|
if c.wTimer != nil {
|
|
c.wTimer.Stop()
|
|
c.wTimer = nil
|
|
}
|
|
}
|
|
}
|
|
c.mux.Unlock()
|
|
return nil
|
|
}
|
|
|
|
func (c *Conn) setDeadline(timer **htimer, returnErr error, t time.Time) error {
|
|
c.mux.Lock()
|
|
defer c.mux.Unlock()
|
|
if c.closed {
|
|
return nil
|
|
}
|
|
if !t.IsZero() {
|
|
now := time.Now()
|
|
if *timer == nil {
|
|
*timer = c.g.afterFunc(t.Sub(now), func() { c.closeWithError(returnErr) })
|
|
} else {
|
|
(*timer).Reset(t.Sub(now))
|
|
}
|
|
} else if *timer != nil {
|
|
(*timer).Stop()
|
|
(*timer) = nil
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// SetReadDeadline implements SetReadDeadline.
|
|
func (c *Conn) SetReadDeadline(t time.Time) error {
|
|
return c.setDeadline(&c.rTimer, errReadTimeout, t)
|
|
}
|
|
|
|
// SetWriteDeadline implements SetWriteDeadline.
|
|
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
|
return c.setDeadline(&c.wTimer, errWriteTimeout, t)
|
|
}
|
|
|
|
// SetNoDelay implements SetNoDelay.
|
|
func (c *Conn) SetNoDelay(nodelay bool) error {
|
|
if nodelay {
|
|
return syscall.SetsockoptInt(c.fd, syscall.IPPROTO_TCP, syscall.TCP_NODELAY, 1)
|
|
}
|
|
return syscall.SetsockoptInt(c.fd, syscall.IPPROTO_TCP, syscall.TCP_NODELAY, 0)
|
|
}
|
|
|
|
// SetReadBuffer implements SetReadBuffer.
|
|
func (c *Conn) SetReadBuffer(bytes int) error {
|
|
return syscall.SetsockoptInt(c.fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, bytes)
|
|
}
|
|
|
|
// SetWriteBuffer implements SetWriteBuffer.
|
|
func (c *Conn) SetWriteBuffer(bytes int) error {
|
|
return syscall.SetsockoptInt(c.fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, bytes)
|
|
}
|
|
|
|
// SetKeepAlive implements SetKeepAlive.
|
|
func (c *Conn) SetKeepAlive(keepalive bool) error {
|
|
if keepalive {
|
|
return syscall.SetsockoptInt(c.fd, syscall.SOL_SOCKET, syscall.SO_KEEPALIVE, 1)
|
|
}
|
|
return syscall.SetsockoptInt(c.fd, syscall.SOL_SOCKET, syscall.SO_KEEPALIVE, 0)
|
|
}
|
|
|
|
// SetKeepAlivePeriod implements SetKeepAlivePeriod.
|
|
func (c *Conn) SetKeepAlivePeriod(d time.Duration) error {
|
|
return errors.New("not supported")
|
|
// d += (time.Second - time.Nanosecond)
|
|
// secs := int(d.Seconds())
|
|
// if err := syscall.SetsockoptInt(c.fd, syscall.IPPROTO_TCP, syscall.TCP_KEEPINTVL, secs); err != nil {
|
|
// return err
|
|
// }
|
|
// return syscall.SetsockoptInt(c.fd, syscall.IPPROTO_TCP, syscall.TCP_KEEPIDLE, secs)
|
|
}
|
|
|
|
// SetLinger implements SetLinger.
|
|
func (c *Conn) SetLinger(onoff int32, linger int32) error {
|
|
return syscall.SetsockoptLinger(c.fd, syscall.SOL_SOCKET, syscall.SO_LINGER, &syscall.Linger{
|
|
Onoff: onoff, // 1
|
|
Linger: linger, // 0
|
|
})
|
|
}
|
|
|
|
// Session returns user session.
|
|
func (c *Conn) Session() interface{} {
|
|
return c.session
|
|
}
|
|
|
|
// SetSession sets user session.
|
|
func (c *Conn) SetSession(session interface{}) {
|
|
c.session = session
|
|
}
|
|
|
|
func (c *Conn) modWrite() {
|
|
if !c.closed && !c.isWAdded {
|
|
c.isWAdded = true
|
|
c.g.pollers[c.Hash()%len(c.g.pollers)].modWrite(c.fd)
|
|
}
|
|
}
|
|
|
|
func (c *Conn) resetRead() {
|
|
if !c.closed && c.isWAdded {
|
|
c.isWAdded = false
|
|
p := c.g.pollers[c.Hash()%len(c.g.pollers)]
|
|
p.deleteEvent(c.fd)
|
|
p.addRead(c.fd)
|
|
}
|
|
}
|
|
|
|
func (c *Conn) write(b []byte) (int, error) {
|
|
if len(b) == 0 {
|
|
return 0, nil
|
|
}
|
|
|
|
if c.overflow(len(b)) {
|
|
return -1, syscall.EINVAL
|
|
}
|
|
|
|
if len(c.writeBuffer) == 0 {
|
|
n, err := syscall.Write(c.fd, b)
|
|
if err != nil && !errors.Is(err, syscall.EINTR) && !errors.Is(err, syscall.EAGAIN) {
|
|
return n, err
|
|
}
|
|
if n < 0 {
|
|
n = 0
|
|
}
|
|
left := len(b) - n
|
|
if left > 0 {
|
|
c.writeBuffer = mempool.Malloc(left)
|
|
copy(c.writeBuffer, b[n:])
|
|
c.modWrite()
|
|
}
|
|
return len(b), nil
|
|
}
|
|
c.writeBuffer = append(c.writeBuffer, b...)
|
|
|
|
return len(b), nil
|
|
}
|
|
|
|
func (c *Conn) flush() error {
|
|
c.mux.Lock()
|
|
if c.closed {
|
|
c.mux.Unlock()
|
|
return errClosed
|
|
}
|
|
|
|
if len(c.writeBuffer) == 0 {
|
|
c.mux.Unlock()
|
|
return nil
|
|
}
|
|
|
|
old := c.writeBuffer
|
|
|
|
n, err := syscall.Write(c.fd, old)
|
|
if err != nil && !errors.Is(err, syscall.EINTR) && !errors.Is(err, syscall.EAGAIN) {
|
|
c.closed = true
|
|
c.mux.Unlock()
|
|
c.closeWithErrorWithoutLock(err)
|
|
return err
|
|
}
|
|
if n < 0 {
|
|
n = 0
|
|
}
|
|
left := len(old) - n
|
|
if left > 0 {
|
|
if n > 0 {
|
|
c.writeBuffer = mempool.Malloc(left)
|
|
copy(c.writeBuffer, old[n:])
|
|
mempool.Free(old)
|
|
}
|
|
// c.modWrite()
|
|
} else {
|
|
c.writeBuffer = nil
|
|
if c.wTimer != nil {
|
|
c.wTimer.Stop()
|
|
c.wTimer = nil
|
|
}
|
|
c.resetRead()
|
|
if c.chWaitWrite != nil {
|
|
select {
|
|
case c.chWaitWrite <- struct{}{}:
|
|
default:
|
|
}
|
|
}
|
|
}
|
|
|
|
c.mux.Unlock()
|
|
return nil
|
|
}
|
|
|
|
func (c *Conn) writev(in [][]byte) (int, error) {
|
|
size := 0
|
|
for _, v := range in {
|
|
size += len(v)
|
|
}
|
|
if c.overflow(size) {
|
|
return -1, syscall.EINVAL
|
|
}
|
|
if len(c.writeBuffer) > 0 {
|
|
for _, v := range in {
|
|
c.writeBuffer = append(c.writeBuffer, v...)
|
|
}
|
|
return size, nil
|
|
}
|
|
|
|
if len(in) > 1 && size <= 65536 {
|
|
b := mempool.Malloc(size)
|
|
copied := 0
|
|
for _, v := range in {
|
|
copy(b[copied:], v)
|
|
copied += len(v)
|
|
}
|
|
return c.write(b)
|
|
}
|
|
|
|
nwrite := 0
|
|
for _, b := range in {
|
|
n, err := c.write(b)
|
|
if n > 0 {
|
|
nwrite += n
|
|
}
|
|
if err != nil {
|
|
return nwrite, err
|
|
}
|
|
}
|
|
return nwrite, nil
|
|
}
|
|
|
|
func (c *Conn) overflow(n int) bool {
|
|
return c.g.maxWriteBufferSize > 0 && (len(c.writeBuffer)+n > c.g.maxWriteBufferSize)
|
|
}
|
|
|
|
func (c *Conn) closeWithError(err error) error {
|
|
c.mux.Lock()
|
|
if !c.closed {
|
|
c.closed = true
|
|
c.mux.Unlock()
|
|
return c.closeWithErrorWithoutLock(err)
|
|
}
|
|
c.mux.Unlock()
|
|
return nil
|
|
}
|
|
|
|
func (c *Conn) closeWithErrorWithoutLock(err error) error {
|
|
c.closeErr = err
|
|
|
|
if c.wTimer != nil {
|
|
c.wTimer.Stop()
|
|
c.wTimer = nil
|
|
}
|
|
if c.rTimer != nil {
|
|
c.rTimer.Stop()
|
|
c.rTimer = nil
|
|
}
|
|
|
|
mempool.Free(c.writeBuffer)
|
|
c.writeBuffer = nil
|
|
|
|
if c.chWaitWrite != nil {
|
|
select {
|
|
case c.chWaitWrite <- struct{}{}:
|
|
default:
|
|
}
|
|
}
|
|
|
|
if c.g != nil {
|
|
c.g.pollers[c.Hash()%len(c.g.pollers)].deleteConn(c)
|
|
}
|
|
|
|
return syscall.Close(c.fd)
|
|
}
|
|
|
|
// NBConn converts net.Conn to *Conn.
|
|
func NBConn(conn net.Conn) (*Conn, error) {
|
|
if conn == nil {
|
|
return nil, errors.New("invalid conn: nil")
|
|
}
|
|
c, ok := conn.(*Conn)
|
|
if !ok {
|
|
var err error
|
|
c, err = dupStdConn(conn)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return c, nil
|
|
}
|