mirror of
https://github.com/fumiama/WireGold.git
synced 2026-06-21 19:13:20 +08:00
feat(p2p): support tcp protocol
This commit is contained in:
@@ -12,6 +12,11 @@ import (
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrBadCRCChecksum = errors.New("bad crc checksum")
|
||||
ErrDataLenLT60 = errors.New("data len < 60")
|
||||
)
|
||||
|
||||
type PacketFlags uint16
|
||||
|
||||
func (pf PacketFlags) IsValid() bool {
|
||||
@@ -97,12 +102,12 @@ func NewPacket(proto uint8, srcPort uint16, dst net.IP, dstPort uint16, data []b
|
||||
// Unmarshal 将 data 的数据解码到自身
|
||||
func (p *Packet) Unmarshal(data []byte) (complete bool, err error) {
|
||||
if len(data) < 60 {
|
||||
err = errors.New("data len < 60")
|
||||
err = ErrDataLenLT60
|
||||
return
|
||||
}
|
||||
p.crc64 = binary.LittleEndian.Uint64(data[52:60])
|
||||
if crc64.Checksum(data[:52], crc64.MakeTable(crc64.ISO)) != p.crc64 {
|
||||
err = errors.New("bad crc checksum")
|
||||
err = ErrBadCRCChecksum
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,10 @@ import (
|
||||
base14 "github.com/fumiama/go-base16384"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrPerrNotExist = errors.New("peer not exist")
|
||||
)
|
||||
|
||||
// Link 是本机到 peer 的连接抽象
|
||||
type Link struct {
|
||||
// peer 的公钥
|
||||
@@ -56,7 +60,7 @@ func (m *Me) Connect(peer string) (*Link, error) {
|
||||
if ok {
|
||||
return p, nil
|
||||
}
|
||||
return nil, errors.New("peer not exist")
|
||||
return nil, ErrPerrNotExist
|
||||
}
|
||||
|
||||
// Close 关闭到 peer 的连接
|
||||
|
||||
@@ -56,7 +56,7 @@ func (m *Me) listen() (conn p2p.Conn, err error) {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
logrus.Warnln("[listen] read from udp err, reconnect:", err)
|
||||
logrus.Warnln("[listen] read from conn err, reconnect:", err)
|
||||
conn, err = m.ep.Listen()
|
||||
if err != nil {
|
||||
logrus.Errorln("[listen] reconnect udp err:", err)
|
||||
|
||||
@@ -51,12 +51,15 @@ type Me struct {
|
||||
srcport, dstport, mtu, speedloop uint16
|
||||
// 报头掩码
|
||||
mask uint64
|
||||
// 本机网络端点初始化配置
|
||||
networkconfigs []any
|
||||
}
|
||||
|
||||
type MyConfig struct {
|
||||
MyIPwithMask string
|
||||
MyEndpoint string
|
||||
Network string
|
||||
NetworkConfigs []any
|
||||
PrivateKey *[32]byte
|
||||
NIC lower.NICIO
|
||||
SrcPort, DstPort, MTU, SpeedLoop uint16
|
||||
@@ -71,7 +74,8 @@ func NewMe(cfg *MyConfig) (m Me) {
|
||||
if nw == "" {
|
||||
nw = "udp"
|
||||
}
|
||||
m.ep, err = p2p.NewEndPoint(nw, cfg.MyEndpoint)
|
||||
m.networkconfigs = cfg.NetworkConfigs
|
||||
m.ep, err = p2p.NewEndPoint(nw, cfg.MyEndpoint, m.networkconfigs...)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
@@ -44,7 +44,12 @@ func (l *Link) onNotify(packet []byte) {
|
||||
// ---- 遍历 Notify,注册对方的 endpoint 到
|
||||
// ---- connections,注意使用读写锁connmapmu
|
||||
for peer, ep := range notify {
|
||||
addr, err := p2p.NewEndPoint(ep[0], ep[1])
|
||||
nw, epstr := ep[0], ep[1]
|
||||
if nw != l.me.ep.Network() {
|
||||
logrus.Warnln("[nat] ignore different network notify", nw, "addr", epstr)
|
||||
continue
|
||||
}
|
||||
addr, err := p2p.NewEndPoint(nw, epstr, l.me.networkconfigs...)
|
||||
if err == nil {
|
||||
p, ok := l.me.IsInPeer(peer)
|
||||
if ok {
|
||||
|
||||
@@ -73,7 +73,7 @@ func (m *Me) AddPeer(cfg *PeerConfig) (l *Link) {
|
||||
}
|
||||
}
|
||||
if cfg.EndPoint != "" {
|
||||
e, err := p2p.NewEndPoint(m.ep.Network(), cfg.EndPoint)
|
||||
e, err := p2p.NewEndPoint(m.ep.Network(), cfg.EndPoint, m.networkconfigs...)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
@@ -14,6 +14,11 @@ import (
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrDropBigDontFragTransPkt = errors.New("drop big don't fragmnet trans packet")
|
||||
ErrTTL = errors.New("ttl exceeded")
|
||||
)
|
||||
|
||||
// WriteAndPut 向 peer 发包并将包放回缓存池
|
||||
func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) {
|
||||
defer p.Put()
|
||||
@@ -37,7 +42,7 @@ func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) {
|
||||
return l.write(p, teatype, sndcnt, uint32(remlen), 0, istransfer, false)
|
||||
}
|
||||
if istransfer && p.Flags.DontFrag() && remlen > delta {
|
||||
return 0, errors.New("drop don't fragmnet big trans packet")
|
||||
return 0, ErrDropBigDontFragTransPkt
|
||||
}
|
||||
ttl := p.TTL
|
||||
totl := uint32(remlen)
|
||||
@@ -93,11 +98,11 @@ func (l *Link) write(p *head.Packet, teatype uint8, additional uint16, datasz ui
|
||||
d, cl = p.Marshal(l.me.me, teatype, additional, datasz, offset, false, hasmore)
|
||||
}
|
||||
if d == nil {
|
||||
return 0, errors.New("[send] ttl exceeded")
|
||||
return 0, ErrTTL
|
||||
}
|
||||
peerep := l.endpoint
|
||||
if peerep == nil {
|
||||
return 0, errors.New("[send] nil endpoint of " + p.Dst.String())
|
||||
return 0, errors.New("nil endpoint of " + p.Dst.String())
|
||||
}
|
||||
bound := 64
|
||||
endl := "..."
|
||||
|
||||
@@ -8,6 +8,10 @@ import (
|
||||
"github.com/RomiChan/syncx"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrEndpointTypeMistatch = errors.New("endpoint type mismatch")
|
||||
)
|
||||
|
||||
type Initializer func(endpoint string, configs ...any) EndPoint
|
||||
|
||||
var factory syncx.Map[string, Initializer]
|
||||
|
||||
41
gold/p2p/tcp/init.go
Normal file
41
gold/p2p/tcp/init.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/fumiama/WireGold/gold/p2p"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
PeersTimeout time.Duration
|
||||
ReceiveChannelSize int
|
||||
}
|
||||
|
||||
func NewEndpoint(endpoint string, configs ...any) p2p.EndPoint {
|
||||
return newEndpoint(endpoint, configs...)
|
||||
}
|
||||
|
||||
func newEndpoint(endpoint string, configs ...any) *EndPoint {
|
||||
var cfg *Config
|
||||
if len(configs) == 0 || configs[0] == nil {
|
||||
cfg = &Config{}
|
||||
} else {
|
||||
cfg = configs[0].(*Config)
|
||||
}
|
||||
return &EndPoint{
|
||||
addr: net.TCPAddrFromAddrPort(
|
||||
netip.MustParseAddrPort(endpoint),
|
||||
),
|
||||
peerstimeout: cfg.PeersTimeout,
|
||||
recvchansize: cfg.ReceiveChannelSize,
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
_, hasexist := p2p.Register("tcp", NewEndpoint)
|
||||
if hasexist {
|
||||
panic("network tcp has been registered")
|
||||
}
|
||||
}
|
||||
65
gold/p2p/tcp/pdu.go
Normal file
65
gold/p2p/tcp/pdu.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/fumiama/WireGold/helper"
|
||||
)
|
||||
|
||||
type packetType uint8
|
||||
|
||||
const (
|
||||
packetTypeKeepAlive packetType = iota
|
||||
packetTypeNormal
|
||||
)
|
||||
|
||||
type packet struct {
|
||||
typ packetType
|
||||
len uint16
|
||||
dat []byte
|
||||
io.ReaderFrom
|
||||
io.WriterTo
|
||||
}
|
||||
|
||||
func (p *packet) pack() (net.Buffers, func()) {
|
||||
d, cl := helper.OpenWriterF(func(w *helper.Writer) {
|
||||
w.WriteByte(byte(p.typ))
|
||||
w.WriteUInt16(p.len)
|
||||
})
|
||||
return net.Buffers{d, p.dat}, cl
|
||||
}
|
||||
|
||||
func (p *packet) Read(_ []byte) (int, error) {
|
||||
panic("stub")
|
||||
}
|
||||
|
||||
func (p *packet) Write(_ []byte) (int, error) {
|
||||
panic("stub")
|
||||
}
|
||||
|
||||
func (p *packet) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
var buf [3]byte
|
||||
cnt, err := io.ReadFull(r, buf[:])
|
||||
n = int64(cnt)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
p.typ = packetType(buf[0])
|
||||
p.len = binary.LittleEndian.Uint16(buf[1:3])
|
||||
w := helper.SelectWriter()
|
||||
copied, err := io.CopyN(w, r, int64(p.len))
|
||||
n += copied
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
p.dat = w.Bytes()
|
||||
return
|
||||
}
|
||||
|
||||
func (p *packet) WriteTo(w io.Writer) (n int64, err error) {
|
||||
buf, cl := p.pack()
|
||||
defer cl()
|
||||
return io.Copy(w, &buf)
|
||||
}
|
||||
219
gold/p2p/tcp/tcp.go
Normal file
219
gold/p2p/tcp/tcp.go
Normal file
@@ -0,0 +1,219 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/FloatTech/ttl"
|
||||
"github.com/fumiama/WireGold/gold/p2p"
|
||||
"github.com/fumiama/WireGold/helper"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type EndPoint struct {
|
||||
addr *net.TCPAddr
|
||||
peerstimeout time.Duration
|
||||
recvchansize int
|
||||
}
|
||||
|
||||
func (ep *EndPoint) String() string {
|
||||
return ep.addr.String()
|
||||
}
|
||||
|
||||
func (ep *EndPoint) Network() string {
|
||||
return ep.addr.Network()
|
||||
}
|
||||
|
||||
func (ep *EndPoint) Euqal(ep2 p2p.EndPoint) bool {
|
||||
tcpep2, ok := ep2.(*EndPoint)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
tcpep1 := ep
|
||||
return tcpep1.addr.IP.Equal(tcpep2.addr.IP) &&
|
||||
tcpep1.addr.Port == tcpep2.addr.Port &&
|
||||
tcpep1.addr.Zone == tcpep2.addr.Zone
|
||||
}
|
||||
|
||||
func (ep *EndPoint) Listen() (p2p.Conn, error) {
|
||||
lstn, err := net.ListenTCP(ep.addr.Network(), ep.addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ep.addr = lstn.Addr().(*net.TCPAddr)
|
||||
timeout := ep.peerstimeout
|
||||
if timeout < time.Second {
|
||||
timeout = time.Second
|
||||
}
|
||||
chansz := ep.recvchansize
|
||||
if chansz < 32 {
|
||||
chansz = 32
|
||||
}
|
||||
conn := &Conn{
|
||||
addr: ep,
|
||||
lstn: lstn,
|
||||
peers: ttl.NewCacheOn(timeout, [4]func(string, *net.TCPConn){
|
||||
nil,
|
||||
nil,
|
||||
func(s string, t *net.TCPConn) {
|
||||
err := t.Close()
|
||||
if err != nil {
|
||||
logrus.Debugln("[tcp] close conn from", ep, "to", s, "err:", err)
|
||||
} else {
|
||||
logrus.Debugln("[tcp] close conn from", ep, "to", s)
|
||||
}
|
||||
},
|
||||
ep.keepAlive,
|
||||
}),
|
||||
recv: make(chan *connrecv, chansz),
|
||||
}
|
||||
go conn.accept()
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (ep *EndPoint) keepAlive(_ string, t *net.TCPConn) {
|
||||
_, err := io.Copy(t, &packet{
|
||||
typ: packetTypeKeepAlive,
|
||||
len: 1,
|
||||
dat: []byte{byte(rand.Intn(256))},
|
||||
})
|
||||
if err != nil {
|
||||
logrus.Debugln("[tcp] write keepalive from", ep, "to conn", t.RemoteAddr(), "err:", err)
|
||||
}
|
||||
}
|
||||
|
||||
type connrecv struct {
|
||||
addr *EndPoint // cast from tcpconn.RemoteAddr()
|
||||
pckt packet
|
||||
}
|
||||
|
||||
// Conn 伪装成无状态的有状态连接
|
||||
type Conn struct {
|
||||
addr *EndPoint
|
||||
lstn *net.TCPListener
|
||||
peers *ttl.Cache[string, *net.TCPConn]
|
||||
recv chan *connrecv
|
||||
}
|
||||
|
||||
func (conn *Conn) accept() {
|
||||
for {
|
||||
tcpconn, err := conn.lstn.AcceptTCP()
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) { // normal close
|
||||
logrus.Infoln("[tcp] accept of", conn.addr, "got closed")
|
||||
return
|
||||
}
|
||||
if conn.addr == nil || conn.lstn == nil || conn.peers == nil || conn.recv == nil {
|
||||
return
|
||||
}
|
||||
logrus.Warnln("[tcp] accept on", conn.addr, "err:", err)
|
||||
_ = conn.Close()
|
||||
newc, err := conn.addr.Listen()
|
||||
if err != nil {
|
||||
logrus.Warn("[tcp] re-listen on", conn.addr, "err:", err)
|
||||
return
|
||||
}
|
||||
*conn = *newc.(*Conn)
|
||||
logrus.Info("[tcp] re-listen on", conn.addr)
|
||||
continue
|
||||
}
|
||||
ep := newEndpoint(tcpconn.RemoteAddr().String(), &Config{
|
||||
PeersTimeout: conn.addr.peerstimeout,
|
||||
ReceiveChannelSize: conn.addr.recvchansize,
|
||||
})
|
||||
logrus.Debugln("[tcp] accept from", ep)
|
||||
conn.peers.Set(ep.String(), tcpconn)
|
||||
go conn.receive(ep)
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) receive(ep *EndPoint) {
|
||||
for {
|
||||
r := &connrecv{addr: ep}
|
||||
if conn.addr == nil || conn.lstn == nil || conn.peers == nil || conn.recv == nil {
|
||||
return
|
||||
}
|
||||
tcpconn := conn.peers.Get(ep.String())
|
||||
if tcpconn == nil {
|
||||
return
|
||||
}
|
||||
_, err := io.Copy(&r.pckt, tcpconn)
|
||||
if err != nil {
|
||||
logrus.Debugln("[tcp] recv from", ep, "err:", err)
|
||||
return
|
||||
}
|
||||
logrus.Debugln("[tcp] dispatch packet from", ep, "typ", r.pckt.typ, "len", r.pckt.len)
|
||||
conn.recv <- r
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) Close() error {
|
||||
if conn.lstn != nil {
|
||||
_ = conn.lstn.Close()
|
||||
}
|
||||
if conn.peers != nil {
|
||||
conn.peers.Destroy()
|
||||
}
|
||||
if conn.recv != nil {
|
||||
close(conn.recv)
|
||||
}
|
||||
conn.addr = nil
|
||||
conn.lstn = nil
|
||||
conn.peers = nil
|
||||
conn.recv = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (conn *Conn) String() string {
|
||||
return conn.addr.String()
|
||||
}
|
||||
|
||||
func (conn *Conn) LocalAddr() p2p.EndPoint {
|
||||
return conn.addr
|
||||
}
|
||||
|
||||
func (conn *Conn) ReadFromPeer(b []byte) (int, p2p.EndPoint, error) {
|
||||
var p *connrecv
|
||||
for {
|
||||
p = <-conn.recv
|
||||
if p == nil {
|
||||
return 0, nil, net.ErrClosed
|
||||
}
|
||||
if p.pckt.typ == packetTypeNormal {
|
||||
break
|
||||
}
|
||||
defer helper.PutBytes(p.pckt.dat)
|
||||
}
|
||||
n := copy(b, p.pckt.dat)
|
||||
return n, p.addr, nil
|
||||
}
|
||||
|
||||
func (conn *Conn) WriteToPeer(b []byte, ep p2p.EndPoint) (n int, err error) {
|
||||
tcpep, ok := ep.(*EndPoint)
|
||||
if !ok {
|
||||
return 0, p2p.ErrEndpointTypeMistatch
|
||||
}
|
||||
blen := len(b)
|
||||
if blen >= 65536 {
|
||||
return 0, errors.New("data size " + strconv.Itoa(blen) + " is too large")
|
||||
}
|
||||
tcpconn := conn.peers.Get(tcpep.String())
|
||||
if tcpconn == nil {
|
||||
// must use another port to send because there's no exsiting conn
|
||||
tcpconn, err = net.DialTCP(tcpep.Network(), nil, tcpep.addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
conn.peers.Set(tcpep.String(), tcpconn)
|
||||
}
|
||||
cnt, err := io.Copy(tcpconn, &packet{
|
||||
typ: packetTypeNormal,
|
||||
len: uint16(blen),
|
||||
dat: b,
|
||||
})
|
||||
return int(cnt) - 3, err
|
||||
}
|
||||
@@ -1,17 +1,12 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/fumiama/WireGold/gold/p2p"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrEndpointTypeMistatch = errors.New("endpoint type mismatch")
|
||||
)
|
||||
|
||||
func NewEndpoint(endpoint string, _ ...any) p2p.EndPoint {
|
||||
return (*EndPoint)(net.UDPAddrFromAddrPort(
|
||||
netip.MustParseAddrPort(endpoint),
|
||||
|
||||
@@ -52,7 +52,7 @@ func (conn *Conn) ReadFromPeer(b []byte) (int, p2p.EndPoint, error) {
|
||||
func (conn *Conn) WriteToPeer(b []byte, ep p2p.EndPoint) (int, error) {
|
||||
udpep, ok := ep.(*EndPoint)
|
||||
if !ok {
|
||||
return 0, ErrEndpointTypeMistatch
|
||||
return 0, p2p.ErrEndpointTypeMistatch
|
||||
}
|
||||
return (*net.UDPConn)(conn).WriteTo(b, (*net.UDPAddr)(udpep))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user