From fffef1be40e7a909daed740bb2da86edb5c550f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Sun, 5 Oct 2025 12:40:27 +0800 Subject: [PATCH] feat(dns): support lookup in sequence --- dns/dns.go | 50 +++++++++++++++++++++++++++++++++++++++---------- dns/dns_test.go | 9 +++++++-- 2 files changed, 47 insertions(+), 12 deletions(-) diff --git a/dns/dns.go b/dns/dns.go index 8257869..a078cd8 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -41,8 +41,9 @@ func (ds *dnsstat) disable(reEnable time.Duration) { type DNSList struct { sync.RWMutex - m map[string][]*dnsstat - b map[string][]string + hostseq []string + m map[string][]*dnsstat + b map[string][]string } type DNSConfig struct { @@ -74,9 +75,14 @@ func (ds *DNSList) Add(c *DNSConfig) { ds.Lock() defer ds.Unlock() addList := map[string][]*dnsstat{} + addHosts := map[string]struct{}{} for host, addrs := range c.Servers { + availableHosts, ok := ds.m[host] + if !ok { + addHosts[host] = struct{}{} + } for _, addr := range addrs { - if !hasrecord(ds.m[host], addr) && !hasrecord(addList[host], addr) { + if !hasrecord(availableHosts, addr) && !hasrecord(addList[host], addr) { addList[host] = append(addList[host], &dnsstat{addr, true}) } } @@ -84,6 +90,9 @@ func (ds *DNSList) Add(c *DNSConfig) { for host, addrs := range addList { ds.m[host] = append(ds.m[host], addrs...) } + for host := range addHosts { + ds.hostseq = append(ds.hostseq, host) + } addListFallback := map[string][]string{} for host, addrs := range c.Fallbacks { for _, addr := range addrs { @@ -97,24 +106,38 @@ func (ds *DNSList) Add(c *DNSConfig) { } } -func (ds *DNSList) lookupHostDoH(ctx context.Context, host string) ([]string, error) { +// rangeHosts in sequence, please use in rlock +func (ds *DNSList) rangeHosts(fn func(host string, addrs []*dnsstat) error) error { + for _, h := range ds.hostseq { + if err := fn(h, ds.m[h]); err != nil { + return err + } + } + return nil +} + +func (ds *DNSList) lookupHostDoH(ctx context.Context, host string) (hosts []string, err error) { ds.RLock() defer ds.RUnlock() // try to use DoH first - for _, addrs := range ds.m { + err = ds.rangeHosts(func(_ string, addrs []*dnsstat) error { for _, addr := range addrs { if !addr.e || !strings.HasPrefix(addr.a, "https://") { // disabled or is not DoH continue } jr, err := lookupdoh(ctx, addr.a, host) if err == nil { - hosts := jr.hosts() + hosts = jr.hosts() if len(hosts) > 0 { - return hosts, nil + return nil } } addr.disable(time.Hour) // no need to acquire write lock } + return nil // not found, fallback to ds.b + }) + if len(hosts) > 0 || err != nil { + return } if addrs, ok := ds.b[host]; ok { return addrs, nil @@ -133,7 +156,7 @@ func (ds *DNSList) DialContext(ctx context.Context, dialer *net.Dialer, firstFra defer ds.RUnlock() var conn net.Conn - for host, addrs := range ds.m { + _ = ds.rangeHosts(func(host string, addrs []*dnsstat) error { for _, addr := range addrs { if !addr.e || strings.HasPrefix(addr.a, "https://") { // disabled or is DoH continue @@ -155,16 +178,20 @@ func (ds *DNSList) DialContext(ctx context.Context, dialer *net.Dialer, firstFra tlsConn = tls.Client(conn, &tls.Config{ServerName: host}) err = terasu.Use(tlsConn).HandshakeContext(ctx, firstFragmentLen) if err == nil { - return + return nil } _ = tlsConn.Close() addr.disable(time.Hour) // no need to acquire write lock } - } + return nil + }) return } var IPv6Servers = DNSList{ + hostseq: []string{ + "dot.sb", "dns.google", "cloudflare-dns.com", "dns.opendns.com", "dns10.quad9.net", + }, m: map[string][]*dnsstat{ "dot.sb": { {"[2a09::]:853", true}, @@ -198,6 +225,9 @@ var IPv6Servers = DNSList{ } var IPv4Servers = DNSList{ + hostseq: []string{ + "dot.sb", "dns.google", "cloudflare-dns.com", "dns.opendns.com", "dns10.quad9.net", + }, m: map[string][]*dnsstat{ "dot.sb": { {"185.222.222.222:853", true}, diff --git a/dns/dns_test.go b/dns/dns_test.go index 62b8035..0439e72 100644 --- a/dns/dns_test.go +++ b/dns/dns_test.go @@ -66,9 +66,13 @@ func TestDNS(t *testing.T) { } func TestBadDNS(t *testing.T) { + dotv6serversseqbak := IPv6Servers.hostseq + dotv4serversseqbak := IPv4Servers.hostseq dotv6serversbak := IPv6Servers.m dotv4serversbak := IPv4Servers.m defer func() { + IPv6Servers.hostseq = dotv6serversseqbak + IPv4Servers.hostseq = dotv4serversseqbak IPv6Servers.m = dotv6serversbak IPv4Servers.m = dotv4serversbak }() @@ -100,7 +104,7 @@ func TestBadDNS(t *testing.T) { func (ds *DNSList) test() { ds.RLock() defer ds.RUnlock() - for host, addrs := range ds.m { + _ = ds.rangeHosts(func(host string, addrs []*dnsstat) error { for _, addr := range addrs { if !addr.e { continue @@ -119,5 +123,6 @@ func (ds *DNSList) test() { } fmt.Println("fail:", host, addr.a) } - } + return nil + }) }