1
0
mirror of https://github.com/fumiama/WireGold.git synced 2026-06-27 14:20:27 +08:00

feat(p2p): support tcp protocol

This commit is contained in:
源文雨
2024-07-16 21:38:45 +09:00
parent 17e1f6cac9
commit 739cf863f1
19 changed files with 393 additions and 26 deletions

View File

@@ -8,6 +8,10 @@ import (
"github.com/RomiChan/syncx"
)
var (
ErrEndpointTypeMistatch = errors.New("endpoint type mismatch")
)
type Initializer func(endpoint string, configs ...any) EndPoint
var factory syncx.Map[string, Initializer]

41
gold/p2p/tcp/init.go Normal file
View File

@@ -0,0 +1,41 @@
package tcp
import (
"net"
"net/netip"
"time"
"github.com/fumiama/WireGold/gold/p2p"
)
type Config struct {
PeersTimeout time.Duration
ReceiveChannelSize int
}
func NewEndpoint(endpoint string, configs ...any) p2p.EndPoint {
return newEndpoint(endpoint, configs...)
}
func newEndpoint(endpoint string, configs ...any) *EndPoint {
var cfg *Config
if len(configs) == 0 || configs[0] == nil {
cfg = &Config{}
} else {
cfg = configs[0].(*Config)
}
return &EndPoint{
addr: net.TCPAddrFromAddrPort(
netip.MustParseAddrPort(endpoint),
),
peerstimeout: cfg.PeersTimeout,
recvchansize: cfg.ReceiveChannelSize,
}
}
func init() {
_, hasexist := p2p.Register("tcp", NewEndpoint)
if hasexist {
panic("network tcp has been registered")
}
}

65
gold/p2p/tcp/pdu.go Normal file
View File

@@ -0,0 +1,65 @@
package tcp
import (
"encoding/binary"
"io"
"net"
"github.com/fumiama/WireGold/helper"
)
type packetType uint8
const (
packetTypeKeepAlive packetType = iota
packetTypeNormal
)
type packet struct {
typ packetType
len uint16
dat []byte
io.ReaderFrom
io.WriterTo
}
func (p *packet) pack() (net.Buffers, func()) {
d, cl := helper.OpenWriterF(func(w *helper.Writer) {
w.WriteByte(byte(p.typ))
w.WriteUInt16(p.len)
})
return net.Buffers{d, p.dat}, cl
}
func (p *packet) Read(_ []byte) (int, error) {
panic("stub")
}
func (p *packet) Write(_ []byte) (int, error) {
panic("stub")
}
func (p *packet) ReadFrom(r io.Reader) (n int64, err error) {
var buf [3]byte
cnt, err := io.ReadFull(r, buf[:])
n = int64(cnt)
if err != nil {
return
}
p.typ = packetType(buf[0])
p.len = binary.LittleEndian.Uint16(buf[1:3])
w := helper.SelectWriter()
copied, err := io.CopyN(w, r, int64(p.len))
n += copied
if err != nil {
return
}
p.dat = w.Bytes()
return
}
func (p *packet) WriteTo(w io.Writer) (n int64, err error) {
buf, cl := p.pack()
defer cl()
return io.Copy(w, &buf)
}

219
gold/p2p/tcp/tcp.go Normal file
View File

@@ -0,0 +1,219 @@
package tcp
import (
"errors"
"io"
"math/rand"
"net"
"strconv"
"time"
"github.com/FloatTech/ttl"
"github.com/fumiama/WireGold/gold/p2p"
"github.com/fumiama/WireGold/helper"
"github.com/sirupsen/logrus"
)
type EndPoint struct {
addr *net.TCPAddr
peerstimeout time.Duration
recvchansize int
}
func (ep *EndPoint) String() string {
return ep.addr.String()
}
func (ep *EndPoint) Network() string {
return ep.addr.Network()
}
func (ep *EndPoint) Euqal(ep2 p2p.EndPoint) bool {
tcpep2, ok := ep2.(*EndPoint)
if !ok {
return false
}
tcpep1 := ep
return tcpep1.addr.IP.Equal(tcpep2.addr.IP) &&
tcpep1.addr.Port == tcpep2.addr.Port &&
tcpep1.addr.Zone == tcpep2.addr.Zone
}
func (ep *EndPoint) Listen() (p2p.Conn, error) {
lstn, err := net.ListenTCP(ep.addr.Network(), ep.addr)
if err != nil {
return nil, err
}
ep.addr = lstn.Addr().(*net.TCPAddr)
timeout := ep.peerstimeout
if timeout < time.Second {
timeout = time.Second
}
chansz := ep.recvchansize
if chansz < 32 {
chansz = 32
}
conn := &Conn{
addr: ep,
lstn: lstn,
peers: ttl.NewCacheOn(timeout, [4]func(string, *net.TCPConn){
nil,
nil,
func(s string, t *net.TCPConn) {
err := t.Close()
if err != nil {
logrus.Debugln("[tcp] close conn from", ep, "to", s, "err:", err)
} else {
logrus.Debugln("[tcp] close conn from", ep, "to", s)
}
},
ep.keepAlive,
}),
recv: make(chan *connrecv, chansz),
}
go conn.accept()
return conn, nil
}
func (ep *EndPoint) keepAlive(_ string, t *net.TCPConn) {
_, err := io.Copy(t, &packet{
typ: packetTypeKeepAlive,
len: 1,
dat: []byte{byte(rand.Intn(256))},
})
if err != nil {
logrus.Debugln("[tcp] write keepalive from", ep, "to conn", t.RemoteAddr(), "err:", err)
}
}
type connrecv struct {
addr *EndPoint // cast from tcpconn.RemoteAddr()
pckt packet
}
// Conn 伪装成无状态的有状态连接
type Conn struct {
addr *EndPoint
lstn *net.TCPListener
peers *ttl.Cache[string, *net.TCPConn]
recv chan *connrecv
}
func (conn *Conn) accept() {
for {
tcpconn, err := conn.lstn.AcceptTCP()
if err != nil {
if errors.Is(err, net.ErrClosed) { // normal close
logrus.Infoln("[tcp] accept of", conn.addr, "got closed")
return
}
if conn.addr == nil || conn.lstn == nil || conn.peers == nil || conn.recv == nil {
return
}
logrus.Warnln("[tcp] accept on", conn.addr, "err:", err)
_ = conn.Close()
newc, err := conn.addr.Listen()
if err != nil {
logrus.Warn("[tcp] re-listen on", conn.addr, "err:", err)
return
}
*conn = *newc.(*Conn)
logrus.Info("[tcp] re-listen on", conn.addr)
continue
}
ep := newEndpoint(tcpconn.RemoteAddr().String(), &Config{
PeersTimeout: conn.addr.peerstimeout,
ReceiveChannelSize: conn.addr.recvchansize,
})
logrus.Debugln("[tcp] accept from", ep)
conn.peers.Set(ep.String(), tcpconn)
go conn.receive(ep)
}
}
func (conn *Conn) receive(ep *EndPoint) {
for {
r := &connrecv{addr: ep}
if conn.addr == nil || conn.lstn == nil || conn.peers == nil || conn.recv == nil {
return
}
tcpconn := conn.peers.Get(ep.String())
if tcpconn == nil {
return
}
_, err := io.Copy(&r.pckt, tcpconn)
if err != nil {
logrus.Debugln("[tcp] recv from", ep, "err:", err)
return
}
logrus.Debugln("[tcp] dispatch packet from", ep, "typ", r.pckt.typ, "len", r.pckt.len)
conn.recv <- r
}
}
func (conn *Conn) Close() error {
if conn.lstn != nil {
_ = conn.lstn.Close()
}
if conn.peers != nil {
conn.peers.Destroy()
}
if conn.recv != nil {
close(conn.recv)
}
conn.addr = nil
conn.lstn = nil
conn.peers = nil
conn.recv = nil
return nil
}
func (conn *Conn) String() string {
return conn.addr.String()
}
func (conn *Conn) LocalAddr() p2p.EndPoint {
return conn.addr
}
func (conn *Conn) ReadFromPeer(b []byte) (int, p2p.EndPoint, error) {
var p *connrecv
for {
p = <-conn.recv
if p == nil {
return 0, nil, net.ErrClosed
}
if p.pckt.typ == packetTypeNormal {
break
}
defer helper.PutBytes(p.pckt.dat)
}
n := copy(b, p.pckt.dat)
return n, p.addr, nil
}
func (conn *Conn) WriteToPeer(b []byte, ep p2p.EndPoint) (n int, err error) {
tcpep, ok := ep.(*EndPoint)
if !ok {
return 0, p2p.ErrEndpointTypeMistatch
}
blen := len(b)
if blen >= 65536 {
return 0, errors.New("data size " + strconv.Itoa(blen) + " is too large")
}
tcpconn := conn.peers.Get(tcpep.String())
if tcpconn == nil {
// must use another port to send because there's no exsiting conn
tcpconn, err = net.DialTCP(tcpep.Network(), nil, tcpep.addr)
if err != nil {
return
}
conn.peers.Set(tcpep.String(), tcpconn)
}
cnt, err := io.Copy(tcpconn, &packet{
typ: packetTypeNormal,
len: uint16(blen),
dat: b,
})
return int(cnt) - 3, err
}

View File

@@ -1,17 +1,12 @@
package udp
import (
"errors"
"net"
"net/netip"
"github.com/fumiama/WireGold/gold/p2p"
)
var (
ErrEndpointTypeMistatch = errors.New("endpoint type mismatch")
)
func NewEndpoint(endpoint string, _ ...any) p2p.EndPoint {
return (*EndPoint)(net.UDPAddrFromAddrPort(
netip.MustParseAddrPort(endpoint),

View File

@@ -52,7 +52,7 @@ func (conn *Conn) ReadFromPeer(b []byte) (int, p2p.EndPoint, error) {
func (conn *Conn) WriteToPeer(b []byte, ep p2p.EndPoint) (int, error) {
udpep, ok := ep.(*EndPoint)
if !ok {
return 0, ErrEndpointTypeMistatch
return 0, p2p.ErrEndpointTypeMistatch
}
return (*net.UDPConn)(conn).WriteTo(b, (*net.UDPAddr)(udpep))
}