1
0
mirror of https://github.com/fumiama/WireGold.git synced 2026-06-21 19:13:20 +08:00

feat(p2p): support tcp protocol

This commit is contained in:
源文雨
2024-07-16 21:38:45 +09:00
parent 17e1f6cac9
commit 739cf863f1
19 changed files with 393 additions and 26 deletions

View File

@@ -12,6 +12,11 @@ import (
"github.com/sirupsen/logrus"
)
var (
ErrBadCRCChecksum = errors.New("bad crc checksum")
ErrDataLenLT60 = errors.New("data len < 60")
)
type PacketFlags uint16
func (pf PacketFlags) IsValid() bool {
@@ -97,12 +102,12 @@ func NewPacket(proto uint8, srcPort uint16, dst net.IP, dstPort uint16, data []b
// Unmarshal 将 data 的数据解码到自身
func (p *Packet) Unmarshal(data []byte) (complete bool, err error) {
if len(data) < 60 {
err = errors.New("data len < 60")
err = ErrDataLenLT60
return
}
p.crc64 = binary.LittleEndian.Uint64(data[52:60])
if crc64.Checksum(data[:52], crc64.MakeTable(crc64.ISO)) != p.crc64 {
err = errors.New("bad crc checksum")
err = ErrBadCRCChecksum
return
}

View File

@@ -12,6 +12,10 @@ import (
base14 "github.com/fumiama/go-base16384"
)
var (
ErrPerrNotExist = errors.New("peer not exist")
)
// Link 是本机到 peer 的连接抽象
type Link struct {
// peer 的公钥
@@ -56,7 +60,7 @@ func (m *Me) Connect(peer string) (*Link, error) {
if ok {
return p, nil
}
return nil, errors.New("peer not exist")
return nil, ErrPerrNotExist
}
// Close 关闭到 peer 的连接

View File

@@ -56,7 +56,7 @@ func (m *Me) listen() (conn p2p.Conn, err error) {
return
}
if err != nil {
logrus.Warnln("[listen] read from udp err, reconnect:", err)
logrus.Warnln("[listen] read from conn err, reconnect:", err)
conn, err = m.ep.Listen()
if err != nil {
logrus.Errorln("[listen] reconnect udp err:", err)

View File

@@ -51,12 +51,15 @@ type Me struct {
srcport, dstport, mtu, speedloop uint16
// 报头掩码
mask uint64
// 本机网络端点初始化配置
networkconfigs []any
}
type MyConfig struct {
MyIPwithMask string
MyEndpoint string
Network string
NetworkConfigs []any
PrivateKey *[32]byte
NIC lower.NICIO
SrcPort, DstPort, MTU, SpeedLoop uint16
@@ -71,7 +74,8 @@ func NewMe(cfg *MyConfig) (m Me) {
if nw == "" {
nw = "udp"
}
m.ep, err = p2p.NewEndPoint(nw, cfg.MyEndpoint)
m.networkconfigs = cfg.NetworkConfigs
m.ep, err = p2p.NewEndPoint(nw, cfg.MyEndpoint, m.networkconfigs...)
if err != nil {
panic(err)
}

View File

@@ -44,7 +44,12 @@ func (l *Link) onNotify(packet []byte) {
// ---- 遍历 Notify注册对方的 endpoint 到
// ---- connections注意使用读写锁connmapmu
for peer, ep := range notify {
addr, err := p2p.NewEndPoint(ep[0], ep[1])
nw, epstr := ep[0], ep[1]
if nw != l.me.ep.Network() {
logrus.Warnln("[nat] ignore different network notify", nw, "addr", epstr)
continue
}
addr, err := p2p.NewEndPoint(nw, epstr, l.me.networkconfigs...)
if err == nil {
p, ok := l.me.IsInPeer(peer)
if ok {

View File

@@ -73,7 +73,7 @@ func (m *Me) AddPeer(cfg *PeerConfig) (l *Link) {
}
}
if cfg.EndPoint != "" {
e, err := p2p.NewEndPoint(m.ep.Network(), cfg.EndPoint)
e, err := p2p.NewEndPoint(m.ep.Network(), cfg.EndPoint, m.networkconfigs...)
if err != nil {
panic(err)
}

View File

@@ -14,6 +14,11 @@ import (
"github.com/sirupsen/logrus"
)
var (
ErrDropBigDontFragTransPkt = errors.New("drop big don't fragmnet trans packet")
ErrTTL = errors.New("ttl exceeded")
)
// WriteAndPut 向 peer 发包并将包放回缓存池
func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) {
defer p.Put()
@@ -37,7 +42,7 @@ func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) {
return l.write(p, teatype, sndcnt, uint32(remlen), 0, istransfer, false)
}
if istransfer && p.Flags.DontFrag() && remlen > delta {
return 0, errors.New("drop don't fragmnet big trans packet")
return 0, ErrDropBigDontFragTransPkt
}
ttl := p.TTL
totl := uint32(remlen)
@@ -93,11 +98,11 @@ func (l *Link) write(p *head.Packet, teatype uint8, additional uint16, datasz ui
d, cl = p.Marshal(l.me.me, teatype, additional, datasz, offset, false, hasmore)
}
if d == nil {
return 0, errors.New("[send] ttl exceeded")
return 0, ErrTTL
}
peerep := l.endpoint
if peerep == nil {
return 0, errors.New("[send] nil endpoint of " + p.Dst.String())
return 0, errors.New("nil endpoint of " + p.Dst.String())
}
bound := 64
endl := "..."

View File

@@ -8,6 +8,10 @@ import (
"github.com/RomiChan/syncx"
)
var (
ErrEndpointTypeMistatch = errors.New("endpoint type mismatch")
)
type Initializer func(endpoint string, configs ...any) EndPoint
var factory syncx.Map[string, Initializer]

41
gold/p2p/tcp/init.go Normal file
View File

@@ -0,0 +1,41 @@
package tcp
import (
"net"
"net/netip"
"time"
"github.com/fumiama/WireGold/gold/p2p"
)
type Config struct {
PeersTimeout time.Duration
ReceiveChannelSize int
}
func NewEndpoint(endpoint string, configs ...any) p2p.EndPoint {
return newEndpoint(endpoint, configs...)
}
func newEndpoint(endpoint string, configs ...any) *EndPoint {
var cfg *Config
if len(configs) == 0 || configs[0] == nil {
cfg = &Config{}
} else {
cfg = configs[0].(*Config)
}
return &EndPoint{
addr: net.TCPAddrFromAddrPort(
netip.MustParseAddrPort(endpoint),
),
peerstimeout: cfg.PeersTimeout,
recvchansize: cfg.ReceiveChannelSize,
}
}
func init() {
_, hasexist := p2p.Register("tcp", NewEndpoint)
if hasexist {
panic("network tcp has been registered")
}
}

65
gold/p2p/tcp/pdu.go Normal file
View File

@@ -0,0 +1,65 @@
package tcp
import (
"encoding/binary"
"io"
"net"
"github.com/fumiama/WireGold/helper"
)
type packetType uint8
const (
packetTypeKeepAlive packetType = iota
packetTypeNormal
)
type packet struct {
typ packetType
len uint16
dat []byte
io.ReaderFrom
io.WriterTo
}
func (p *packet) pack() (net.Buffers, func()) {
d, cl := helper.OpenWriterF(func(w *helper.Writer) {
w.WriteByte(byte(p.typ))
w.WriteUInt16(p.len)
})
return net.Buffers{d, p.dat}, cl
}
func (p *packet) Read(_ []byte) (int, error) {
panic("stub")
}
func (p *packet) Write(_ []byte) (int, error) {
panic("stub")
}
func (p *packet) ReadFrom(r io.Reader) (n int64, err error) {
var buf [3]byte
cnt, err := io.ReadFull(r, buf[:])
n = int64(cnt)
if err != nil {
return
}
p.typ = packetType(buf[0])
p.len = binary.LittleEndian.Uint16(buf[1:3])
w := helper.SelectWriter()
copied, err := io.CopyN(w, r, int64(p.len))
n += copied
if err != nil {
return
}
p.dat = w.Bytes()
return
}
func (p *packet) WriteTo(w io.Writer) (n int64, err error) {
buf, cl := p.pack()
defer cl()
return io.Copy(w, &buf)
}

219
gold/p2p/tcp/tcp.go Normal file
View File

@@ -0,0 +1,219 @@
package tcp
import (
"errors"
"io"
"math/rand"
"net"
"strconv"
"time"
"github.com/FloatTech/ttl"
"github.com/fumiama/WireGold/gold/p2p"
"github.com/fumiama/WireGold/helper"
"github.com/sirupsen/logrus"
)
type EndPoint struct {
addr *net.TCPAddr
peerstimeout time.Duration
recvchansize int
}
func (ep *EndPoint) String() string {
return ep.addr.String()
}
func (ep *EndPoint) Network() string {
return ep.addr.Network()
}
func (ep *EndPoint) Euqal(ep2 p2p.EndPoint) bool {
tcpep2, ok := ep2.(*EndPoint)
if !ok {
return false
}
tcpep1 := ep
return tcpep1.addr.IP.Equal(tcpep2.addr.IP) &&
tcpep1.addr.Port == tcpep2.addr.Port &&
tcpep1.addr.Zone == tcpep2.addr.Zone
}
func (ep *EndPoint) Listen() (p2p.Conn, error) {
lstn, err := net.ListenTCP(ep.addr.Network(), ep.addr)
if err != nil {
return nil, err
}
ep.addr = lstn.Addr().(*net.TCPAddr)
timeout := ep.peerstimeout
if timeout < time.Second {
timeout = time.Second
}
chansz := ep.recvchansize
if chansz < 32 {
chansz = 32
}
conn := &Conn{
addr: ep,
lstn: lstn,
peers: ttl.NewCacheOn(timeout, [4]func(string, *net.TCPConn){
nil,
nil,
func(s string, t *net.TCPConn) {
err := t.Close()
if err != nil {
logrus.Debugln("[tcp] close conn from", ep, "to", s, "err:", err)
} else {
logrus.Debugln("[tcp] close conn from", ep, "to", s)
}
},
ep.keepAlive,
}),
recv: make(chan *connrecv, chansz),
}
go conn.accept()
return conn, nil
}
func (ep *EndPoint) keepAlive(_ string, t *net.TCPConn) {
_, err := io.Copy(t, &packet{
typ: packetTypeKeepAlive,
len: 1,
dat: []byte{byte(rand.Intn(256))},
})
if err != nil {
logrus.Debugln("[tcp] write keepalive from", ep, "to conn", t.RemoteAddr(), "err:", err)
}
}
type connrecv struct {
addr *EndPoint // cast from tcpconn.RemoteAddr()
pckt packet
}
// Conn 伪装成无状态的有状态连接
type Conn struct {
addr *EndPoint
lstn *net.TCPListener
peers *ttl.Cache[string, *net.TCPConn]
recv chan *connrecv
}
func (conn *Conn) accept() {
for {
tcpconn, err := conn.lstn.AcceptTCP()
if err != nil {
if errors.Is(err, net.ErrClosed) { // normal close
logrus.Infoln("[tcp] accept of", conn.addr, "got closed")
return
}
if conn.addr == nil || conn.lstn == nil || conn.peers == nil || conn.recv == nil {
return
}
logrus.Warnln("[tcp] accept on", conn.addr, "err:", err)
_ = conn.Close()
newc, err := conn.addr.Listen()
if err != nil {
logrus.Warn("[tcp] re-listen on", conn.addr, "err:", err)
return
}
*conn = *newc.(*Conn)
logrus.Info("[tcp] re-listen on", conn.addr)
continue
}
ep := newEndpoint(tcpconn.RemoteAddr().String(), &Config{
PeersTimeout: conn.addr.peerstimeout,
ReceiveChannelSize: conn.addr.recvchansize,
})
logrus.Debugln("[tcp] accept from", ep)
conn.peers.Set(ep.String(), tcpconn)
go conn.receive(ep)
}
}
func (conn *Conn) receive(ep *EndPoint) {
for {
r := &connrecv{addr: ep}
if conn.addr == nil || conn.lstn == nil || conn.peers == nil || conn.recv == nil {
return
}
tcpconn := conn.peers.Get(ep.String())
if tcpconn == nil {
return
}
_, err := io.Copy(&r.pckt, tcpconn)
if err != nil {
logrus.Debugln("[tcp] recv from", ep, "err:", err)
return
}
logrus.Debugln("[tcp] dispatch packet from", ep, "typ", r.pckt.typ, "len", r.pckt.len)
conn.recv <- r
}
}
func (conn *Conn) Close() error {
if conn.lstn != nil {
_ = conn.lstn.Close()
}
if conn.peers != nil {
conn.peers.Destroy()
}
if conn.recv != nil {
close(conn.recv)
}
conn.addr = nil
conn.lstn = nil
conn.peers = nil
conn.recv = nil
return nil
}
func (conn *Conn) String() string {
return conn.addr.String()
}
func (conn *Conn) LocalAddr() p2p.EndPoint {
return conn.addr
}
func (conn *Conn) ReadFromPeer(b []byte) (int, p2p.EndPoint, error) {
var p *connrecv
for {
p = <-conn.recv
if p == nil {
return 0, nil, net.ErrClosed
}
if p.pckt.typ == packetTypeNormal {
break
}
defer helper.PutBytes(p.pckt.dat)
}
n := copy(b, p.pckt.dat)
return n, p.addr, nil
}
func (conn *Conn) WriteToPeer(b []byte, ep p2p.EndPoint) (n int, err error) {
tcpep, ok := ep.(*EndPoint)
if !ok {
return 0, p2p.ErrEndpointTypeMistatch
}
blen := len(b)
if blen >= 65536 {
return 0, errors.New("data size " + strconv.Itoa(blen) + " is too large")
}
tcpconn := conn.peers.Get(tcpep.String())
if tcpconn == nil {
// must use another port to send because there's no exsiting conn
tcpconn, err = net.DialTCP(tcpep.Network(), nil, tcpep.addr)
if err != nil {
return
}
conn.peers.Set(tcpep.String(), tcpconn)
}
cnt, err := io.Copy(tcpconn, &packet{
typ: packetTypeNormal,
len: uint16(blen),
dat: b,
})
return int(cnt) - 3, err
}

View File

@@ -1,17 +1,12 @@
package udp
import (
"errors"
"net"
"net/netip"
"github.com/fumiama/WireGold/gold/p2p"
)
var (
ErrEndpointTypeMistatch = errors.New("endpoint type mismatch")
)
func NewEndpoint(endpoint string, _ ...any) p2p.EndPoint {
return (*EndPoint)(net.UDPAddrFromAddrPort(
netip.MustParseAddrPort(endpoint),

View File

@@ -52,7 +52,7 @@ func (conn *Conn) ReadFromPeer(b []byte) (int, p2p.EndPoint, error) {
func (conn *Conn) WriteToPeer(b []byte, ep p2p.EndPoint) (int, error) {
udpep, ok := ep.(*EndPoint)
if !ok {
return 0, ErrEndpointTypeMistatch
return 0, p2p.ErrEndpointTypeMistatch
}
return (*net.UDPConn)(conn).WriteTo(b, (*net.UDPAddr)(udpep))
}