diff --git a/config/cfg.go b/config/cfg.go index a851396..dfac4ae 100644 --- a/config/cfg.go +++ b/config/cfg.go @@ -13,7 +13,7 @@ type Config struct { IP string `yaml:"IP"` SubNet string `yaml:"SubNet"` PrivateKey string `yaml:"PrivateKey"` - Network string `yaml:"Network"` // Network udp or ws (WIP) + Network string `yaml:"Network"` // Network udp, tcp or ws (WIP) EndPoint string `yaml:"EndPoint"` MTU int64 `yaml:"MTU"` SpeedLoop uint16 `yaml:"SpeedLoop"` diff --git a/go.mod b/go.mod index 9dc1031..55159d2 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/fumiama/WireGold go 1.20 require ( - github.com/FloatTech/ttl v0.0.0-20230307105452-d6f7b2b647d1 + github.com/FloatTech/ttl v0.0.0-20240715074357-190755f3fece github.com/RomiChan/syncx v0.0.0-20240418144900-b7402ffdebc7 github.com/fumiama/blake2b-simd v0.0.0-20220412110131-4481822068bb github.com/fumiama/go-base16384 v1.7.0 diff --git a/go.sum b/go.sum index 6d1ff6b..0f278bb 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/FloatTech/ttl v0.0.0-20230307105452-d6f7b2b647d1 h1:g4pTnDJUW4VbJ9NvoRfUvdjDrHz/6QhfN/LoIIpICbo= -github.com/FloatTech/ttl v0.0.0-20230307105452-d6f7b2b647d1/go.mod h1:fHZFWGquNXuHttu9dUYoKuNbm3dzLETnIOnm1muSfDs= +github.com/FloatTech/ttl v0.0.0-20240715074357-190755f3fece h1:RIrGO+hIOoXxUh0T3TDaWNvinkXH9S2i12cWivT2MZ4= +github.com/FloatTech/ttl v0.0.0-20240715074357-190755f3fece/go.mod h1:fHZFWGquNXuHttu9dUYoKuNbm3dzLETnIOnm1muSfDs= github.com/RomiChan/syncx v0.0.0-20240418144900-b7402ffdebc7 h1:S/ferNiehVjNaBMNNBxUjLtVmP/YWD6Yh79RfPv4ehU= github.com/RomiChan/syncx v0.0.0-20240418144900-b7402ffdebc7/go.mod h1:vD7Ra3Q9onRtojoY5sMCLQ7JBgjUsrXDnDKyFxqpf9w= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/gold/head/packet.go b/gold/head/packet.go index 0699413..17cd7cd 100644 --- a/gold/head/packet.go +++ b/gold/head/packet.go @@ -12,6 +12,11 @@ import ( "github.com/sirupsen/logrus" ) +var ( + ErrBadCRCChecksum = errors.New("bad crc checksum") + ErrDataLenLT60 = errors.New("data len < 60") +) + type PacketFlags uint16 func (pf PacketFlags) IsValid() bool { @@ -97,12 +102,12 @@ func NewPacket(proto uint8, srcPort uint16, dst net.IP, dstPort uint16, data []b // Unmarshal 将 data 的数据解码到自身 func (p *Packet) Unmarshal(data []byte) (complete bool, err error) { if len(data) < 60 { - err = errors.New("data len < 60") + err = ErrDataLenLT60 return } p.crc64 = binary.LittleEndian.Uint64(data[52:60]) if crc64.Checksum(data[:52], crc64.MakeTable(crc64.ISO)) != p.crc64 { - err = errors.New("bad crc checksum") + err = ErrBadCRCChecksum return } diff --git a/gold/link/link.go b/gold/link/link.go index 1fe4b0c..8144b48 100644 --- a/gold/link/link.go +++ b/gold/link/link.go @@ -12,6 +12,10 @@ import ( base14 "github.com/fumiama/go-base16384" ) +var ( + ErrPerrNotExist = errors.New("peer not exist") +) + // Link 是本机到 peer 的连接抽象 type Link struct { // peer 的公钥 @@ -56,7 +60,7 @@ func (m *Me) Connect(peer string) (*Link, error) { if ok { return p, nil } - return nil, errors.New("peer not exist") + return nil, ErrPerrNotExist } // Close 关闭到 peer 的连接 diff --git a/gold/link/listen.go b/gold/link/listen.go index 9b00c01..7221974 100644 --- a/gold/link/listen.go +++ b/gold/link/listen.go @@ -56,7 +56,7 @@ func (m *Me) listen() (conn p2p.Conn, err error) { return } if err != nil { - logrus.Warnln("[listen] read from udp err, reconnect:", err) + logrus.Warnln("[listen] read from conn err, reconnect:", err) conn, err = m.ep.Listen() if err != nil { logrus.Errorln("[listen] reconnect udp err:", err) diff --git a/gold/link/me.go b/gold/link/me.go index 4309182..bd051ee 100644 --- a/gold/link/me.go +++ b/gold/link/me.go @@ -51,12 +51,15 @@ type Me struct { srcport, dstport, mtu, speedloop uint16 // 报头掩码 mask uint64 + // 本机网络端点初始化配置 + networkconfigs []any } type MyConfig struct { MyIPwithMask string MyEndpoint string Network string + NetworkConfigs []any PrivateKey *[32]byte NIC lower.NICIO SrcPort, DstPort, MTU, SpeedLoop uint16 @@ -71,7 +74,8 @@ func NewMe(cfg *MyConfig) (m Me) { if nw == "" { nw = "udp" } - m.ep, err = p2p.NewEndPoint(nw, cfg.MyEndpoint) + m.networkconfigs = cfg.NetworkConfigs + m.ep, err = p2p.NewEndPoint(nw, cfg.MyEndpoint, m.networkconfigs...) if err != nil { panic(err) } diff --git a/gold/link/nat.go b/gold/link/nat.go index 9850400..5e4942a 100644 --- a/gold/link/nat.go +++ b/gold/link/nat.go @@ -44,7 +44,12 @@ func (l *Link) onNotify(packet []byte) { // ---- 遍历 Notify,注册对方的 endpoint 到 // ---- connections,注意使用读写锁connmapmu for peer, ep := range notify { - addr, err := p2p.NewEndPoint(ep[0], ep[1]) + nw, epstr := ep[0], ep[1] + if nw != l.me.ep.Network() { + logrus.Warnln("[nat] ignore different network notify", nw, "addr", epstr) + continue + } + addr, err := p2p.NewEndPoint(nw, epstr, l.me.networkconfigs...) if err == nil { p, ok := l.me.IsInPeer(peer) if ok { diff --git a/gold/link/peer.go b/gold/link/peer.go index ae2e6c0..ee42562 100644 --- a/gold/link/peer.go +++ b/gold/link/peer.go @@ -73,7 +73,7 @@ func (m *Me) AddPeer(cfg *PeerConfig) (l *Link) { } } if cfg.EndPoint != "" { - e, err := p2p.NewEndPoint(m.ep.Network(), cfg.EndPoint) + e, err := p2p.NewEndPoint(m.ep.Network(), cfg.EndPoint, m.networkconfigs...) if err != nil { panic(err) } diff --git a/gold/link/send.go b/gold/link/send.go index 5167afa..3215e00 100644 --- a/gold/link/send.go +++ b/gold/link/send.go @@ -14,6 +14,11 @@ import ( "github.com/sirupsen/logrus" ) +var ( + ErrDropBigDontFragTransPkt = errors.New("drop big don't fragmnet trans packet") + ErrTTL = errors.New("ttl exceeded") +) + // WriteAndPut 向 peer 发包并将包放回缓存池 func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) { defer p.Put() @@ -37,7 +42,7 @@ func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) { return l.write(p, teatype, sndcnt, uint32(remlen), 0, istransfer, false) } if istransfer && p.Flags.DontFrag() && remlen > delta { - return 0, errors.New("drop don't fragmnet big trans packet") + return 0, ErrDropBigDontFragTransPkt } ttl := p.TTL totl := uint32(remlen) @@ -93,11 +98,11 @@ func (l *Link) write(p *head.Packet, teatype uint8, additional uint16, datasz ui d, cl = p.Marshal(l.me.me, teatype, additional, datasz, offset, false, hasmore) } if d == nil { - return 0, errors.New("[send] ttl exceeded") + return 0, ErrTTL } peerep := l.endpoint if peerep == nil { - return 0, errors.New("[send] nil endpoint of " + p.Dst.String()) + return 0, errors.New("nil endpoint of " + p.Dst.String()) } bound := 64 endl := "..." diff --git a/gold/p2p/define.go b/gold/p2p/define.go index b6e1795..1959280 100644 --- a/gold/p2p/define.go +++ b/gold/p2p/define.go @@ -8,6 +8,10 @@ import ( "github.com/RomiChan/syncx" ) +var ( + ErrEndpointTypeMistatch = errors.New("endpoint type mismatch") +) + type Initializer func(endpoint string, configs ...any) EndPoint var factory syncx.Map[string, Initializer] diff --git a/gold/p2p/tcp/init.go b/gold/p2p/tcp/init.go new file mode 100644 index 0000000..cb240cc --- /dev/null +++ b/gold/p2p/tcp/init.go @@ -0,0 +1,41 @@ +package tcp + +import ( + "net" + "net/netip" + "time" + + "github.com/fumiama/WireGold/gold/p2p" +) + +type Config struct { + PeersTimeout time.Duration + ReceiveChannelSize int +} + +func NewEndpoint(endpoint string, configs ...any) p2p.EndPoint { + return newEndpoint(endpoint, configs...) +} + +func newEndpoint(endpoint string, configs ...any) *EndPoint { + var cfg *Config + if len(configs) == 0 || configs[0] == nil { + cfg = &Config{} + } else { + cfg = configs[0].(*Config) + } + return &EndPoint{ + addr: net.TCPAddrFromAddrPort( + netip.MustParseAddrPort(endpoint), + ), + peerstimeout: cfg.PeersTimeout, + recvchansize: cfg.ReceiveChannelSize, + } +} + +func init() { + _, hasexist := p2p.Register("tcp", NewEndpoint) + if hasexist { + panic("network tcp has been registered") + } +} diff --git a/gold/p2p/tcp/pdu.go b/gold/p2p/tcp/pdu.go new file mode 100644 index 0000000..25f5488 --- /dev/null +++ b/gold/p2p/tcp/pdu.go @@ -0,0 +1,65 @@ +package tcp + +import ( + "encoding/binary" + "io" + "net" + + "github.com/fumiama/WireGold/helper" +) + +type packetType uint8 + +const ( + packetTypeKeepAlive packetType = iota + packetTypeNormal +) + +type packet struct { + typ packetType + len uint16 + dat []byte + io.ReaderFrom + io.WriterTo +} + +func (p *packet) pack() (net.Buffers, func()) { + d, cl := helper.OpenWriterF(func(w *helper.Writer) { + w.WriteByte(byte(p.typ)) + w.WriteUInt16(p.len) + }) + return net.Buffers{d, p.dat}, cl +} + +func (p *packet) Read(_ []byte) (int, error) { + panic("stub") +} + +func (p *packet) Write(_ []byte) (int, error) { + panic("stub") +} + +func (p *packet) ReadFrom(r io.Reader) (n int64, err error) { + var buf [3]byte + cnt, err := io.ReadFull(r, buf[:]) + n = int64(cnt) + if err != nil { + return + } + p.typ = packetType(buf[0]) + p.len = binary.LittleEndian.Uint16(buf[1:3]) + w := helper.SelectWriter() + copied, err := io.CopyN(w, r, int64(p.len)) + n += copied + if err != nil { + return + } + p.dat = w.Bytes() + return +} + +func (p *packet) WriteTo(w io.Writer) (n int64, err error) { + buf, cl := p.pack() + defer cl() + return io.Copy(w, &buf) +} diff --git a/gold/p2p/tcp/tcp.go b/gold/p2p/tcp/tcp.go new file mode 100644 index 0000000..3604923 --- /dev/null +++ b/gold/p2p/tcp/tcp.go @@ -0,0 +1,219 @@ +package tcp + +import ( + "errors" + "io" + "math/rand" + "net" + "strconv" + "time" + + "github.com/FloatTech/ttl" + "github.com/fumiama/WireGold/gold/p2p" + "github.com/fumiama/WireGold/helper" + "github.com/sirupsen/logrus" +) + +type EndPoint struct { + addr *net.TCPAddr + peerstimeout time.Duration + recvchansize int +} + +func (ep *EndPoint) String() string { + return ep.addr.String() +} + +func (ep *EndPoint) Network() string { + return ep.addr.Network() +} + +func (ep *EndPoint) Euqal(ep2 p2p.EndPoint) bool { + tcpep2, ok := ep2.(*EndPoint) + if !ok { + return false + } + tcpep1 := ep + return tcpep1.addr.IP.Equal(tcpep2.addr.IP) && + tcpep1.addr.Port == tcpep2.addr.Port && + tcpep1.addr.Zone == tcpep2.addr.Zone +} + +func (ep *EndPoint) Listen() (p2p.Conn, error) { + lstn, err := net.ListenTCP(ep.addr.Network(), ep.addr) + if err != nil { + return nil, err + } + ep.addr = lstn.Addr().(*net.TCPAddr) + timeout := ep.peerstimeout + if timeout < time.Second { + timeout = time.Second + } + chansz := ep.recvchansize + if chansz < 32 { + chansz = 32 + } + conn := &Conn{ + addr: ep, + lstn: lstn, + peers: ttl.NewCacheOn(timeout, [4]func(string, *net.TCPConn){ + nil, + nil, + func(s string, t *net.TCPConn) { + err := t.Close() + if err != nil { + logrus.Debugln("[tcp] close conn from", ep, "to", s, "err:", err) + } else { + logrus.Debugln("[tcp] close conn from", ep, "to", s) + } + }, + ep.keepAlive, + }), + recv: make(chan *connrecv, chansz), + } + go conn.accept() + return conn, nil +} + +func (ep *EndPoint) keepAlive(_ string, t *net.TCPConn) { + _, err := io.Copy(t, &packet{ + typ: packetTypeKeepAlive, + len: 1, + dat: []byte{byte(rand.Intn(256))}, + }) + if err != nil { + logrus.Debugln("[tcp] write keepalive from", ep, "to conn", t.RemoteAddr(), "err:", err) + } +} + +type connrecv struct { + addr *EndPoint // cast from tcpconn.RemoteAddr() + pckt packet +} + +// Conn 伪装成无状态的有状态连接 +type Conn struct { + addr *EndPoint + lstn *net.TCPListener + peers *ttl.Cache[string, *net.TCPConn] + recv chan *connrecv +} + +func (conn *Conn) accept() { + for { + tcpconn, err := conn.lstn.AcceptTCP() + if err != nil { + if errors.Is(err, net.ErrClosed) { // normal close + logrus.Infoln("[tcp] accept of", conn.addr, "got closed") + return + } + if conn.addr == nil || conn.lstn == nil || conn.peers == nil || conn.recv == nil { + return + } + logrus.Warnln("[tcp] accept on", conn.addr, "err:", err) + _ = conn.Close() + newc, err := conn.addr.Listen() + if err != nil { + logrus.Warn("[tcp] re-listen on", conn.addr, "err:", err) + return + } + *conn = *newc.(*Conn) + logrus.Info("[tcp] re-listen on", conn.addr) + continue + } + ep := newEndpoint(tcpconn.RemoteAddr().String(), &Config{ + PeersTimeout: conn.addr.peerstimeout, + ReceiveChannelSize: conn.addr.recvchansize, + }) + logrus.Debugln("[tcp] accept from", ep) + conn.peers.Set(ep.String(), tcpconn) + go conn.receive(ep) + } +} + +func (conn *Conn) receive(ep *EndPoint) { + for { + r := &connrecv{addr: ep} + if conn.addr == nil || conn.lstn == nil || conn.peers == nil || conn.recv == nil { + return + } + tcpconn := conn.peers.Get(ep.String()) + if tcpconn == nil { + return + } + _, err := io.Copy(&r.pckt, tcpconn) + if err != nil { + logrus.Debugln("[tcp] recv from", ep, "err:", err) + return + } + logrus.Debugln("[tcp] dispatch packet from", ep, "typ", r.pckt.typ, "len", r.pckt.len) + conn.recv <- r + } +} + +func (conn *Conn) Close() error { + if conn.lstn != nil { + _ = conn.lstn.Close() + } + if conn.peers != nil { + conn.peers.Destroy() + } + if conn.recv != nil { + close(conn.recv) + } + conn.addr = nil + conn.lstn = nil + conn.peers = nil + conn.recv = nil + return nil +} + +func (conn *Conn) String() string { + return conn.addr.String() +} + +func (conn *Conn) LocalAddr() p2p.EndPoint { + return conn.addr +} + +func (conn *Conn) ReadFromPeer(b []byte) (int, p2p.EndPoint, error) { + var p *connrecv + for { + p = <-conn.recv + if p == nil { + return 0, nil, net.ErrClosed + } + if p.pckt.typ == packetTypeNormal { + break + } + defer helper.PutBytes(p.pckt.dat) + } + n := copy(b, p.pckt.dat) + return n, p.addr, nil +} + +func (conn *Conn) WriteToPeer(b []byte, ep p2p.EndPoint) (n int, err error) { + tcpep, ok := ep.(*EndPoint) + if !ok { + return 0, p2p.ErrEndpointTypeMistatch + } + blen := len(b) + if blen >= 65536 { + return 0, errors.New("data size " + strconv.Itoa(blen) + " is too large") + } + tcpconn := conn.peers.Get(tcpep.String()) + if tcpconn == nil { + // must use another port to send because there's no exsiting conn + tcpconn, err = net.DialTCP(tcpep.Network(), nil, tcpep.addr) + if err != nil { + return + } + conn.peers.Set(tcpep.String(), tcpconn) + } + cnt, err := io.Copy(tcpconn, &packet{ + typ: packetTypeNormal, + len: uint16(blen), + dat: b, + }) + return int(cnt) - 3, err +} diff --git a/gold/p2p/udp/init.go b/gold/p2p/udp/init.go index 0cdf014..c352b64 100644 --- a/gold/p2p/udp/init.go +++ b/gold/p2p/udp/init.go @@ -1,17 +1,12 @@ package udp import ( - "errors" "net" "net/netip" "github.com/fumiama/WireGold/gold/p2p" ) -var ( - ErrEndpointTypeMistatch = errors.New("endpoint type mismatch") -) - func NewEndpoint(endpoint string, _ ...any) p2p.EndPoint { return (*EndPoint)(net.UDPAddrFromAddrPort( netip.MustParseAddrPort(endpoint), diff --git a/gold/p2p/udp/udp.go b/gold/p2p/udp/udp.go index 3b40374..7694bb7 100644 --- a/gold/p2p/udp/udp.go +++ b/gold/p2p/udp/udp.go @@ -52,7 +52,7 @@ func (conn *Conn) ReadFromPeer(b []byte) (int, p2p.EndPoint, error) { func (conn *Conn) WriteToPeer(b []byte, ep p2p.EndPoint) (int, error) { udpep, ok := ep.(*EndPoint) if !ok { - return 0, ErrEndpointTypeMistatch + return 0, p2p.ErrEndpointTypeMistatch } return (*net.UDPConn)(conn).WriteTo(b, (*net.UDPAddr)(udpep)) } diff --git a/upper/services/tunnel/tunnel.go b/upper/services/tunnel/tunnel.go index 7c40bfb..fc33d8e 100644 --- a/upper/services/tunnel/tunnel.go +++ b/upper/services/tunnel/tunnel.go @@ -8,6 +8,7 @@ import ( "github.com/sirupsen/logrus" + _ "github.com/fumiama/WireGold/gold/p2p/tcp" // support tcp connection _ "github.com/fumiama/WireGold/gold/p2p/udp" // support udp connection "github.com/fumiama/WireGold/gold/head" diff --git a/upper/services/tunnel/tunnel_test.go b/upper/services/tunnel/tunnel_test.go index 89434d1..d381b52 100644 --- a/upper/services/tunnel/tunnel_test.go +++ b/upper/services/tunnel/tunnel_test.go @@ -16,7 +16,7 @@ import ( "github.com/fumiama/WireGold/helper" ) -func testTunnel(t *testing.T, isplain bool, pshk *[32]byte) { +func testTunnel(t *testing.T, nw string, isplain bool, pshk *[32]byte) { selfpk, err := curve.New(nil) if err != nil { panic(err) @@ -33,6 +33,7 @@ func testTunnel(t *testing.T, isplain bool, pshk *[32]byte) { m := link.NewMe(&link.MyConfig{ MyIPwithMask: "192.168.1.2/32", MyEndpoint: "127.0.0.1:0", + Network: nw, PrivateKey: selfpk.Private(), SrcPort: 1, DstPort: 1, @@ -43,6 +44,7 @@ func testTunnel(t *testing.T, isplain bool, pshk *[32]byte) { p := link.NewMe(&link.MyConfig{ MyIPwithMask: "192.168.1.3/32", MyEndpoint: "127.0.0.1:0", + Network: nw, PrivateKey: peerpk.Private(), SrcPort: 1, DstPort: 1, @@ -146,20 +148,36 @@ func testTunnel(t *testing.T, isplain bool, pshk *[32]byte) { } } -func TestTunnel(t *testing.T) { +func TestTunnelUDP(t *testing.T) { logrus.SetLevel(logrus.DebugLevel) logrus.SetFormatter(&logFormat{enableColor: false}) - testTunnel(t, true, nil) // test plain text + testTunnel(t, "udp", true, nil) // test plain text - testTunnel(t, false, nil) // test normal + testTunnel(t, "udp", false, nil) // test normal var buf [32]byte _, err := rand.Read(buf[:]) if err != nil { panic(err) } - testTunnel(t, false, &buf) // test preshared + testTunnel(t, "udp", false, &buf) // test preshared +} + +func TestTunnelTCP(t *testing.T) { + logrus.SetLevel(logrus.DebugLevel) + logrus.SetFormatter(&logFormat{enableColor: false}) + + testTunnel(t, "tcp", true, nil) // test plain text + + testTunnel(t, "tcp", false, nil) // test normal + + var buf [32]byte + _, err := rand.Read(buf[:]) + if err != nil { + panic(err) + } + testTunnel(t, "tcp", false, &buf) // test preshared } // logFormat specialize for go-cqhttp diff --git a/upper/services/wg/wg.go b/upper/services/wg/wg.go index 0cc4bab..29de02b 100644 --- a/upper/services/wg/wg.go +++ b/upper/services/wg/wg.go @@ -9,6 +9,7 @@ import ( curve "github.com/fumiama/go-x25519" "github.com/sirupsen/logrus" + _ "github.com/fumiama/WireGold/gold/p2p/tcp" // support tcp connection _ "github.com/fumiama/WireGold/gold/p2p/udp" // support udp connection "github.com/fumiama/WireGold/config" @@ -36,7 +37,7 @@ func NewWireGold(c *config.Config) (wg WG, err error) { } n := copy(wg.key[:], base14.Decode(k)) if n != 32 { - err = errors.New("private key length is not 32") + err = errors.New("private key length != 32, got " + strconv.Itoa(n)) return }