1
0
mirror of https://github.com/fumiama/terasu.git synced 2026-06-05 09:10:24 +08:00

feat(dns): support lookup in sequence

This commit is contained in:
源文雨
2025-10-05 12:40:27 +08:00
parent f6efbe4b61
commit fffef1be40
2 changed files with 47 additions and 12 deletions

View File

@@ -41,8 +41,9 @@ func (ds *dnsstat) disable(reEnable time.Duration) {
type DNSList struct {
sync.RWMutex
m map[string][]*dnsstat
b map[string][]string
hostseq []string
m map[string][]*dnsstat
b map[string][]string
}
type DNSConfig struct {
@@ -74,9 +75,14 @@ func (ds *DNSList) Add(c *DNSConfig) {
ds.Lock()
defer ds.Unlock()
addList := map[string][]*dnsstat{}
addHosts := map[string]struct{}{}
for host, addrs := range c.Servers {
availableHosts, ok := ds.m[host]
if !ok {
addHosts[host] = struct{}{}
}
for _, addr := range addrs {
if !hasrecord(ds.m[host], addr) && !hasrecord(addList[host], addr) {
if !hasrecord(availableHosts, addr) && !hasrecord(addList[host], addr) {
addList[host] = append(addList[host], &dnsstat{addr, true})
}
}
@@ -84,6 +90,9 @@ func (ds *DNSList) Add(c *DNSConfig) {
for host, addrs := range addList {
ds.m[host] = append(ds.m[host], addrs...)
}
for host := range addHosts {
ds.hostseq = append(ds.hostseq, host)
}
addListFallback := map[string][]string{}
for host, addrs := range c.Fallbacks {
for _, addr := range addrs {
@@ -97,24 +106,38 @@ func (ds *DNSList) Add(c *DNSConfig) {
}
}
func (ds *DNSList) lookupHostDoH(ctx context.Context, host string) ([]string, error) {
// rangeHosts in sequence, please use in rlock
func (ds *DNSList) rangeHosts(fn func(host string, addrs []*dnsstat) error) error {
for _, h := range ds.hostseq {
if err := fn(h, ds.m[h]); err != nil {
return err
}
}
return nil
}
func (ds *DNSList) lookupHostDoH(ctx context.Context, host string) (hosts []string, err error) {
ds.RLock()
defer ds.RUnlock()
// try to use DoH first
for _, addrs := range ds.m {
err = ds.rangeHosts(func(_ string, addrs []*dnsstat) error {
for _, addr := range addrs {
if !addr.e || !strings.HasPrefix(addr.a, "https://") { // disabled or is not DoH
continue
}
jr, err := lookupdoh(ctx, addr.a, host)
if err == nil {
hosts := jr.hosts()
hosts = jr.hosts()
if len(hosts) > 0 {
return hosts, nil
return nil
}
}
addr.disable(time.Hour) // no need to acquire write lock
}
return nil // not found, fallback to ds.b
})
if len(hosts) > 0 || err != nil {
return
}
if addrs, ok := ds.b[host]; ok {
return addrs, nil
@@ -133,7 +156,7 @@ func (ds *DNSList) DialContext(ctx context.Context, dialer *net.Dialer, firstFra
defer ds.RUnlock()
var conn net.Conn
for host, addrs := range ds.m {
_ = ds.rangeHosts(func(host string, addrs []*dnsstat) error {
for _, addr := range addrs {
if !addr.e || strings.HasPrefix(addr.a, "https://") { // disabled or is DoH
continue
@@ -155,16 +178,20 @@ func (ds *DNSList) DialContext(ctx context.Context, dialer *net.Dialer, firstFra
tlsConn = tls.Client(conn, &tls.Config{ServerName: host})
err = terasu.Use(tlsConn).HandshakeContext(ctx, firstFragmentLen)
if err == nil {
return
return nil
}
_ = tlsConn.Close()
addr.disable(time.Hour) // no need to acquire write lock
}
}
return nil
})
return
}
var IPv6Servers = DNSList{
hostseq: []string{
"dot.sb", "dns.google", "cloudflare-dns.com", "dns.opendns.com", "dns10.quad9.net",
},
m: map[string][]*dnsstat{
"dot.sb": {
{"[2a09::]:853", true},
@@ -198,6 +225,9 @@ var IPv6Servers = DNSList{
}
var IPv4Servers = DNSList{
hostseq: []string{
"dot.sb", "dns.google", "cloudflare-dns.com", "dns.opendns.com", "dns10.quad9.net",
},
m: map[string][]*dnsstat{
"dot.sb": {
{"185.222.222.222:853", true},

View File

@@ -66,9 +66,13 @@ func TestDNS(t *testing.T) {
}
func TestBadDNS(t *testing.T) {
dotv6serversseqbak := IPv6Servers.hostseq
dotv4serversseqbak := IPv4Servers.hostseq
dotv6serversbak := IPv6Servers.m
dotv4serversbak := IPv4Servers.m
defer func() {
IPv6Servers.hostseq = dotv6serversseqbak
IPv4Servers.hostseq = dotv4serversseqbak
IPv6Servers.m = dotv6serversbak
IPv4Servers.m = dotv4serversbak
}()
@@ -100,7 +104,7 @@ func TestBadDNS(t *testing.T) {
func (ds *DNSList) test() {
ds.RLock()
defer ds.RUnlock()
for host, addrs := range ds.m {
_ = ds.rangeHosts(func(host string, addrs []*dnsstat) error {
for _, addr := range addrs {
if !addr.e {
continue
@@ -119,5 +123,6 @@ func (ds *DNSList) test() {
}
fmt.Println("fail:", host, addr.a)
}
}
return nil
})
}