mirror of
https://github.com/fumiama/WireGold.git
synced 2026-06-23 03:50:32 +08:00
feat(p2p): support tcp protocol
This commit is contained in:
@@ -13,7 +13,7 @@ type Config struct {
|
|||||||
IP string `yaml:"IP"`
|
IP string `yaml:"IP"`
|
||||||
SubNet string `yaml:"SubNet"`
|
SubNet string `yaml:"SubNet"`
|
||||||
PrivateKey string `yaml:"PrivateKey"`
|
PrivateKey string `yaml:"PrivateKey"`
|
||||||
Network string `yaml:"Network"` // Network udp or ws (WIP)
|
Network string `yaml:"Network"` // Network udp, tcp or ws (WIP)
|
||||||
EndPoint string `yaml:"EndPoint"`
|
EndPoint string `yaml:"EndPoint"`
|
||||||
MTU int64 `yaml:"MTU"`
|
MTU int64 `yaml:"MTU"`
|
||||||
SpeedLoop uint16 `yaml:"SpeedLoop"`
|
SpeedLoop uint16 `yaml:"SpeedLoop"`
|
||||||
|
|||||||
2
go.mod
2
go.mod
@@ -3,7 +3,7 @@ module github.com/fumiama/WireGold
|
|||||||
go 1.20
|
go 1.20
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/FloatTech/ttl v0.0.0-20230307105452-d6f7b2b647d1
|
github.com/FloatTech/ttl v0.0.0-20240715074357-190755f3fece
|
||||||
github.com/RomiChan/syncx v0.0.0-20240418144900-b7402ffdebc7
|
github.com/RomiChan/syncx v0.0.0-20240418144900-b7402ffdebc7
|
||||||
github.com/fumiama/blake2b-simd v0.0.0-20220412110131-4481822068bb
|
github.com/fumiama/blake2b-simd v0.0.0-20220412110131-4481822068bb
|
||||||
github.com/fumiama/go-base16384 v1.7.0
|
github.com/fumiama/go-base16384 v1.7.0
|
||||||
|
|||||||
4
go.sum
4
go.sum
@@ -1,5 +1,5 @@
|
|||||||
github.com/FloatTech/ttl v0.0.0-20230307105452-d6f7b2b647d1 h1:g4pTnDJUW4VbJ9NvoRfUvdjDrHz/6QhfN/LoIIpICbo=
|
github.com/FloatTech/ttl v0.0.0-20240715074357-190755f3fece h1:RIrGO+hIOoXxUh0T3TDaWNvinkXH9S2i12cWivT2MZ4=
|
||||||
github.com/FloatTech/ttl v0.0.0-20230307105452-d6f7b2b647d1/go.mod h1:fHZFWGquNXuHttu9dUYoKuNbm3dzLETnIOnm1muSfDs=
|
github.com/FloatTech/ttl v0.0.0-20240715074357-190755f3fece/go.mod h1:fHZFWGquNXuHttu9dUYoKuNbm3dzLETnIOnm1muSfDs=
|
||||||
github.com/RomiChan/syncx v0.0.0-20240418144900-b7402ffdebc7 h1:S/ferNiehVjNaBMNNBxUjLtVmP/YWD6Yh79RfPv4ehU=
|
github.com/RomiChan/syncx v0.0.0-20240418144900-b7402ffdebc7 h1:S/ferNiehVjNaBMNNBxUjLtVmP/YWD6Yh79RfPv4ehU=
|
||||||
github.com/RomiChan/syncx v0.0.0-20240418144900-b7402ffdebc7/go.mod h1:vD7Ra3Q9onRtojoY5sMCLQ7JBgjUsrXDnDKyFxqpf9w=
|
github.com/RomiChan/syncx v0.0.0-20240418144900-b7402ffdebc7/go.mod h1:vD7Ra3Q9onRtojoY5sMCLQ7JBgjUsrXDnDKyFxqpf9w=
|
||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
|||||||
@@ -12,6 +12,11 @@ import (
|
|||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrBadCRCChecksum = errors.New("bad crc checksum")
|
||||||
|
ErrDataLenLT60 = errors.New("data len < 60")
|
||||||
|
)
|
||||||
|
|
||||||
type PacketFlags uint16
|
type PacketFlags uint16
|
||||||
|
|
||||||
func (pf PacketFlags) IsValid() bool {
|
func (pf PacketFlags) IsValid() bool {
|
||||||
@@ -97,12 +102,12 @@ func NewPacket(proto uint8, srcPort uint16, dst net.IP, dstPort uint16, data []b
|
|||||||
// Unmarshal 将 data 的数据解码到自身
|
// Unmarshal 将 data 的数据解码到自身
|
||||||
func (p *Packet) Unmarshal(data []byte) (complete bool, err error) {
|
func (p *Packet) Unmarshal(data []byte) (complete bool, err error) {
|
||||||
if len(data) < 60 {
|
if len(data) < 60 {
|
||||||
err = errors.New("data len < 60")
|
err = ErrDataLenLT60
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
p.crc64 = binary.LittleEndian.Uint64(data[52:60])
|
p.crc64 = binary.LittleEndian.Uint64(data[52:60])
|
||||||
if crc64.Checksum(data[:52], crc64.MakeTable(crc64.ISO)) != p.crc64 {
|
if crc64.Checksum(data[:52], crc64.MakeTable(crc64.ISO)) != p.crc64 {
|
||||||
err = errors.New("bad crc checksum")
|
err = ErrBadCRCChecksum
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,10 @@ import (
|
|||||||
base14 "github.com/fumiama/go-base16384"
|
base14 "github.com/fumiama/go-base16384"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrPerrNotExist = errors.New("peer not exist")
|
||||||
|
)
|
||||||
|
|
||||||
// Link 是本机到 peer 的连接抽象
|
// Link 是本机到 peer 的连接抽象
|
||||||
type Link struct {
|
type Link struct {
|
||||||
// peer 的公钥
|
// peer 的公钥
|
||||||
@@ -56,7 +60,7 @@ func (m *Me) Connect(peer string) (*Link, error) {
|
|||||||
if ok {
|
if ok {
|
||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
return nil, errors.New("peer not exist")
|
return nil, ErrPerrNotExist
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close 关闭到 peer 的连接
|
// Close 关闭到 peer 的连接
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ func (m *Me) listen() (conn p2p.Conn, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Warnln("[listen] read from udp err, reconnect:", err)
|
logrus.Warnln("[listen] read from conn err, reconnect:", err)
|
||||||
conn, err = m.ep.Listen()
|
conn, err = m.ep.Listen()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Errorln("[listen] reconnect udp err:", err)
|
logrus.Errorln("[listen] reconnect udp err:", err)
|
||||||
|
|||||||
@@ -51,12 +51,15 @@ type Me struct {
|
|||||||
srcport, dstport, mtu, speedloop uint16
|
srcport, dstport, mtu, speedloop uint16
|
||||||
// 报头掩码
|
// 报头掩码
|
||||||
mask uint64
|
mask uint64
|
||||||
|
// 本机网络端点初始化配置
|
||||||
|
networkconfigs []any
|
||||||
}
|
}
|
||||||
|
|
||||||
type MyConfig struct {
|
type MyConfig struct {
|
||||||
MyIPwithMask string
|
MyIPwithMask string
|
||||||
MyEndpoint string
|
MyEndpoint string
|
||||||
Network string
|
Network string
|
||||||
|
NetworkConfigs []any
|
||||||
PrivateKey *[32]byte
|
PrivateKey *[32]byte
|
||||||
NIC lower.NICIO
|
NIC lower.NICIO
|
||||||
SrcPort, DstPort, MTU, SpeedLoop uint16
|
SrcPort, DstPort, MTU, SpeedLoop uint16
|
||||||
@@ -71,7 +74,8 @@ func NewMe(cfg *MyConfig) (m Me) {
|
|||||||
if nw == "" {
|
if nw == "" {
|
||||||
nw = "udp"
|
nw = "udp"
|
||||||
}
|
}
|
||||||
m.ep, err = p2p.NewEndPoint(nw, cfg.MyEndpoint)
|
m.networkconfigs = cfg.NetworkConfigs
|
||||||
|
m.ep, err = p2p.NewEndPoint(nw, cfg.MyEndpoint, m.networkconfigs...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -44,7 +44,12 @@ func (l *Link) onNotify(packet []byte) {
|
|||||||
// ---- 遍历 Notify,注册对方的 endpoint 到
|
// ---- 遍历 Notify,注册对方的 endpoint 到
|
||||||
// ---- connections,注意使用读写锁connmapmu
|
// ---- connections,注意使用读写锁connmapmu
|
||||||
for peer, ep := range notify {
|
for peer, ep := range notify {
|
||||||
addr, err := p2p.NewEndPoint(ep[0], ep[1])
|
nw, epstr := ep[0], ep[1]
|
||||||
|
if nw != l.me.ep.Network() {
|
||||||
|
logrus.Warnln("[nat] ignore different network notify", nw, "addr", epstr)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
addr, err := p2p.NewEndPoint(nw, epstr, l.me.networkconfigs...)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
p, ok := l.me.IsInPeer(peer)
|
p, ok := l.me.IsInPeer(peer)
|
||||||
if ok {
|
if ok {
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ func (m *Me) AddPeer(cfg *PeerConfig) (l *Link) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if cfg.EndPoint != "" {
|
if cfg.EndPoint != "" {
|
||||||
e, err := p2p.NewEndPoint(m.ep.Network(), cfg.EndPoint)
|
e, err := p2p.NewEndPoint(m.ep.Network(), cfg.EndPoint, m.networkconfigs...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,6 +14,11 @@ import (
|
|||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrDropBigDontFragTransPkt = errors.New("drop big don't fragmnet trans packet")
|
||||||
|
ErrTTL = errors.New("ttl exceeded")
|
||||||
|
)
|
||||||
|
|
||||||
// WriteAndPut 向 peer 发包并将包放回缓存池
|
// WriteAndPut 向 peer 发包并将包放回缓存池
|
||||||
func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) {
|
func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) {
|
||||||
defer p.Put()
|
defer p.Put()
|
||||||
@@ -37,7 +42,7 @@ func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) {
|
|||||||
return l.write(p, teatype, sndcnt, uint32(remlen), 0, istransfer, false)
|
return l.write(p, teatype, sndcnt, uint32(remlen), 0, istransfer, false)
|
||||||
}
|
}
|
||||||
if istransfer && p.Flags.DontFrag() && remlen > delta {
|
if istransfer && p.Flags.DontFrag() && remlen > delta {
|
||||||
return 0, errors.New("drop don't fragmnet big trans packet")
|
return 0, ErrDropBigDontFragTransPkt
|
||||||
}
|
}
|
||||||
ttl := p.TTL
|
ttl := p.TTL
|
||||||
totl := uint32(remlen)
|
totl := uint32(remlen)
|
||||||
@@ -93,11 +98,11 @@ func (l *Link) write(p *head.Packet, teatype uint8, additional uint16, datasz ui
|
|||||||
d, cl = p.Marshal(l.me.me, teatype, additional, datasz, offset, false, hasmore)
|
d, cl = p.Marshal(l.me.me, teatype, additional, datasz, offset, false, hasmore)
|
||||||
}
|
}
|
||||||
if d == nil {
|
if d == nil {
|
||||||
return 0, errors.New("[send] ttl exceeded")
|
return 0, ErrTTL
|
||||||
}
|
}
|
||||||
peerep := l.endpoint
|
peerep := l.endpoint
|
||||||
if peerep == nil {
|
if peerep == nil {
|
||||||
return 0, errors.New("[send] nil endpoint of " + p.Dst.String())
|
return 0, errors.New("nil endpoint of " + p.Dst.String())
|
||||||
}
|
}
|
||||||
bound := 64
|
bound := 64
|
||||||
endl := "..."
|
endl := "..."
|
||||||
|
|||||||
@@ -8,6 +8,10 @@ import (
|
|||||||
"github.com/RomiChan/syncx"
|
"github.com/RomiChan/syncx"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrEndpointTypeMistatch = errors.New("endpoint type mismatch")
|
||||||
|
)
|
||||||
|
|
||||||
type Initializer func(endpoint string, configs ...any) EndPoint
|
type Initializer func(endpoint string, configs ...any) EndPoint
|
||||||
|
|
||||||
var factory syncx.Map[string, Initializer]
|
var factory syncx.Map[string, Initializer]
|
||||||
|
|||||||
41
gold/p2p/tcp/init.go
Normal file
41
gold/p2p/tcp/init.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
package tcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/fumiama/WireGold/gold/p2p"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
PeersTimeout time.Duration
|
||||||
|
ReceiveChannelSize int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewEndpoint(endpoint string, configs ...any) p2p.EndPoint {
|
||||||
|
return newEndpoint(endpoint, configs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newEndpoint(endpoint string, configs ...any) *EndPoint {
|
||||||
|
var cfg *Config
|
||||||
|
if len(configs) == 0 || configs[0] == nil {
|
||||||
|
cfg = &Config{}
|
||||||
|
} else {
|
||||||
|
cfg = configs[0].(*Config)
|
||||||
|
}
|
||||||
|
return &EndPoint{
|
||||||
|
addr: net.TCPAddrFromAddrPort(
|
||||||
|
netip.MustParseAddrPort(endpoint),
|
||||||
|
),
|
||||||
|
peerstimeout: cfg.PeersTimeout,
|
||||||
|
recvchansize: cfg.ReceiveChannelSize,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
_, hasexist := p2p.Register("tcp", NewEndpoint)
|
||||||
|
if hasexist {
|
||||||
|
panic("network tcp has been registered")
|
||||||
|
}
|
||||||
|
}
|
||||||
65
gold/p2p/tcp/pdu.go
Normal file
65
gold/p2p/tcp/pdu.go
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
package tcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/fumiama/WireGold/helper"
|
||||||
|
)
|
||||||
|
|
||||||
|
type packetType uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
packetTypeKeepAlive packetType = iota
|
||||||
|
packetTypeNormal
|
||||||
|
)
|
||||||
|
|
||||||
|
type packet struct {
|
||||||
|
typ packetType
|
||||||
|
len uint16
|
||||||
|
dat []byte
|
||||||
|
io.ReaderFrom
|
||||||
|
io.WriterTo
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *packet) pack() (net.Buffers, func()) {
|
||||||
|
d, cl := helper.OpenWriterF(func(w *helper.Writer) {
|
||||||
|
w.WriteByte(byte(p.typ))
|
||||||
|
w.WriteUInt16(p.len)
|
||||||
|
})
|
||||||
|
return net.Buffers{d, p.dat}, cl
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *packet) Read(_ []byte) (int, error) {
|
||||||
|
panic("stub")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *packet) Write(_ []byte) (int, error) {
|
||||||
|
panic("stub")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *packet) ReadFrom(r io.Reader) (n int64, err error) {
|
||||||
|
var buf [3]byte
|
||||||
|
cnt, err := io.ReadFull(r, buf[:])
|
||||||
|
n = int64(cnt)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.typ = packetType(buf[0])
|
||||||
|
p.len = binary.LittleEndian.Uint16(buf[1:3])
|
||||||
|
w := helper.SelectWriter()
|
||||||
|
copied, err := io.CopyN(w, r, int64(p.len))
|
||||||
|
n += copied
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.dat = w.Bytes()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *packet) WriteTo(w io.Writer) (n int64, err error) {
|
||||||
|
buf, cl := p.pack()
|
||||||
|
defer cl()
|
||||||
|
return io.Copy(w, &buf)
|
||||||
|
}
|
||||||
219
gold/p2p/tcp/tcp.go
Normal file
219
gold/p2p/tcp/tcp.go
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
package tcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/FloatTech/ttl"
|
||||||
|
"github.com/fumiama/WireGold/gold/p2p"
|
||||||
|
"github.com/fumiama/WireGold/helper"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type EndPoint struct {
|
||||||
|
addr *net.TCPAddr
|
||||||
|
peerstimeout time.Duration
|
||||||
|
recvchansize int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ep *EndPoint) String() string {
|
||||||
|
return ep.addr.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ep *EndPoint) Network() string {
|
||||||
|
return ep.addr.Network()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ep *EndPoint) Euqal(ep2 p2p.EndPoint) bool {
|
||||||
|
tcpep2, ok := ep2.(*EndPoint)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
tcpep1 := ep
|
||||||
|
return tcpep1.addr.IP.Equal(tcpep2.addr.IP) &&
|
||||||
|
tcpep1.addr.Port == tcpep2.addr.Port &&
|
||||||
|
tcpep1.addr.Zone == tcpep2.addr.Zone
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ep *EndPoint) Listen() (p2p.Conn, error) {
|
||||||
|
lstn, err := net.ListenTCP(ep.addr.Network(), ep.addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ep.addr = lstn.Addr().(*net.TCPAddr)
|
||||||
|
timeout := ep.peerstimeout
|
||||||
|
if timeout < time.Second {
|
||||||
|
timeout = time.Second
|
||||||
|
}
|
||||||
|
chansz := ep.recvchansize
|
||||||
|
if chansz < 32 {
|
||||||
|
chansz = 32
|
||||||
|
}
|
||||||
|
conn := &Conn{
|
||||||
|
addr: ep,
|
||||||
|
lstn: lstn,
|
||||||
|
peers: ttl.NewCacheOn(timeout, [4]func(string, *net.TCPConn){
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
func(s string, t *net.TCPConn) {
|
||||||
|
err := t.Close()
|
||||||
|
if err != nil {
|
||||||
|
logrus.Debugln("[tcp] close conn from", ep, "to", s, "err:", err)
|
||||||
|
} else {
|
||||||
|
logrus.Debugln("[tcp] close conn from", ep, "to", s)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
ep.keepAlive,
|
||||||
|
}),
|
||||||
|
recv: make(chan *connrecv, chansz),
|
||||||
|
}
|
||||||
|
go conn.accept()
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ep *EndPoint) keepAlive(_ string, t *net.TCPConn) {
|
||||||
|
_, err := io.Copy(t, &packet{
|
||||||
|
typ: packetTypeKeepAlive,
|
||||||
|
len: 1,
|
||||||
|
dat: []byte{byte(rand.Intn(256))},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
logrus.Debugln("[tcp] write keepalive from", ep, "to conn", t.RemoteAddr(), "err:", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type connrecv struct {
|
||||||
|
addr *EndPoint // cast from tcpconn.RemoteAddr()
|
||||||
|
pckt packet
|
||||||
|
}
|
||||||
|
|
||||||
|
// Conn 伪装成无状态的有状态连接
|
||||||
|
type Conn struct {
|
||||||
|
addr *EndPoint
|
||||||
|
lstn *net.TCPListener
|
||||||
|
peers *ttl.Cache[string, *net.TCPConn]
|
||||||
|
recv chan *connrecv
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) accept() {
|
||||||
|
for {
|
||||||
|
tcpconn, err := conn.lstn.AcceptTCP()
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, net.ErrClosed) { // normal close
|
||||||
|
logrus.Infoln("[tcp] accept of", conn.addr, "got closed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if conn.addr == nil || conn.lstn == nil || conn.peers == nil || conn.recv == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logrus.Warnln("[tcp] accept on", conn.addr, "err:", err)
|
||||||
|
_ = conn.Close()
|
||||||
|
newc, err := conn.addr.Listen()
|
||||||
|
if err != nil {
|
||||||
|
logrus.Warn("[tcp] re-listen on", conn.addr, "err:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
*conn = *newc.(*Conn)
|
||||||
|
logrus.Info("[tcp] re-listen on", conn.addr)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ep := newEndpoint(tcpconn.RemoteAddr().String(), &Config{
|
||||||
|
PeersTimeout: conn.addr.peerstimeout,
|
||||||
|
ReceiveChannelSize: conn.addr.recvchansize,
|
||||||
|
})
|
||||||
|
logrus.Debugln("[tcp] accept from", ep)
|
||||||
|
conn.peers.Set(ep.String(), tcpconn)
|
||||||
|
go conn.receive(ep)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) receive(ep *EndPoint) {
|
||||||
|
for {
|
||||||
|
r := &connrecv{addr: ep}
|
||||||
|
if conn.addr == nil || conn.lstn == nil || conn.peers == nil || conn.recv == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tcpconn := conn.peers.Get(ep.String())
|
||||||
|
if tcpconn == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, err := io.Copy(&r.pckt, tcpconn)
|
||||||
|
if err != nil {
|
||||||
|
logrus.Debugln("[tcp] recv from", ep, "err:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logrus.Debugln("[tcp] dispatch packet from", ep, "typ", r.pckt.typ, "len", r.pckt.len)
|
||||||
|
conn.recv <- r
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) Close() error {
|
||||||
|
if conn.lstn != nil {
|
||||||
|
_ = conn.lstn.Close()
|
||||||
|
}
|
||||||
|
if conn.peers != nil {
|
||||||
|
conn.peers.Destroy()
|
||||||
|
}
|
||||||
|
if conn.recv != nil {
|
||||||
|
close(conn.recv)
|
||||||
|
}
|
||||||
|
conn.addr = nil
|
||||||
|
conn.lstn = nil
|
||||||
|
conn.peers = nil
|
||||||
|
conn.recv = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) String() string {
|
||||||
|
return conn.addr.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) LocalAddr() p2p.EndPoint {
|
||||||
|
return conn.addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) ReadFromPeer(b []byte) (int, p2p.EndPoint, error) {
|
||||||
|
var p *connrecv
|
||||||
|
for {
|
||||||
|
p = <-conn.recv
|
||||||
|
if p == nil {
|
||||||
|
return 0, nil, net.ErrClosed
|
||||||
|
}
|
||||||
|
if p.pckt.typ == packetTypeNormal {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
defer helper.PutBytes(p.pckt.dat)
|
||||||
|
}
|
||||||
|
n := copy(b, p.pckt.dat)
|
||||||
|
return n, p.addr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) WriteToPeer(b []byte, ep p2p.EndPoint) (n int, err error) {
|
||||||
|
tcpep, ok := ep.(*EndPoint)
|
||||||
|
if !ok {
|
||||||
|
return 0, p2p.ErrEndpointTypeMistatch
|
||||||
|
}
|
||||||
|
blen := len(b)
|
||||||
|
if blen >= 65536 {
|
||||||
|
return 0, errors.New("data size " + strconv.Itoa(blen) + " is too large")
|
||||||
|
}
|
||||||
|
tcpconn := conn.peers.Get(tcpep.String())
|
||||||
|
if tcpconn == nil {
|
||||||
|
// must use another port to send because there's no exsiting conn
|
||||||
|
tcpconn, err = net.DialTCP(tcpep.Network(), nil, tcpep.addr)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conn.peers.Set(tcpep.String(), tcpconn)
|
||||||
|
}
|
||||||
|
cnt, err := io.Copy(tcpconn, &packet{
|
||||||
|
typ: packetTypeNormal,
|
||||||
|
len: uint16(blen),
|
||||||
|
dat: b,
|
||||||
|
})
|
||||||
|
return int(cnt) - 3, err
|
||||||
|
}
|
||||||
@@ -1,17 +1,12 @@
|
|||||||
package udp
|
package udp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/fumiama/WireGold/gold/p2p"
|
"github.com/fumiama/WireGold/gold/p2p"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
ErrEndpointTypeMistatch = errors.New("endpoint type mismatch")
|
|
||||||
)
|
|
||||||
|
|
||||||
func NewEndpoint(endpoint string, _ ...any) p2p.EndPoint {
|
func NewEndpoint(endpoint string, _ ...any) p2p.EndPoint {
|
||||||
return (*EndPoint)(net.UDPAddrFromAddrPort(
|
return (*EndPoint)(net.UDPAddrFromAddrPort(
|
||||||
netip.MustParseAddrPort(endpoint),
|
netip.MustParseAddrPort(endpoint),
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ func (conn *Conn) ReadFromPeer(b []byte) (int, p2p.EndPoint, error) {
|
|||||||
func (conn *Conn) WriteToPeer(b []byte, ep p2p.EndPoint) (int, error) {
|
func (conn *Conn) WriteToPeer(b []byte, ep p2p.EndPoint) (int, error) {
|
||||||
udpep, ok := ep.(*EndPoint)
|
udpep, ok := ep.(*EndPoint)
|
||||||
if !ok {
|
if !ok {
|
||||||
return 0, ErrEndpointTypeMistatch
|
return 0, p2p.ErrEndpointTypeMistatch
|
||||||
}
|
}
|
||||||
return (*net.UDPConn)(conn).WriteTo(b, (*net.UDPAddr)(udpep))
|
return (*net.UDPConn)(conn).WriteTo(b, (*net.UDPAddr)(udpep))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
_ "github.com/fumiama/WireGold/gold/p2p/tcp" // support tcp connection
|
||||||
_ "github.com/fumiama/WireGold/gold/p2p/udp" // support udp connection
|
_ "github.com/fumiama/WireGold/gold/p2p/udp" // support udp connection
|
||||||
|
|
||||||
"github.com/fumiama/WireGold/gold/head"
|
"github.com/fumiama/WireGold/gold/head"
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import (
|
|||||||
"github.com/fumiama/WireGold/helper"
|
"github.com/fumiama/WireGold/helper"
|
||||||
)
|
)
|
||||||
|
|
||||||
func testTunnel(t *testing.T, isplain bool, pshk *[32]byte) {
|
func testTunnel(t *testing.T, nw string, isplain bool, pshk *[32]byte) {
|
||||||
selfpk, err := curve.New(nil)
|
selfpk, err := curve.New(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
@@ -33,6 +33,7 @@ func testTunnel(t *testing.T, isplain bool, pshk *[32]byte) {
|
|||||||
m := link.NewMe(&link.MyConfig{
|
m := link.NewMe(&link.MyConfig{
|
||||||
MyIPwithMask: "192.168.1.2/32",
|
MyIPwithMask: "192.168.1.2/32",
|
||||||
MyEndpoint: "127.0.0.1:0",
|
MyEndpoint: "127.0.0.1:0",
|
||||||
|
Network: nw,
|
||||||
PrivateKey: selfpk.Private(),
|
PrivateKey: selfpk.Private(),
|
||||||
SrcPort: 1,
|
SrcPort: 1,
|
||||||
DstPort: 1,
|
DstPort: 1,
|
||||||
@@ -43,6 +44,7 @@ func testTunnel(t *testing.T, isplain bool, pshk *[32]byte) {
|
|||||||
p := link.NewMe(&link.MyConfig{
|
p := link.NewMe(&link.MyConfig{
|
||||||
MyIPwithMask: "192.168.1.3/32",
|
MyIPwithMask: "192.168.1.3/32",
|
||||||
MyEndpoint: "127.0.0.1:0",
|
MyEndpoint: "127.0.0.1:0",
|
||||||
|
Network: nw,
|
||||||
PrivateKey: peerpk.Private(),
|
PrivateKey: peerpk.Private(),
|
||||||
SrcPort: 1,
|
SrcPort: 1,
|
||||||
DstPort: 1,
|
DstPort: 1,
|
||||||
@@ -146,20 +148,36 @@ func testTunnel(t *testing.T, isplain bool, pshk *[32]byte) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTunnel(t *testing.T) {
|
func TestTunnelUDP(t *testing.T) {
|
||||||
logrus.SetLevel(logrus.DebugLevel)
|
logrus.SetLevel(logrus.DebugLevel)
|
||||||
logrus.SetFormatter(&logFormat{enableColor: false})
|
logrus.SetFormatter(&logFormat{enableColor: false})
|
||||||
|
|
||||||
testTunnel(t, true, nil) // test plain text
|
testTunnel(t, "udp", true, nil) // test plain text
|
||||||
|
|
||||||
testTunnel(t, false, nil) // test normal
|
testTunnel(t, "udp", false, nil) // test normal
|
||||||
|
|
||||||
var buf [32]byte
|
var buf [32]byte
|
||||||
_, err := rand.Read(buf[:])
|
_, err := rand.Read(buf[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
testTunnel(t, false, &buf) // test preshared
|
testTunnel(t, "udp", false, &buf) // test preshared
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTunnelTCP(t *testing.T) {
|
||||||
|
logrus.SetLevel(logrus.DebugLevel)
|
||||||
|
logrus.SetFormatter(&logFormat{enableColor: false})
|
||||||
|
|
||||||
|
testTunnel(t, "tcp", true, nil) // test plain text
|
||||||
|
|
||||||
|
testTunnel(t, "tcp", false, nil) // test normal
|
||||||
|
|
||||||
|
var buf [32]byte
|
||||||
|
_, err := rand.Read(buf[:])
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
testTunnel(t, "tcp", false, &buf) // test preshared
|
||||||
}
|
}
|
||||||
|
|
||||||
// logFormat specialize for go-cqhttp
|
// logFormat specialize for go-cqhttp
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
curve "github.com/fumiama/go-x25519"
|
curve "github.com/fumiama/go-x25519"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
_ "github.com/fumiama/WireGold/gold/p2p/tcp" // support tcp connection
|
||||||
_ "github.com/fumiama/WireGold/gold/p2p/udp" // support udp connection
|
_ "github.com/fumiama/WireGold/gold/p2p/udp" // support udp connection
|
||||||
|
|
||||||
"github.com/fumiama/WireGold/config"
|
"github.com/fumiama/WireGold/config"
|
||||||
@@ -36,7 +37,7 @@ func NewWireGold(c *config.Config) (wg WG, err error) {
|
|||||||
}
|
}
|
||||||
n := copy(wg.key[:], base14.Decode(k))
|
n := copy(wg.key[:], base14.Decode(k))
|
||||||
if n != 32 {
|
if n != 32 {
|
||||||
err = errors.New("private key length is not 32")
|
err = errors.New("private key length != 32, got " + strconv.Itoa(n))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user