2021-12-04 16:42:11 +00:00

271 lines
4.9 KiB
Go

// Copyright 2020 lesismal. All rights reserved.
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
//go:build windows
// +build windows
package nbio
import (
"bytes"
"errors"
"io"
"net"
"sync"
"time"
)
// Conn wraps net.Conn
type Conn struct {
g *Gopher
hash int
mux sync.Mutex
conn net.Conn
closed bool
closeErr error
ReadBuffer []byte
// user session
session interface{}
execList []func()
cache *bytes.Buffer
DataHandler func(c *Conn, data []byte)
}
// Hash returns a hashcode
func (c *Conn) Hash() int {
return c.hash
}
// Read wraps net.Conn.Read
func (c *Conn) Read(b []byte) (int, error) {
if c.closeErr != nil {
return 0, c.closeErr
}
var reader io.Reader = c.conn
if c.cache != nil {
reader = c.cache
}
nread, err := reader.Read(b)
if c.closeErr == nil {
c.closeErr = err
}
return nread, err
}
func (c *Conn) read(b []byte) (int, error) {
c.g.beforeRead(c)
nread, err := c.conn.Read(b)
if c.closeErr == nil {
c.closeErr = err
}
if c.g.onRead != nil {
if nread > 0 {
if c.cache == nil {
c.cache = bytes.NewBuffer(nil)
}
c.cache.Write(b[:nread])
}
c.g.onRead(c)
return nread, nil
} else if nread > 0 {
c.g.onData(c, b[:nread])
}
return nread, err
}
// Write wraps net.Conn.Write
func (c *Conn) Write(b []byte) (int, error) {
c.g.beforeWrite(c)
nwrite, err := c.conn.Write(b)
if err != nil {
if c.closeErr == nil {
c.closeErr = err
}
c.Close()
}
c.g.onWriteBufferFree(c, b)
return nwrite, err
}
// Writev wraps buffers.WriteTo/syscall.Writev
func (c *Conn) Writev(in [][]byte) (int, error) {
buffers := net.Buffers(in)
nwrite, err := buffers.WriteTo(c.conn)
if err != nil {
if c.closeErr == nil {
c.closeErr = err
}
c.Close()
}
for _, v := range in {
c.g.onWriteBufferFree(c, v)
}
return int(nwrite), err
}
// Close wraps net.Conn.Close
func (c *Conn) Close() error {
c.mux.Lock()
if !c.closed {
c.closed = true
err := c.conn.Close()
c.mux.Unlock()
if c.g != nil {
c.g.pollers[c.Hash()%len(c.g.pollers)].deleteConn(c)
}
return err
}
c.mux.Unlock()
return nil
}
// CloseWithError .
func (c *Conn) CloseWithError(err error) error {
if c.closeErr == nil {
c.closeErr = err
}
return c.Close()
}
// LocalAddr wraps net.Conn.LocalAddr
func (c *Conn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
// RemoteAddr wraps net.Conn.RemoteAddr
func (c *Conn) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
}
// SetDeadline wraps net.Conn.SetDeadline
func (c *Conn) SetDeadline(t time.Time) error {
if t.IsZero() {
t = time.Now().Add(timeForever)
}
return c.conn.SetDeadline(t)
}
// SetReadDeadline wraps net.Conn.SetReadDeadline
func (c *Conn) SetReadDeadline(t time.Time) error {
if t.IsZero() {
t = time.Now().Add(timeForever)
}
return c.conn.SetReadDeadline(t)
}
// SetWriteDeadline wraps net.Conn.SetWriteDeadline
func (c *Conn) SetWriteDeadline(t time.Time) error {
if t.IsZero() {
t = time.Now().Add(timeForever)
}
return c.conn.SetWriteDeadline(t)
}
// SetNoDelay wraps net.Conn.SetNoDelay
func (c *Conn) SetNoDelay(nodelay bool) error {
conn, ok := c.conn.(*net.TCPConn)
if ok {
return conn.SetNoDelay(nodelay)
}
return nil
}
// SetReadBuffer wraps net.Conn.SetReadBuffer
func (c *Conn) SetReadBuffer(bytes int) error {
conn, ok := c.conn.(*net.TCPConn)
if ok {
return conn.SetReadBuffer(bytes)
}
return nil
}
// SetWriteBuffer wraps net.Conn.SetWriteBuffer
func (c *Conn) SetWriteBuffer(bytes int) error {
conn, ok := c.conn.(*net.TCPConn)
if ok {
return conn.SetWriteBuffer(bytes)
}
return nil
}
// SetKeepAlive wraps net.Conn.SetKeepAlive
func (c *Conn) SetKeepAlive(keepalive bool) error {
conn, ok := c.conn.(*net.TCPConn)
if ok {
return conn.SetKeepAlive(keepalive)
}
return nil
}
// SetKeepAlivePeriod wraps net.Conn.SetKeepAlivePeriod
func (c *Conn) SetKeepAlivePeriod(d time.Duration) error {
conn, ok := c.conn.(*net.TCPConn)
if ok {
return conn.SetKeepAlivePeriod(d)
}
return nil
}
// SetLinger wraps net.Conn.SetLinger
func (c *Conn) SetLinger(onoff int32, linger int32) error {
conn, ok := c.conn.(*net.TCPConn)
if ok {
return conn.SetLinger(int(linger))
}
return nil
}
// Session returns user session
func (c *Conn) Session() interface{} {
return c.session
}
// SetSession sets user session
func (c *Conn) SetSession(session interface{}) {
c.session = session
}
func newConn(conn net.Conn, fromClient ...interface{}) *Conn {
c := &Conn{
conn: conn,
}
addr := conn.RemoteAddr().String()
if len(fromClient) > 0 {
addr = conn.LocalAddr().String()
}
for _, ch := range addr {
c.hash = 31*c.hash + int(ch)
}
if c.hash < 0 {
c.hash = -c.hash
}
return c
}
// NBConn converts net.Conn to *Conn
func NBConn(conn net.Conn) (*Conn, error) {
if conn == nil {
return nil, errors.New("invalid conn: nil")
}
c, ok := conn.(*Conn)
if !ok {
c = newConn(conn, true)
}
return c, nil
}