mirror of
https://github.com/fumiama/WireGold.git
synced 2026-06-27 14:20:27 +08:00
feat(p2p): support tcp protocol
This commit is contained in:
@@ -8,6 +8,10 @@ import (
|
||||
"github.com/RomiChan/syncx"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrEndpointTypeMistatch = errors.New("endpoint type mismatch")
|
||||
)
|
||||
|
||||
type Initializer func(endpoint string, configs ...any) EndPoint
|
||||
|
||||
var factory syncx.Map[string, Initializer]
|
||||
|
||||
41
gold/p2p/tcp/init.go
Normal file
41
gold/p2p/tcp/init.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/fumiama/WireGold/gold/p2p"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
PeersTimeout time.Duration
|
||||
ReceiveChannelSize int
|
||||
}
|
||||
|
||||
func NewEndpoint(endpoint string, configs ...any) p2p.EndPoint {
|
||||
return newEndpoint(endpoint, configs...)
|
||||
}
|
||||
|
||||
func newEndpoint(endpoint string, configs ...any) *EndPoint {
|
||||
var cfg *Config
|
||||
if len(configs) == 0 || configs[0] == nil {
|
||||
cfg = &Config{}
|
||||
} else {
|
||||
cfg = configs[0].(*Config)
|
||||
}
|
||||
return &EndPoint{
|
||||
addr: net.TCPAddrFromAddrPort(
|
||||
netip.MustParseAddrPort(endpoint),
|
||||
),
|
||||
peerstimeout: cfg.PeersTimeout,
|
||||
recvchansize: cfg.ReceiveChannelSize,
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
_, hasexist := p2p.Register("tcp", NewEndpoint)
|
||||
if hasexist {
|
||||
panic("network tcp has been registered")
|
||||
}
|
||||
}
|
||||
65
gold/p2p/tcp/pdu.go
Normal file
65
gold/p2p/tcp/pdu.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/fumiama/WireGold/helper"
|
||||
)
|
||||
|
||||
type packetType uint8
|
||||
|
||||
const (
|
||||
packetTypeKeepAlive packetType = iota
|
||||
packetTypeNormal
|
||||
)
|
||||
|
||||
type packet struct {
|
||||
typ packetType
|
||||
len uint16
|
||||
dat []byte
|
||||
io.ReaderFrom
|
||||
io.WriterTo
|
||||
}
|
||||
|
||||
func (p *packet) pack() (net.Buffers, func()) {
|
||||
d, cl := helper.OpenWriterF(func(w *helper.Writer) {
|
||||
w.WriteByte(byte(p.typ))
|
||||
w.WriteUInt16(p.len)
|
||||
})
|
||||
return net.Buffers{d, p.dat}, cl
|
||||
}
|
||||
|
||||
func (p *packet) Read(_ []byte) (int, error) {
|
||||
panic("stub")
|
||||
}
|
||||
|
||||
func (p *packet) Write(_ []byte) (int, error) {
|
||||
panic("stub")
|
||||
}
|
||||
|
||||
func (p *packet) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
var buf [3]byte
|
||||
cnt, err := io.ReadFull(r, buf[:])
|
||||
n = int64(cnt)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
p.typ = packetType(buf[0])
|
||||
p.len = binary.LittleEndian.Uint16(buf[1:3])
|
||||
w := helper.SelectWriter()
|
||||
copied, err := io.CopyN(w, r, int64(p.len))
|
||||
n += copied
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
p.dat = w.Bytes()
|
||||
return
|
||||
}
|
||||
|
||||
func (p *packet) WriteTo(w io.Writer) (n int64, err error) {
|
||||
buf, cl := p.pack()
|
||||
defer cl()
|
||||
return io.Copy(w, &buf)
|
||||
}
|
||||
219
gold/p2p/tcp/tcp.go
Normal file
219
gold/p2p/tcp/tcp.go
Normal file
@@ -0,0 +1,219 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/FloatTech/ttl"
|
||||
"github.com/fumiama/WireGold/gold/p2p"
|
||||
"github.com/fumiama/WireGold/helper"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type EndPoint struct {
|
||||
addr *net.TCPAddr
|
||||
peerstimeout time.Duration
|
||||
recvchansize int
|
||||
}
|
||||
|
||||
func (ep *EndPoint) String() string {
|
||||
return ep.addr.String()
|
||||
}
|
||||
|
||||
func (ep *EndPoint) Network() string {
|
||||
return ep.addr.Network()
|
||||
}
|
||||
|
||||
func (ep *EndPoint) Euqal(ep2 p2p.EndPoint) bool {
|
||||
tcpep2, ok := ep2.(*EndPoint)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
tcpep1 := ep
|
||||
return tcpep1.addr.IP.Equal(tcpep2.addr.IP) &&
|
||||
tcpep1.addr.Port == tcpep2.addr.Port &&
|
||||
tcpep1.addr.Zone == tcpep2.addr.Zone
|
||||
}
|
||||
|
||||
func (ep *EndPoint) Listen() (p2p.Conn, error) {
|
||||
lstn, err := net.ListenTCP(ep.addr.Network(), ep.addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ep.addr = lstn.Addr().(*net.TCPAddr)
|
||||
timeout := ep.peerstimeout
|
||||
if timeout < time.Second {
|
||||
timeout = time.Second
|
||||
}
|
||||
chansz := ep.recvchansize
|
||||
if chansz < 32 {
|
||||
chansz = 32
|
||||
}
|
||||
conn := &Conn{
|
||||
addr: ep,
|
||||
lstn: lstn,
|
||||
peers: ttl.NewCacheOn(timeout, [4]func(string, *net.TCPConn){
|
||||
nil,
|
||||
nil,
|
||||
func(s string, t *net.TCPConn) {
|
||||
err := t.Close()
|
||||
if err != nil {
|
||||
logrus.Debugln("[tcp] close conn from", ep, "to", s, "err:", err)
|
||||
} else {
|
||||
logrus.Debugln("[tcp] close conn from", ep, "to", s)
|
||||
}
|
||||
},
|
||||
ep.keepAlive,
|
||||
}),
|
||||
recv: make(chan *connrecv, chansz),
|
||||
}
|
||||
go conn.accept()
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (ep *EndPoint) keepAlive(_ string, t *net.TCPConn) {
|
||||
_, err := io.Copy(t, &packet{
|
||||
typ: packetTypeKeepAlive,
|
||||
len: 1,
|
||||
dat: []byte{byte(rand.Intn(256))},
|
||||
})
|
||||
if err != nil {
|
||||
logrus.Debugln("[tcp] write keepalive from", ep, "to conn", t.RemoteAddr(), "err:", err)
|
||||
}
|
||||
}
|
||||
|
||||
type connrecv struct {
|
||||
addr *EndPoint // cast from tcpconn.RemoteAddr()
|
||||
pckt packet
|
||||
}
|
||||
|
||||
// Conn 伪装成无状态的有状态连接
|
||||
type Conn struct {
|
||||
addr *EndPoint
|
||||
lstn *net.TCPListener
|
||||
peers *ttl.Cache[string, *net.TCPConn]
|
||||
recv chan *connrecv
|
||||
}
|
||||
|
||||
func (conn *Conn) accept() {
|
||||
for {
|
||||
tcpconn, err := conn.lstn.AcceptTCP()
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) { // normal close
|
||||
logrus.Infoln("[tcp] accept of", conn.addr, "got closed")
|
||||
return
|
||||
}
|
||||
if conn.addr == nil || conn.lstn == nil || conn.peers == nil || conn.recv == nil {
|
||||
return
|
||||
}
|
||||
logrus.Warnln("[tcp] accept on", conn.addr, "err:", err)
|
||||
_ = conn.Close()
|
||||
newc, err := conn.addr.Listen()
|
||||
if err != nil {
|
||||
logrus.Warn("[tcp] re-listen on", conn.addr, "err:", err)
|
||||
return
|
||||
}
|
||||
*conn = *newc.(*Conn)
|
||||
logrus.Info("[tcp] re-listen on", conn.addr)
|
||||
continue
|
||||
}
|
||||
ep := newEndpoint(tcpconn.RemoteAddr().String(), &Config{
|
||||
PeersTimeout: conn.addr.peerstimeout,
|
||||
ReceiveChannelSize: conn.addr.recvchansize,
|
||||
})
|
||||
logrus.Debugln("[tcp] accept from", ep)
|
||||
conn.peers.Set(ep.String(), tcpconn)
|
||||
go conn.receive(ep)
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) receive(ep *EndPoint) {
|
||||
for {
|
||||
r := &connrecv{addr: ep}
|
||||
if conn.addr == nil || conn.lstn == nil || conn.peers == nil || conn.recv == nil {
|
||||
return
|
||||
}
|
||||
tcpconn := conn.peers.Get(ep.String())
|
||||
if tcpconn == nil {
|
||||
return
|
||||
}
|
||||
_, err := io.Copy(&r.pckt, tcpconn)
|
||||
if err != nil {
|
||||
logrus.Debugln("[tcp] recv from", ep, "err:", err)
|
||||
return
|
||||
}
|
||||
logrus.Debugln("[tcp] dispatch packet from", ep, "typ", r.pckt.typ, "len", r.pckt.len)
|
||||
conn.recv <- r
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) Close() error {
|
||||
if conn.lstn != nil {
|
||||
_ = conn.lstn.Close()
|
||||
}
|
||||
if conn.peers != nil {
|
||||
conn.peers.Destroy()
|
||||
}
|
||||
if conn.recv != nil {
|
||||
close(conn.recv)
|
||||
}
|
||||
conn.addr = nil
|
||||
conn.lstn = nil
|
||||
conn.peers = nil
|
||||
conn.recv = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (conn *Conn) String() string {
|
||||
return conn.addr.String()
|
||||
}
|
||||
|
||||
func (conn *Conn) LocalAddr() p2p.EndPoint {
|
||||
return conn.addr
|
||||
}
|
||||
|
||||
func (conn *Conn) ReadFromPeer(b []byte) (int, p2p.EndPoint, error) {
|
||||
var p *connrecv
|
||||
for {
|
||||
p = <-conn.recv
|
||||
if p == nil {
|
||||
return 0, nil, net.ErrClosed
|
||||
}
|
||||
if p.pckt.typ == packetTypeNormal {
|
||||
break
|
||||
}
|
||||
defer helper.PutBytes(p.pckt.dat)
|
||||
}
|
||||
n := copy(b, p.pckt.dat)
|
||||
return n, p.addr, nil
|
||||
}
|
||||
|
||||
func (conn *Conn) WriteToPeer(b []byte, ep p2p.EndPoint) (n int, err error) {
|
||||
tcpep, ok := ep.(*EndPoint)
|
||||
if !ok {
|
||||
return 0, p2p.ErrEndpointTypeMistatch
|
||||
}
|
||||
blen := len(b)
|
||||
if blen >= 65536 {
|
||||
return 0, errors.New("data size " + strconv.Itoa(blen) + " is too large")
|
||||
}
|
||||
tcpconn := conn.peers.Get(tcpep.String())
|
||||
if tcpconn == nil {
|
||||
// must use another port to send because there's no exsiting conn
|
||||
tcpconn, err = net.DialTCP(tcpep.Network(), nil, tcpep.addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
conn.peers.Set(tcpep.String(), tcpconn)
|
||||
}
|
||||
cnt, err := io.Copy(tcpconn, &packet{
|
||||
typ: packetTypeNormal,
|
||||
len: uint16(blen),
|
||||
dat: b,
|
||||
})
|
||||
return int(cnt) - 3, err
|
||||
}
|
||||
@@ -1,17 +1,12 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/fumiama/WireGold/gold/p2p"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrEndpointTypeMistatch = errors.New("endpoint type mismatch")
|
||||
)
|
||||
|
||||
func NewEndpoint(endpoint string, _ ...any) p2p.EndPoint {
|
||||
return (*EndPoint)(net.UDPAddrFromAddrPort(
|
||||
netip.MustParseAddrPort(endpoint),
|
||||
|
||||
@@ -52,7 +52,7 @@ func (conn *Conn) ReadFromPeer(b []byte) (int, p2p.EndPoint, error) {
|
||||
func (conn *Conn) WriteToPeer(b []byte, ep p2p.EndPoint) (int, error) {
|
||||
udpep, ok := ep.(*EndPoint)
|
||||
if !ok {
|
||||
return 0, ErrEndpointTypeMistatch
|
||||
return 0, p2p.ErrEndpointTypeMistatch
|
||||
}
|
||||
return (*net.UDPConn)(conn).WriteTo(b, (*net.UDPAddr)(udpep))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user