mirror of
https://github.com/fumiama/dnskip.git
synced 2026-06-28 08:40:23 +08:00
fix: too many conn in parallel
This commit is contained in:
83
main.go
83
main.go
@@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"flag"
|
||||
@@ -9,11 +10,12 @@ import (
|
||||
"math"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/FloatTech/ttl"
|
||||
"github.com/fumiama/orbyte/pbuf"
|
||||
"github.com/fumiama/terasu"
|
||||
"github.com/fumiama/terasu/dns"
|
||||
@@ -22,14 +24,6 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
freeconn = uintptr(0)
|
||||
tlsconnCache = ttl.NewCacheOn(
|
||||
5*time.Minute, [4]func(uint8, net.Conn){
|
||||
nil, nil, func(_ uint8, c net.Conn) {
|
||||
logrus.Warnln("Close idle/error tls conn to", c.RemoteAddr())
|
||||
_ = c.Close()
|
||||
}, nil,
|
||||
})
|
||||
fallback *net.UDPAddr
|
||||
forcefb bool
|
||||
timeout uint
|
||||
@@ -103,7 +97,8 @@ RECONN:
|
||||
func response(cnt uint8, conn *net.UDPConn, addr *net.UDPAddr, payload pbuf.Bytes) {
|
||||
var (
|
||||
err error
|
||||
tlsconn net.Conn
|
||||
cl func()
|
||||
tlsconn *tls.Conn
|
||||
loopcnt = 0
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
@@ -114,12 +109,12 @@ func response(cnt uint8, conn *net.UDPConn, addr *net.UDPAddr, payload pbuf.Byte
|
||||
}
|
||||
|
||||
defer releasefree(cnt)
|
||||
logrus.Debugln(addr, "Run on lock", cnt)
|
||||
logrus.Debugln(addr, "Run thread", cnt)
|
||||
|
||||
REDAIL:
|
||||
ctx, cancel = context.WithTimeout(context.Background(), time.Second*time.Duration(timeout))
|
||||
defer cancel()
|
||||
tlsconn, err = dialtls(cnt, ctx)
|
||||
tlsconn, cl, err = dialtls(cnt, ctx)
|
||||
if err != nil {
|
||||
logrus.Warnln(addr, "Dial DNS server err:", err)
|
||||
return
|
||||
@@ -127,6 +122,8 @@ REDAIL:
|
||||
logrus.Debugln(addr, "Dial to DNS server", tlsconn.RemoteAddr())
|
||||
|
||||
payload.V(func(b []byte) {
|
||||
defer cl()
|
||||
|
||||
var plen [2]byte
|
||||
binary.BigEndian.PutUint16(plen[:], uint16(len(b)))
|
||||
_, err = io.Copy(tlsconn, &net.Buffers{plen[:], b})
|
||||
@@ -162,7 +159,10 @@ REDAIL:
|
||||
return
|
||||
}
|
||||
logrus.Warnln(addr, "Proxy to DNS server err:", err)
|
||||
tlsconnCache.Delete(cnt)
|
||||
atomic.StorePointer(
|
||||
(*unsafe.Pointer)(unsafe.Pointer(&remoconn)),
|
||||
unsafe.Pointer(nil),
|
||||
)
|
||||
loopcnt++
|
||||
if loopcnt < 2 {
|
||||
goto REDAIL
|
||||
@@ -214,17 +214,21 @@ FALLBACK:
|
||||
})
|
||||
}
|
||||
|
||||
var (
|
||||
freeconn = uint32(0)
|
||||
)
|
||||
|
||||
// lockfree is spin update
|
||||
func lockfree() uint8 {
|
||||
old := atomic.LoadUintptr(&freeconn)
|
||||
old := atomic.LoadUint32(&freeconn)
|
||||
for i := uint8(0); i < uint8(unsafe.Sizeof(uintptr(0)))*8; i++ {
|
||||
for old&(1<<i) == 0 { // is free
|
||||
ok := atomic.CompareAndSwapUintptr(&freeconn, old, old|(1<<i))
|
||||
ok := atomic.CompareAndSwapUint32(&freeconn, old, old|(1<<i))
|
||||
if ok {
|
||||
return i
|
||||
}
|
||||
// update latest
|
||||
old = atomic.LoadUintptr(&freeconn)
|
||||
old = atomic.LoadUint32(&freeconn)
|
||||
}
|
||||
}
|
||||
return math.MaxUint8
|
||||
@@ -232,29 +236,52 @@ func lockfree() uint8 {
|
||||
|
||||
// releasefree is spin update
|
||||
func releasefree(i uint8) {
|
||||
old := atomic.LoadUintptr(&freeconn)
|
||||
old := atomic.LoadUint32(&freeconn)
|
||||
for old&(1<<i) != 0 { // is not free
|
||||
ok := atomic.CompareAndSwapUintptr(&freeconn, old, old&^(1<<i))
|
||||
ok := atomic.CompareAndSwapUint32(&freeconn, old, old&^(1<<i))
|
||||
if ok {
|
||||
return
|
||||
}
|
||||
// update latest
|
||||
old = atomic.LoadUintptr(&freeconn)
|
||||
old = atomic.LoadUint32(&freeconn)
|
||||
}
|
||||
logrus.Debugln("Free thread", i)
|
||||
}
|
||||
|
||||
func dialtls(cnt uint8, ctx context.Context) (net.Conn, error) {
|
||||
conn := tlsconnCache.Get(cnt)
|
||||
var (
|
||||
remoconn *tls.Conn
|
||||
connmu sync.Mutex
|
||||
)
|
||||
|
||||
func dialtls(cnt uint8, ctx context.Context) (*tls.Conn, func(), error) {
|
||||
conn := (*tls.Conn)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&remoconn))))
|
||||
if conn != nil {
|
||||
logrus.Debugln("Lock", cnt, "get cached tls conn to", conn.RemoteAddr())
|
||||
return conn, nil
|
||||
logrus.Debugln("Thread", cnt, "get cached tls conn to", conn.RemoteAddr())
|
||||
connmu.Lock()
|
||||
return conn, connmu.Unlock, nil
|
||||
}
|
||||
// slow path
|
||||
connmu.Lock()
|
||||
conn = (*tls.Conn)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&remoconn))))
|
||||
if conn != nil {
|
||||
logrus.Debugln("Thread", cnt, "slowly get cached tls conn to", conn.RemoteAddr())
|
||||
return conn, connmu.Unlock, nil
|
||||
}
|
||||
// dummy nw and addr
|
||||
conn, err := dns.DefaultResolver.Dial(ctx, "", "")
|
||||
connintf, err := dns.DefaultResolver.Dial(ctx, "", "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
connmu.Unlock()
|
||||
return nil, nil, err
|
||||
}
|
||||
tlsconnCache.Set(cnt, conn)
|
||||
logrus.Debugln("Lock", cnt, "set new tls conn to", conn.RemoteAddr())
|
||||
return conn, nil
|
||||
conn = connintf.(*tls.Conn)
|
||||
atomic.StorePointer(
|
||||
(*unsafe.Pointer)(unsafe.Pointer(&remoconn)),
|
||||
unsafe.Pointer(conn),
|
||||
)
|
||||
runtime.SetFinalizer(conn, func(conn *tls.Conn) {
|
||||
logrus.Warnln("Cleanup unused conn to", conn.RemoteAddr())
|
||||
_ = conn.Close()
|
||||
})
|
||||
logrus.Debugln("Thread", cnt, "set new tls conn to", conn.RemoteAddr())
|
||||
return conn, connmu.Unlock, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user