1
0
mirror of https://github.com/fumiama/dnskip.git synced 2026-06-28 08:40:23 +08:00

fix: too many conn in parallel

This commit is contained in:
源文雨
2026-01-20 22:52:24 +08:00
parent d25b80df48
commit 7d1257379b
3 changed files with 71 additions and 50 deletions

83
main.go
View File

@@ -2,6 +2,7 @@ package main
import (
"context"
"crypto/tls"
"encoding/binary"
"errors"
"flag"
@@ -9,11 +10,12 @@ import (
"math"
"net"
"net/netip"
"runtime"
"sync"
"sync/atomic"
"time"
"unsafe"
"github.com/FloatTech/ttl"
"github.com/fumiama/orbyte/pbuf"
"github.com/fumiama/terasu"
"github.com/fumiama/terasu/dns"
@@ -22,14 +24,6 @@ import (
)
var (
freeconn = uintptr(0)
tlsconnCache = ttl.NewCacheOn(
5*time.Minute, [4]func(uint8, net.Conn){
nil, nil, func(_ uint8, c net.Conn) {
logrus.Warnln("Close idle/error tls conn to", c.RemoteAddr())
_ = c.Close()
}, nil,
})
fallback *net.UDPAddr
forcefb bool
timeout uint
@@ -103,7 +97,8 @@ RECONN:
func response(cnt uint8, conn *net.UDPConn, addr *net.UDPAddr, payload pbuf.Bytes) {
var (
err error
tlsconn net.Conn
cl func()
tlsconn *tls.Conn
loopcnt = 0
ctx context.Context
cancel context.CancelFunc
@@ -114,12 +109,12 @@ func response(cnt uint8, conn *net.UDPConn, addr *net.UDPAddr, payload pbuf.Byte
}
defer releasefree(cnt)
logrus.Debugln(addr, "Run on lock", cnt)
logrus.Debugln(addr, "Run thread", cnt)
REDAIL:
ctx, cancel = context.WithTimeout(context.Background(), time.Second*time.Duration(timeout))
defer cancel()
tlsconn, err = dialtls(cnt, ctx)
tlsconn, cl, err = dialtls(cnt, ctx)
if err != nil {
logrus.Warnln(addr, "Dial DNS server err:", err)
return
@@ -127,6 +122,8 @@ REDAIL:
logrus.Debugln(addr, "Dial to DNS server", tlsconn.RemoteAddr())
payload.V(func(b []byte) {
defer cl()
var plen [2]byte
binary.BigEndian.PutUint16(plen[:], uint16(len(b)))
_, err = io.Copy(tlsconn, &net.Buffers{plen[:], b})
@@ -162,7 +159,10 @@ REDAIL:
return
}
logrus.Warnln(addr, "Proxy to DNS server err:", err)
tlsconnCache.Delete(cnt)
atomic.StorePointer(
(*unsafe.Pointer)(unsafe.Pointer(&remoconn)),
unsafe.Pointer(nil),
)
loopcnt++
if loopcnt < 2 {
goto REDAIL
@@ -214,17 +214,21 @@ FALLBACK:
})
}
var (
freeconn = uint32(0)
)
// lockfree is spin update
func lockfree() uint8 {
old := atomic.LoadUintptr(&freeconn)
old := atomic.LoadUint32(&freeconn)
for i := uint8(0); i < uint8(unsafe.Sizeof(uintptr(0)))*8; i++ {
for old&(1<<i) == 0 { // is free
ok := atomic.CompareAndSwapUintptr(&freeconn, old, old|(1<<i))
ok := atomic.CompareAndSwapUint32(&freeconn, old, old|(1<<i))
if ok {
return i
}
// update latest
old = atomic.LoadUintptr(&freeconn)
old = atomic.LoadUint32(&freeconn)
}
}
return math.MaxUint8
@@ -232,29 +236,52 @@ func lockfree() uint8 {
// releasefree is spin update
func releasefree(i uint8) {
old := atomic.LoadUintptr(&freeconn)
old := atomic.LoadUint32(&freeconn)
for old&(1<<i) != 0 { // is not free
ok := atomic.CompareAndSwapUintptr(&freeconn, old, old&^(1<<i))
ok := atomic.CompareAndSwapUint32(&freeconn, old, old&^(1<<i))
if ok {
return
}
// update latest
old = atomic.LoadUintptr(&freeconn)
old = atomic.LoadUint32(&freeconn)
}
logrus.Debugln("Free thread", i)
}
func dialtls(cnt uint8, ctx context.Context) (net.Conn, error) {
conn := tlsconnCache.Get(cnt)
var (
remoconn *tls.Conn
connmu sync.Mutex
)
func dialtls(cnt uint8, ctx context.Context) (*tls.Conn, func(), error) {
conn := (*tls.Conn)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&remoconn))))
if conn != nil {
logrus.Debugln("Lock", cnt, "get cached tls conn to", conn.RemoteAddr())
return conn, nil
logrus.Debugln("Thread", cnt, "get cached tls conn to", conn.RemoteAddr())
connmu.Lock()
return conn, connmu.Unlock, nil
}
// slow path
connmu.Lock()
conn = (*tls.Conn)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&remoconn))))
if conn != nil {
logrus.Debugln("Thread", cnt, "slowly get cached tls conn to", conn.RemoteAddr())
return conn, connmu.Unlock, nil
}
// dummy nw and addr
conn, err := dns.DefaultResolver.Dial(ctx, "", "")
connintf, err := dns.DefaultResolver.Dial(ctx, "", "")
if err != nil {
return nil, err
connmu.Unlock()
return nil, nil, err
}
tlsconnCache.Set(cnt, conn)
logrus.Debugln("Lock", cnt, "set new tls conn to", conn.RemoteAddr())
return conn, nil
conn = connintf.(*tls.Conn)
atomic.StorePointer(
(*unsafe.Pointer)(unsafe.Pointer(&remoconn)),
unsafe.Pointer(conn),
)
runtime.SetFinalizer(conn, func(conn *tls.Conn) {
logrus.Warnln("Cleanup unused conn to", conn.RemoteAddr())
_ = conn.Close()
})
logrus.Debugln("Thread", cnt, "set new tls conn to", conn.RemoteAddr())
return conn, connmu.Unlock, nil
}