diff --git a/main.go b/main.go index aa8190a..7e79603 100644 --- a/main.go +++ b/main.go @@ -3,6 +3,7 @@ package main import ( "context" "encoding/binary" + "errors" "flag" "io" "math" @@ -28,10 +29,14 @@ var ( _ = c.Close() }, nil, }) + fallback *net.UDPAddr + forcefb bool ) func main() { - iphost := flag.String("l", "127.0.0.1:5345", "listen DNS UDP port") + iphost := flag.String("l", "0.0.0.0:53", "listen DNS UDP port") + fbsrv := flag.String("fb", "127.0.0.1:5345", "fallback to DNS UDP port") + flag.BoolVar(&forcefb, "ffb", false, "force using fallback") frag := flag.Uint("frag", 3, "TLS first fragemt size") flag.Parse() @@ -39,6 +44,15 @@ func main() { terasu.DefaultFirstFragmentLen = uint8(*frag) } + if *fbsrv != "" { + addrport, err := netip.ParseAddrPort(*fbsrv) + if err != nil { + logrus.Fatal("ParseAddrPort err:", err) + } + fallback = net.UDPAddrFromAddrPort(addrport) + logrus.Infoln("Set fallback server to", fallback) + } + addrport, err := netip.ParseAddrPort(*iphost) if err != nil { logrus.Fatal("ParseAddrPort err:", err) @@ -74,12 +88,21 @@ RECONN: } func response(cnt uint8, conn *net.UDPConn, addr *net.UDPAddr, payload pbuf.Bytes) { + var ( + err error + tlsconn net.Conn + loopcnt = 0 + ) + + if forcefb { + goto FALLBACK + } + defer releasefree(cnt) logrus.Debugln(addr, "Run on lock", cnt) - loopcnt := 0 REDAIL: - tlsconn, err := dialtls(cnt) + tlsconn, err = dialtls(cnt) if err != nil { logrus.Warnln(addr, "Dial DNS server err:", err) return @@ -118,15 +141,60 @@ REDAIL: logrus.Debugln("Write response to", addr) }) }) - if err != nil { - logrus.Warnln(addr, "Proxy to DNS server err:", err) - tlsconnCache.Delete(cnt) - loopcnt++ - if loopcnt < 3 { - goto REDAIL - } + if err == nil { return } + logrus.Warnln(addr, "Proxy to DNS server err:", err) + tlsconnCache.Delete(cnt) + loopcnt++ + if loopcnt < 2 { + goto REDAIL + } +FALLBACK: + fbconn, err := net.DialUDP("udp", nil, fallback) // dummy fill + if err != nil { + logrus.Warnln(addr, "Fallback DialUDP err:", err) + return + } + logrus.Warnln(addr, "Fallback from", fbconn.LocalAddr(), "to", fbconn.RemoteAddr()) + defer fbconn.Close() + payload.V(func(b []byte) { + _, err = fbconn.Write(b) + if err != nil { + logrus.Warnln(addr, "Write to fallback err:", err) + return + } + }) + if err != nil { + return + } + _ = fbconn.SetReadDeadline(time.Now().Add(time.Second * 4)) + pbuf.NewBytes(4096).V(func(b []byte) { + var ( + n int + srvaddr *net.UDPAddr + ) + for { + n, srvaddr, err = fbconn.ReadFromUDP(b) + if err != nil { + logrus.Warnln(addr, "Read from fallback err:", err) + return + } + if !srvaddr.IP.Equal(fallback.IP) || srvaddr.Port != fallback.Port { + logrus.Warnln(addr, "Read expect fallback addr but got", srvaddr) + err = errors.New("wrong endpoint") + continue + } + break + } + if err == nil { + _, err = conn.WriteToUDP(b[:n], addr) + if err != nil { + return + } + logrus.Debugln("Write fallback response to", addr) + } + }) } // lockfree is spin update