From 9f9eb1d83f6f2fb83d632cc731fedfd33ba90283 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Fri, 15 Apr 2022 15:51:39 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=B6=E5=8C=85=E5=AE=9A=E6=97=B6=E5=99=A8?= =?UTF-8?q?=E6=94=B9=E7=94=A8ttl?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- gold/link/crypto.go | 16 ++++++++-------- gold/link/me.go | 6 +----- gold/link/recv.go | 41 ++++++----------------------------------- gold/link/router.go | 16 ++++++++-------- 4 files changed, 23 insertions(+), 56 deletions(-) diff --git a/gold/link/crypto.go b/gold/link/crypto.go index a5317a3..9e0805e 100644 --- a/gold/link/crypto.go +++ b/gold/link/crypto.go @@ -7,11 +7,11 @@ func (l *Link) Encode(teatype uint8, b []byte) (eb []byte) { } if l.key == nil { eb = b - } else { - // 在此处填写加密逻辑,密钥是l.key,输入是b,输出是eb - // 不用写return,直接赋值给eb即可 - eb = l.key[teatype].Encrypt(b) + return } + // 在此处填写加密逻辑,密钥是l.key,输入是b,输出是eb + // 不用写return,直接赋值给eb即可 + eb = l.key[teatype].Encrypt(b) return } @@ -22,10 +22,10 @@ func (l *Link) Decode(teatype uint8, b []byte) (db []byte) { } if l.key == nil { db = b - } else { - // 在此处填写解密逻辑,密钥是l.key,输入是b,输出是db - // 不用写return,直接赋值给db即可 - db = l.key[teatype].Decrypt(b) + return } + // 在此处填写解密逻辑,密钥是l.key,输入是b,输出是db + // 不用写return,直接赋值给db即可 + db = l.key[teatype].Decrypt(b) return } diff --git a/gold/link/me.go b/gold/link/me.go index 40a4069..62c4922 100644 --- a/gold/link/me.go +++ b/gold/link/me.go @@ -44,11 +44,7 @@ type Me struct { // 本机发送缓冲区 writer *helper.Writer // 本机未接收完全分片池 - recving map[[32]byte]*head.Packet - // 接收锁 - recvmu sync.Mutex - // 收包超时定时器 - clock map[*head.Packet]uint8 + recving *ttl.Cache[[32]byte, *head.Packet] // 本机上层配置 srcport, dstport, mtu uint16 } diff --git a/gold/link/recv.go b/gold/link/recv.go index 297703a..1bd009b 100644 --- a/gold/link/recv.go +++ b/gold/link/recv.go @@ -9,6 +9,7 @@ import ( "github.com/fumiama/WireGold/gold/head" "github.com/fumiama/WireGold/helper" "github.com/sirupsen/logrus" + "github.com/wdvxdr1123/ZeroBot/extension/ttl" ) // Read 从 peer 收包 @@ -20,32 +21,7 @@ func (m *Me) initrecvpool() { if m.writer == nil { m.writer = helper.SelectWriter() } - if m.recving == nil { - m.recving = make(map[[32]byte]*head.Packet, 128) - } - // 超时定时器 - m.clock = make(map[*head.Packet]uint8, 128) - var delhs []*head.Packet - t := time.NewTicker(time.Second) - go func() { - for range t.C { - m.recvmu.Lock() - for k, v := range m.clock { - if v > 10 { // 10s - delete(m.recving, k.Hash) - delhs = append(delhs, k) - } else { - m.clock[k]++ - } - } - for _, k := range delhs { - delete(m.clock, k) - logrus.Warnln("[recv] drop timeout packet from", k.Src) - } - delhs = delhs[:0] - m.recvmu.Unlock() - } - }() + m.recving = ttl.NewCache[[32]byte, *head.Packet](time.Second * 128) } func (m *Me) wait(data []byte) *head.Packet { @@ -61,22 +37,18 @@ func (m *Me) wait(data []byte) *head.Packet { return h } - m.recvmu.Lock() - defer m.recvmu.Unlock() hashd := data[20:52] hsh := *(*[32]byte)(*(*unsafe.Pointer)(unsafe.Pointer(&hashd))) - h, ok := m.recving[hsh] - if ok { + h := m.recving.Get(hsh) + if h != nil { logrus.Debugln("[recv] get another frag part of", hex.EncodeToString(hashd)) ok, err := h.Unmarshal(data) if err == nil { if ok { - delete(m.clock, h) - delete(m.recving, hsh) + m.recving.Delete(hsh) logrus.Debugln("[recv] all parts of", hex.EncodeToString(hashd), "is reached") return h } - m.clock[h] = 0 } else { logrus.Errorln("[recv] unmarshal err:", err) } @@ -89,7 +61,6 @@ func (m *Me) wait(data []byte) *head.Packet { logrus.Errorln("[recv] unmarshal err:", err) return nil } - m.recving[hsh] = h - m.clock[h] = 0 + m.recving.Set(hsh, h) return nil } diff --git a/gold/link/router.go b/gold/link/router.go index 8221a98..85670a4 100644 --- a/gold/link/router.go +++ b/gold/link/router.go @@ -10,9 +10,9 @@ import ( ) type Router struct { + sync.RWMutex // map[cidr]*Link table map[string]*Link - mu sync.RWMutex list []*net.IPNet cache *ttl.Cache[string, *Link] } @@ -35,10 +35,10 @@ func (l *Link) IsToMe(ip net.IP) bool { // SetDefault 设置默认网关 func (r *Router) SetDefault(l *Link) { defnet := &net.IPNet{IP: net.IPv4(0, 0, 0, 0), Mask: net.IPv4Mask(0, 0, 0, 0)} - r.mu.Lock() + r.Lock() r.list[len(r.list)-1] = defnet r.table[defnet.String()] = l - r.mu.Unlock() + r.Unlock() } // NextHop 得到前往 ip 的下一跳的 link @@ -54,10 +54,10 @@ func (r *Router) NextHop(ip string) (l *Link) { return } - // TODO: 遍历 r.table,得到正确的下一跳 + // 遍历 r.table,得到正确的下一跳 // 注意使用 r.mu 读写锁避免竞争 - r.mu.RLock() - defer r.mu.RUnlock() + r.RLock() + defer r.RUnlock() for _, c := range r.list { if c.Contains(ipb) { @@ -75,7 +75,7 @@ func (r *Router) NextHop(ip string) (l *Link) { // SetItem 添加一条表项 func (r *Router) SetItem(ip *net.IPNet, l *Link) { - r.mu.Lock() + r.Lock() // 从第一条表项开始匹配 for i := 0; i < len(r.list); i++ { if r.list[i].Contains(ip.IP) { @@ -94,7 +94,7 @@ func (r *Router) SetItem(ip *net.IPNet, l *Link) { break } } - r.mu.Unlock() + r.Unlock() } func isSubnetBcast(ip net.IP, subnet *net.IPNet) bool {