454 lines
11 KiB
Go
Raw Normal View History

2021-11-21 15:25:11 +00:00
package socks5
import (
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"strings"
"time"
cache "github.com/patrickmn/go-cache"
"github.com/txthinking/runnergroup"
)
var (
// ErrUnsupportCmd is the error when got unsupport command
ErrUnsupportCmd = errors.New("Unsupport Command")
// ErrUserPassAuth is the error when got invalid username or password
ErrUserPassAuth = errors.New("Invalid Username or Password for Auth")
)
// Server is socks5 server wrapper
type Server struct {
UserName string
Password string
Method byte
SupportedCommands []byte
TCPAddr *net.TCPAddr
UDPAddr *net.UDPAddr
ServerAddr *net.UDPAddr
TCPListen *net.TCPListener
UDPConn *net.UDPConn
UDPExchanges *cache.Cache
TCPTimeout int
UDPTimeout int
Handle Handler
AssociatedUDP *cache.Cache
RunnerGroup *runnergroup.RunnerGroup
// RFC: [UDP ASSOCIATE] The server MAY use this information to limit access to the association. Default false, no limit.
LimitUDP bool
}
// UDPExchange used to store client address and remote connection
type UDPExchange struct {
ClientAddr *net.UDPAddr
RemoteConn *net.UDPConn
}
// NewClassicServer return a server which allow none method
func NewClassicServer(addr, ip, username, password string, tcpTimeout, udpTimeout int) (*Server, error) {
_, p, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
taddr, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
return nil, err
}
uaddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
saddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(ip, p))
if err != nil {
return nil, err
}
m := MethodNone
if username != "" && password != "" {
m = MethodUsernamePassword
}
cs := cache.New(cache.NoExpiration, cache.NoExpiration)
cs1 := cache.New(cache.NoExpiration, cache.NoExpiration)
s := &Server{
Method: m,
UserName: username,
Password: password,
SupportedCommands: []byte{CmdConnect, CmdUDP},
TCPAddr: taddr,
UDPAddr: uaddr,
ServerAddr: saddr,
UDPExchanges: cs,
TCPTimeout: tcpTimeout,
UDPTimeout: udpTimeout,
AssociatedUDP: cs1,
RunnerGroup: runnergroup.New(),
}
return s, nil
}
// Negotiate handle negotiate packet.
// This method do not handle gssapi(0x01) method now.
// Error or OK both replied.
func (s *Server) Negotiate(rw io.ReadWriter) error {
rq, err := NewNegotiationRequestFrom(rw)
if err != nil {
return err
}
var got bool
var m byte
for _, m = range rq.Methods {
if m == s.Method {
got = true
}
}
if !got {
rp := NewNegotiationReply(MethodUnsupportAll)
if _, err := rp.WriteTo(rw); err != nil {
return err
}
}
rp := NewNegotiationReply(s.Method)
if _, err := rp.WriteTo(rw); err != nil {
return err
}
if s.Method == MethodUsernamePassword {
urq, err := NewUserPassNegotiationRequestFrom(rw)
if err != nil {
return err
}
if string(urq.Uname) != s.UserName || string(urq.Passwd) != s.Password {
urp := NewUserPassNegotiationReply(UserPassStatusFailure)
if _, err := urp.WriteTo(rw); err != nil {
return err
}
return ErrUserPassAuth
}
urp := NewUserPassNegotiationReply(UserPassStatusSuccess)
if _, err := urp.WriteTo(rw); err != nil {
return err
}
}
return nil
}
// GetRequest get request packet from client, and check command according to SupportedCommands
// Error replied.
func (s *Server) GetRequest(rw io.ReadWriter) (*Request, error) {
r, err := NewRequestFrom(rw)
if err != nil {
return nil, err
}
var supported bool
for _, c := range s.SupportedCommands {
if r.Cmd == c {
supported = true
break
}
}
if !supported {
var p *Reply
if r.Atyp == ATYPIPv4 || r.Atyp == ATYPDomain {
p = NewReply(RepCommandNotSupported, ATYPIPv4, []byte{0x00, 0x00, 0x00, 0x00}, []byte{0x00, 0x00})
} else {
p = NewReply(RepCommandNotSupported, ATYPIPv6, []byte(net.IPv6zero), []byte{0x00, 0x00})
}
if _, err := p.WriteTo(rw); err != nil {
return nil, err
}
return nil, ErrUnsupportCmd
}
return r, nil
}
// Run server
func (s *Server) ListenAndServe(h Handler) error {
if h == nil {
s.Handle = &DefaultHandle{}
} else {
s.Handle = h
}
s.RunnerGroup.Add(&runnergroup.Runner{
Start: func() error {
return s.RunTCPServer()
},
Stop: func() error {
if s.TCPListen != nil {
return s.TCPListen.Close()
}
return nil
},
})
s.RunnerGroup.Add(&runnergroup.Runner{
Start: func() error {
return s.RunUDPServer()
},
Stop: func() error {
if s.UDPConn != nil {
return s.UDPConn.Close()
}
return nil
},
})
return s.RunnerGroup.Wait()
}
// RunTCPServer starts tcp server
func (s *Server) RunTCPServer() error {
var err error
s.TCPListen, err = net.ListenTCP("tcp", s.TCPAddr)
if err != nil {
return err
}
defer s.TCPListen.Close()
for {
c, err := s.TCPListen.AcceptTCP()
if err != nil {
return err
}
go func(c *net.TCPConn) {
defer c.Close()
if s.TCPTimeout != 0 {
if err := c.SetDeadline(time.Now().Add(time.Duration(s.TCPTimeout) * time.Second)); err != nil {
log.Println(err)
return
}
}
if err := s.Negotiate(c); err != nil {
log.Println(err)
return
}
r, err := s.GetRequest(c)
if err != nil {
log.Println(err)
return
}
if err := s.Handle.TCPHandle(s, c, r); err != nil {
log.Println(err)
}
}(c)
}
return nil
}
// RunUDPServer starts udp server
func (s *Server) RunUDPServer() error {
var err error
s.UDPConn, err = net.ListenUDP("udp", s.UDPAddr)
if err != nil {
return err
}
defer s.UDPConn.Close()
for {
b := make([]byte, 65507)
n, addr, err := s.UDPConn.ReadFromUDP(b)
if err != nil {
return err
}
go func(addr *net.UDPAddr, b []byte) {
d, err := NewDatagramFromBytes(b)
if err != nil {
log.Println(err)
return
}
if d.Frag != 0x00 {
log.Println("Ignore frag", d.Frag)
return
}
if err := s.Handle.UDPHandle(s, addr, d); err != nil {
log.Println(err)
return
}
}(addr, b[0:n])
}
return nil
}
// Stop server
func (s *Server) Shutdown() error {
return s.RunnerGroup.Done()
}
// Handler handle tcp, udp request
type Handler interface {
// Request has not been replied yet
TCPHandle(*Server, *net.TCPConn, *Request) error
UDPHandle(*Server, *net.UDPAddr, *Datagram) error
}
// DefaultHandle implements Handler interface
type DefaultHandle struct {
}
// TCPHandle auto handle request. You may prefer to do yourself.
func (h *DefaultHandle) TCPHandle(s *Server, c *net.TCPConn, r *Request) error {
if r.Cmd == CmdConnect {
rc, err := r.Connect(c)
if err != nil {
return err
}
defer rc.Close()
go func() {
var bf [1024 * 2]byte
for {
if s.TCPTimeout != 0 {
if err := rc.SetDeadline(time.Now().Add(time.Duration(s.TCPTimeout) * time.Second)); err != nil {
return
}
}
i, err := rc.Read(bf[:])
if err != nil {
return
}
if _, err := c.Write(bf[0:i]); err != nil {
return
}
}
}()
var bf [1024 * 2]byte
for {
if s.TCPTimeout != 0 {
if err := c.SetDeadline(time.Now().Add(time.Duration(s.TCPTimeout) * time.Second)); err != nil {
return nil
}
}
i, err := c.Read(bf[:])
if err != nil {
return nil
}
if _, err := rc.Write(bf[0:i]); err != nil {
return nil
}
}
return nil
}
if r.Cmd == CmdUDP {
caddr, err := r.UDP(c, s.ServerAddr)
if err != nil {
return err
}
ch := make(chan byte)
defer close(ch)
s.AssociatedUDP.Set(caddr.String(), ch, -1)
defer s.AssociatedUDP.Delete(caddr.String())
io.Copy(ioutil.Discard, c)
if Debug {
log.Printf("A tcp connection that udp %#v associated closed\n", caddr.String())
}
return nil
}
return ErrUnsupportCmd
}
// UDPHandle auto handle packet. You may prefer to do yourself.
func (h *DefaultHandle) UDPHandle(s *Server, addr *net.UDPAddr, d *Datagram) error {
src := addr.String()
var ch chan byte
if s.LimitUDP {
any, ok := s.AssociatedUDP.Get(src)
if !ok {
return fmt.Errorf("This udp address %s is not associated with tcp", src)
}
ch = any.(chan byte)
}
send := func(ue *UDPExchange, data []byte) error {
select {
case <-ch:
return fmt.Errorf("This udp address %s is not associated with tcp", src)
default:
_, err := ue.RemoteConn.Write(data)
if err != nil {
return err
}
if Debug {
log.Printf("Sent UDP data to remote. client: %#v server: %#v remote: %#v data: %#v\n", ue.ClientAddr.String(), ue.RemoteConn.LocalAddr().String(), ue.RemoteConn.RemoteAddr().String(), data)
}
}
return nil
}
dst := d.Address()
var ue *UDPExchange
iue, ok := s.UDPExchanges.Get(src + dst)
if ok {
ue = iue.(*UDPExchange)
err := send(ue, d.Data)
if err == nil {
return nil
}
if !strings.Contains(err.Error(), "closed") {
return err
}
}
if Debug {
log.Printf("Call udp: %#v\n", dst)
}
raddr, err := net.ResolveUDPAddr("udp", dst)
if err != nil {
return err
}
rc, err := Dial.DialUDP("udp", nil, raddr)
if err != nil {
return err
}
ue = &UDPExchange{
ClientAddr: addr,
RemoteConn: rc,
}
if Debug {
log.Printf("Created remote UDP conn for client. client: %#v server: %#v remote: %#v\n", addr.String(), ue.RemoteConn.LocalAddr().String(), d.Address())
}
if err := send(ue, d.Data); err != nil {
ue.RemoteConn.Close()
return err
}
s.UDPExchanges.Set(src+dst, ue, -1)
go func(ue *UDPExchange, dst string) {
defer func() {
ue.RemoteConn.Close()
s.UDPExchanges.Delete(ue.ClientAddr.String() + dst)
}()
var b [65507]byte
for {
select {
case <-ch:
if Debug {
log.Printf("The tcp that udp address %s associated closed\n", ue.ClientAddr.String())
}
return
default:
if s.UDPTimeout != 0 {
if err := ue.RemoteConn.SetDeadline(time.Now().Add(time.Duration(s.UDPTimeout) * time.Second)); err != nil {
log.Println(err)
return
}
}
n, err := ue.RemoteConn.Read(b[:])
if err != nil {
return
}
if Debug {
log.Printf("Got UDP data from remote. client: %#v server: %#v remote: %#v data: %#v\n", ue.ClientAddr.String(), ue.RemoteConn.LocalAddr().String(), ue.RemoteConn.RemoteAddr().String(), b[0:n])
}
a, addr, port, err := ParseAddress(dst)
if err != nil {
log.Println(err)
return
}
d1 := NewDatagram(a, addr, port, b[0:n])
if _, err := s.UDPConn.WriteToUDP(d1.Bytes(), ue.ClientAddr); err != nil {
return
}
if Debug {
log.Printf("Sent Datagram. client: %#v server: %#v remote: %#v data: %#v %#v %#v %#v %#v %#v datagram address: %#v\n", ue.ClientAddr.String(), ue.RemoteConn.LocalAddr().String(), ue.RemoteConn.RemoteAddr().String(), d1.Rsv, d1.Frag, d1.Atyp, d1.DstAddr, d1.DstPort, d1.Data, d1.Address())
}
}
}
}(ue, dst)
return nil
}