package socks5 import ( "errors" "fmt" "io" "io/ioutil" "log" "net" "strings" "time" cache "github.com/patrickmn/go-cache" "github.com/txthinking/runnergroup" ) var ( // ErrUnsupportCmd is the error when got unsupport command ErrUnsupportCmd = errors.New("Unsupport Command") // ErrUserPassAuth is the error when got invalid username or password ErrUserPassAuth = errors.New("Invalid Username or Password for Auth") ) // Server is socks5 server wrapper type Server struct { UserName string Password string Method byte SupportedCommands []byte TCPAddr *net.TCPAddr UDPAddr *net.UDPAddr ServerAddr *net.UDPAddr TCPListen *net.TCPListener UDPConn *net.UDPConn UDPExchanges *cache.Cache TCPTimeout int UDPTimeout int Handle Handler AssociatedUDP *cache.Cache RunnerGroup *runnergroup.RunnerGroup // RFC: [UDP ASSOCIATE] The server MAY use this information to limit access to the association. Default false, no limit. LimitUDP bool } // UDPExchange used to store client address and remote connection type UDPExchange struct { ClientAddr *net.UDPAddr RemoteConn *net.UDPConn } // NewClassicServer return a server which allow none method func NewClassicServer(addr, ip, username, password string, tcpTimeout, udpTimeout int) (*Server, error) { _, p, err := net.SplitHostPort(addr) if err != nil { return nil, err } taddr, err := net.ResolveTCPAddr("tcp", addr) if err != nil { return nil, err } uaddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, err } saddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(ip, p)) if err != nil { return nil, err } m := MethodNone if username != "" && password != "" { m = MethodUsernamePassword } cs := cache.New(cache.NoExpiration, cache.NoExpiration) cs1 := cache.New(cache.NoExpiration, cache.NoExpiration) s := &Server{ Method: m, UserName: username, Password: password, SupportedCommands: []byte{CmdConnect, CmdUDP}, TCPAddr: taddr, UDPAddr: uaddr, ServerAddr: saddr, UDPExchanges: cs, TCPTimeout: tcpTimeout, UDPTimeout: udpTimeout, AssociatedUDP: cs1, RunnerGroup: runnergroup.New(), } return s, nil } // Negotiate handle negotiate packet. // This method do not handle gssapi(0x01) method now. // Error or OK both replied. func (s *Server) Negotiate(rw io.ReadWriter) error { rq, err := NewNegotiationRequestFrom(rw) if err != nil { return err } var got bool var m byte for _, m = range rq.Methods { if m == s.Method { got = true } } if !got { rp := NewNegotiationReply(MethodUnsupportAll) if _, err := rp.WriteTo(rw); err != nil { return err } } rp := NewNegotiationReply(s.Method) if _, err := rp.WriteTo(rw); err != nil { return err } if s.Method == MethodUsernamePassword { urq, err := NewUserPassNegotiationRequestFrom(rw) if err != nil { return err } if string(urq.Uname) != s.UserName || string(urq.Passwd) != s.Password { urp := NewUserPassNegotiationReply(UserPassStatusFailure) if _, err := urp.WriteTo(rw); err != nil { return err } return ErrUserPassAuth } urp := NewUserPassNegotiationReply(UserPassStatusSuccess) if _, err := urp.WriteTo(rw); err != nil { return err } } return nil } // GetRequest get request packet from client, and check command according to SupportedCommands // Error replied. func (s *Server) GetRequest(rw io.ReadWriter) (*Request, error) { r, err := NewRequestFrom(rw) if err != nil { return nil, err } var supported bool for _, c := range s.SupportedCommands { if r.Cmd == c { supported = true break } } if !supported { var p *Reply if r.Atyp == ATYPIPv4 || r.Atyp == ATYPDomain { p = NewReply(RepCommandNotSupported, ATYPIPv4, []byte{0x00, 0x00, 0x00, 0x00}, []byte{0x00, 0x00}) } else { p = NewReply(RepCommandNotSupported, ATYPIPv6, []byte(net.IPv6zero), []byte{0x00, 0x00}) } if _, err := p.WriteTo(rw); err != nil { return nil, err } return nil, ErrUnsupportCmd } return r, nil } // Run server func (s *Server) ListenAndServe(h Handler) error { if h == nil { s.Handle = &DefaultHandle{} } else { s.Handle = h } s.RunnerGroup.Add(&runnergroup.Runner{ Start: func() error { return s.RunTCPServer() }, Stop: func() error { if s.TCPListen != nil { return s.TCPListen.Close() } return nil }, }) s.RunnerGroup.Add(&runnergroup.Runner{ Start: func() error { return s.RunUDPServer() }, Stop: func() error { if s.UDPConn != nil { return s.UDPConn.Close() } return nil }, }) return s.RunnerGroup.Wait() } // RunTCPServer starts tcp server func (s *Server) RunTCPServer() error { var err error s.TCPListen, err = net.ListenTCP("tcp", s.TCPAddr) if err != nil { return err } defer s.TCPListen.Close() for { c, err := s.TCPListen.AcceptTCP() if err != nil { return err } go func(c *net.TCPConn) { defer c.Close() if s.TCPTimeout != 0 { if err := c.SetDeadline(time.Now().Add(time.Duration(s.TCPTimeout) * time.Second)); err != nil { log.Println(err) return } } if err := s.Negotiate(c); err != nil { log.Println(err) return } r, err := s.GetRequest(c) if err != nil { log.Println(err) return } if err := s.Handle.TCPHandle(s, c, r); err != nil { log.Println(err) } }(c) } return nil } // RunUDPServer starts udp server func (s *Server) RunUDPServer() error { var err error s.UDPConn, err = net.ListenUDP("udp", s.UDPAddr) if err != nil { return err } defer s.UDPConn.Close() for { b := make([]byte, 65507) n, addr, err := s.UDPConn.ReadFromUDP(b) if err != nil { return err } go func(addr *net.UDPAddr, b []byte) { d, err := NewDatagramFromBytes(b) if err != nil { log.Println(err) return } if d.Frag != 0x00 { log.Println("Ignore frag", d.Frag) return } if err := s.Handle.UDPHandle(s, addr, d); err != nil { log.Println(err) return } }(addr, b[0:n]) } return nil } // Stop server func (s *Server) Shutdown() error { return s.RunnerGroup.Done() } // Handler handle tcp, udp request type Handler interface { // Request has not been replied yet TCPHandle(*Server, *net.TCPConn, *Request) error UDPHandle(*Server, *net.UDPAddr, *Datagram) error } // DefaultHandle implements Handler interface type DefaultHandle struct { } // TCPHandle auto handle request. You may prefer to do yourself. func (h *DefaultHandle) TCPHandle(s *Server, c *net.TCPConn, r *Request) error { if r.Cmd == CmdConnect { rc, err := r.Connect(c) if err != nil { return err } defer rc.Close() go func() { var bf [1024 * 2]byte for { if s.TCPTimeout != 0 { if err := rc.SetDeadline(time.Now().Add(time.Duration(s.TCPTimeout) * time.Second)); err != nil { return } } i, err := rc.Read(bf[:]) if err != nil { return } if _, err := c.Write(bf[0:i]); err != nil { return } } }() var bf [1024 * 2]byte for { if s.TCPTimeout != 0 { if err := c.SetDeadline(time.Now().Add(time.Duration(s.TCPTimeout) * time.Second)); err != nil { return nil } } i, err := c.Read(bf[:]) if err != nil { return nil } if _, err := rc.Write(bf[0:i]); err != nil { return nil } } return nil } if r.Cmd == CmdUDP { caddr, err := r.UDP(c, s.ServerAddr) if err != nil { return err } ch := make(chan byte) defer close(ch) s.AssociatedUDP.Set(caddr.String(), ch, -1) defer s.AssociatedUDP.Delete(caddr.String()) io.Copy(ioutil.Discard, c) if Debug { log.Printf("A tcp connection that udp %#v associated closed\n", caddr.String()) } return nil } return ErrUnsupportCmd } // UDPHandle auto handle packet. You may prefer to do yourself. func (h *DefaultHandle) UDPHandle(s *Server, addr *net.UDPAddr, d *Datagram) error { src := addr.String() var ch chan byte if s.LimitUDP { any, ok := s.AssociatedUDP.Get(src) if !ok { return fmt.Errorf("This udp address %s is not associated with tcp", src) } ch = any.(chan byte) } send := func(ue *UDPExchange, data []byte) error { select { case <-ch: return fmt.Errorf("This udp address %s is not associated with tcp", src) default: _, err := ue.RemoteConn.Write(data) if err != nil { return err } if Debug { log.Printf("Sent UDP data to remote. client: %#v server: %#v remote: %#v data: %#v\n", ue.ClientAddr.String(), ue.RemoteConn.LocalAddr().String(), ue.RemoteConn.RemoteAddr().String(), data) } } return nil } dst := d.Address() var ue *UDPExchange iue, ok := s.UDPExchanges.Get(src + dst) if ok { ue = iue.(*UDPExchange) err := send(ue, d.Data) if err == nil { return nil } if !strings.Contains(err.Error(), "closed") { return err } } if Debug { log.Printf("Call udp: %#v\n", dst) } raddr, err := net.ResolveUDPAddr("udp", dst) if err != nil { return err } rc, err := Dial.DialUDP("udp", nil, raddr) if err != nil { return err } ue = &UDPExchange{ ClientAddr: addr, RemoteConn: rc, } if Debug { log.Printf("Created remote UDP conn for client. client: %#v server: %#v remote: %#v\n", addr.String(), ue.RemoteConn.LocalAddr().String(), d.Address()) } if err := send(ue, d.Data); err != nil { ue.RemoteConn.Close() return err } s.UDPExchanges.Set(src+dst, ue, -1) go func(ue *UDPExchange, dst string) { defer func() { ue.RemoteConn.Close() s.UDPExchanges.Delete(ue.ClientAddr.String() + dst) }() var b [65507]byte for { select { case <-ch: if Debug { log.Printf("The tcp that udp address %s associated closed\n", ue.ClientAddr.String()) } return default: if s.UDPTimeout != 0 { if err := ue.RemoteConn.SetDeadline(time.Now().Add(time.Duration(s.UDPTimeout) * time.Second)); err != nil { log.Println(err) return } } n, err := ue.RemoteConn.Read(b[:]) if err != nil { return } if Debug { log.Printf("Got UDP data from remote. client: %#v server: %#v remote: %#v data: %#v\n", ue.ClientAddr.String(), ue.RemoteConn.LocalAddr().String(), ue.RemoteConn.RemoteAddr().String(), b[0:n]) } a, addr, port, err := ParseAddress(dst) if err != nil { log.Println(err) return } d1 := NewDatagram(a, addr, port, b[0:n]) if _, err := s.UDPConn.WriteToUDP(d1.Bytes(), ue.ClientAddr); err != nil { return } if Debug { log.Printf("Sent Datagram. client: %#v server: %#v remote: %#v data: %#v %#v %#v %#v %#v %#v datagram address: %#v\n", ue.ClientAddr.String(), ue.RemoteConn.LocalAddr().String(), ue.RemoteConn.RemoteAddr().String(), d1.Rsv, d1.Frag, d1.Atyp, d1.DstAddr, d1.DstPort, d1.Data, d1.Address()) } } } }(ue, dst) return nil }