1
0
mirror of https://github.com/fumiama/dnskip.git synced 2026-06-30 01:30:24 +08:00

feat: add fallback logic

This commit is contained in:
源文雨
2025-10-03 22:08:13 +08:00
parent f7f2ef3663
commit 473f442e0d

88
main.go
View File

@@ -3,6 +3,7 @@ package main
import ( import (
"context" "context"
"encoding/binary" "encoding/binary"
"errors"
"flag" "flag"
"io" "io"
"math" "math"
@@ -28,10 +29,14 @@ var (
_ = c.Close() _ = c.Close()
}, nil, }, nil,
}) })
fallback *net.UDPAddr
forcefb bool
) )
func main() { func main() {
iphost := flag.String("l", "127.0.0.1:5345", "listen DNS UDP port") iphost := flag.String("l", "0.0.0.0:53", "listen DNS UDP port")
fbsrv := flag.String("fb", "127.0.0.1:5345", "fallback to DNS UDP port")
flag.BoolVar(&forcefb, "ffb", false, "force using fallback")
frag := flag.Uint("frag", 3, "TLS first fragemt size") frag := flag.Uint("frag", 3, "TLS first fragemt size")
flag.Parse() flag.Parse()
@@ -39,6 +44,15 @@ func main() {
terasu.DefaultFirstFragmentLen = uint8(*frag) terasu.DefaultFirstFragmentLen = uint8(*frag)
} }
if *fbsrv != "" {
addrport, err := netip.ParseAddrPort(*fbsrv)
if err != nil {
logrus.Fatal("ParseAddrPort err:", err)
}
fallback = net.UDPAddrFromAddrPort(addrport)
logrus.Infoln("Set fallback server to", fallback)
}
addrport, err := netip.ParseAddrPort(*iphost) addrport, err := netip.ParseAddrPort(*iphost)
if err != nil { if err != nil {
logrus.Fatal("ParseAddrPort err:", err) logrus.Fatal("ParseAddrPort err:", err)
@@ -74,12 +88,21 @@ RECONN:
} }
func response(cnt uint8, conn *net.UDPConn, addr *net.UDPAddr, payload pbuf.Bytes) { func response(cnt uint8, conn *net.UDPConn, addr *net.UDPAddr, payload pbuf.Bytes) {
var (
err error
tlsconn net.Conn
loopcnt = 0
)
if forcefb {
goto FALLBACK
}
defer releasefree(cnt) defer releasefree(cnt)
logrus.Debugln(addr, "Run on lock", cnt) logrus.Debugln(addr, "Run on lock", cnt)
loopcnt := 0
REDAIL: REDAIL:
tlsconn, err := dialtls(cnt) tlsconn, err = dialtls(cnt)
if err != nil { if err != nil {
logrus.Warnln(addr, "Dial DNS server err:", err) logrus.Warnln(addr, "Dial DNS server err:", err)
return return
@@ -118,15 +141,60 @@ REDAIL:
logrus.Debugln("Write response to", addr) logrus.Debugln("Write response to", addr)
}) })
}) })
if err != nil { if err == nil {
logrus.Warnln(addr, "Proxy to DNS server err:", err)
tlsconnCache.Delete(cnt)
loopcnt++
if loopcnt < 3 {
goto REDAIL
}
return return
} }
logrus.Warnln(addr, "Proxy to DNS server err:", err)
tlsconnCache.Delete(cnt)
loopcnt++
if loopcnt < 2 {
goto REDAIL
}
FALLBACK:
fbconn, err := net.DialUDP("udp", nil, fallback) // dummy fill
if err != nil {
logrus.Warnln(addr, "Fallback DialUDP err:", err)
return
}
logrus.Warnln(addr, "Fallback from", fbconn.LocalAddr(), "to", fbconn.RemoteAddr())
defer fbconn.Close()
payload.V(func(b []byte) {
_, err = fbconn.Write(b)
if err != nil {
logrus.Warnln(addr, "Write to fallback err:", err)
return
}
})
if err != nil {
return
}
_ = fbconn.SetReadDeadline(time.Now().Add(time.Second * 4))
pbuf.NewBytes(4096).V(func(b []byte) {
var (
n int
srvaddr *net.UDPAddr
)
for {
n, srvaddr, err = fbconn.ReadFromUDP(b)
if err != nil {
logrus.Warnln(addr, "Read from fallback err:", err)
return
}
if !srvaddr.IP.Equal(fallback.IP) || srvaddr.Port != fallback.Port {
logrus.Warnln(addr, "Read expect fallback addr but got", srvaddr)
err = errors.New("wrong endpoint")
continue
}
break
}
if err == nil {
_, err = conn.WriteToUDP(b[:n], addr)
if err != nil {
return
}
logrus.Debugln("Write fallback response to", addr)
}
})
} }
// lockfree is spin update // lockfree is spin update