mirror of
https://github.com/fumiama/terasu.git
synced 2026-06-05 09:10:24 +08:00
feat(dns): support lookup in sequence
This commit is contained in:
50
dns/dns.go
50
dns/dns.go
@@ -41,8 +41,9 @@ func (ds *dnsstat) disable(reEnable time.Duration) {
|
||||
|
||||
type DNSList struct {
|
||||
sync.RWMutex
|
||||
m map[string][]*dnsstat
|
||||
b map[string][]string
|
||||
hostseq []string
|
||||
m map[string][]*dnsstat
|
||||
b map[string][]string
|
||||
}
|
||||
|
||||
type DNSConfig struct {
|
||||
@@ -74,9 +75,14 @@ func (ds *DNSList) Add(c *DNSConfig) {
|
||||
ds.Lock()
|
||||
defer ds.Unlock()
|
||||
addList := map[string][]*dnsstat{}
|
||||
addHosts := map[string]struct{}{}
|
||||
for host, addrs := range c.Servers {
|
||||
availableHosts, ok := ds.m[host]
|
||||
if !ok {
|
||||
addHosts[host] = struct{}{}
|
||||
}
|
||||
for _, addr := range addrs {
|
||||
if !hasrecord(ds.m[host], addr) && !hasrecord(addList[host], addr) {
|
||||
if !hasrecord(availableHosts, addr) && !hasrecord(addList[host], addr) {
|
||||
addList[host] = append(addList[host], &dnsstat{addr, true})
|
||||
}
|
||||
}
|
||||
@@ -84,6 +90,9 @@ func (ds *DNSList) Add(c *DNSConfig) {
|
||||
for host, addrs := range addList {
|
||||
ds.m[host] = append(ds.m[host], addrs...)
|
||||
}
|
||||
for host := range addHosts {
|
||||
ds.hostseq = append(ds.hostseq, host)
|
||||
}
|
||||
addListFallback := map[string][]string{}
|
||||
for host, addrs := range c.Fallbacks {
|
||||
for _, addr := range addrs {
|
||||
@@ -97,24 +106,38 @@ func (ds *DNSList) Add(c *DNSConfig) {
|
||||
}
|
||||
}
|
||||
|
||||
func (ds *DNSList) lookupHostDoH(ctx context.Context, host string) ([]string, error) {
|
||||
// rangeHosts in sequence, please use in rlock
|
||||
func (ds *DNSList) rangeHosts(fn func(host string, addrs []*dnsstat) error) error {
|
||||
for _, h := range ds.hostseq {
|
||||
if err := fn(h, ds.m[h]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ds *DNSList) lookupHostDoH(ctx context.Context, host string) (hosts []string, err error) {
|
||||
ds.RLock()
|
||||
defer ds.RUnlock()
|
||||
// try to use DoH first
|
||||
for _, addrs := range ds.m {
|
||||
err = ds.rangeHosts(func(_ string, addrs []*dnsstat) error {
|
||||
for _, addr := range addrs {
|
||||
if !addr.e || !strings.HasPrefix(addr.a, "https://") { // disabled or is not DoH
|
||||
continue
|
||||
}
|
||||
jr, err := lookupdoh(ctx, addr.a, host)
|
||||
if err == nil {
|
||||
hosts := jr.hosts()
|
||||
hosts = jr.hosts()
|
||||
if len(hosts) > 0 {
|
||||
return hosts, nil
|
||||
return nil
|
||||
}
|
||||
}
|
||||
addr.disable(time.Hour) // no need to acquire write lock
|
||||
}
|
||||
return nil // not found, fallback to ds.b
|
||||
})
|
||||
if len(hosts) > 0 || err != nil {
|
||||
return
|
||||
}
|
||||
if addrs, ok := ds.b[host]; ok {
|
||||
return addrs, nil
|
||||
@@ -133,7 +156,7 @@ func (ds *DNSList) DialContext(ctx context.Context, dialer *net.Dialer, firstFra
|
||||
defer ds.RUnlock()
|
||||
|
||||
var conn net.Conn
|
||||
for host, addrs := range ds.m {
|
||||
_ = ds.rangeHosts(func(host string, addrs []*dnsstat) error {
|
||||
for _, addr := range addrs {
|
||||
if !addr.e || strings.HasPrefix(addr.a, "https://") { // disabled or is DoH
|
||||
continue
|
||||
@@ -155,16 +178,20 @@ func (ds *DNSList) DialContext(ctx context.Context, dialer *net.Dialer, firstFra
|
||||
tlsConn = tls.Client(conn, &tls.Config{ServerName: host})
|
||||
err = terasu.Use(tlsConn).HandshakeContext(ctx, firstFragmentLen)
|
||||
if err == nil {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
_ = tlsConn.Close()
|
||||
addr.disable(time.Hour) // no need to acquire write lock
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var IPv6Servers = DNSList{
|
||||
hostseq: []string{
|
||||
"dot.sb", "dns.google", "cloudflare-dns.com", "dns.opendns.com", "dns10.quad9.net",
|
||||
},
|
||||
m: map[string][]*dnsstat{
|
||||
"dot.sb": {
|
||||
{"[2a09::]:853", true},
|
||||
@@ -198,6 +225,9 @@ var IPv6Servers = DNSList{
|
||||
}
|
||||
|
||||
var IPv4Servers = DNSList{
|
||||
hostseq: []string{
|
||||
"dot.sb", "dns.google", "cloudflare-dns.com", "dns.opendns.com", "dns10.quad9.net",
|
||||
},
|
||||
m: map[string][]*dnsstat{
|
||||
"dot.sb": {
|
||||
{"185.222.222.222:853", true},
|
||||
|
||||
@@ -66,9 +66,13 @@ func TestDNS(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBadDNS(t *testing.T) {
|
||||
dotv6serversseqbak := IPv6Servers.hostseq
|
||||
dotv4serversseqbak := IPv4Servers.hostseq
|
||||
dotv6serversbak := IPv6Servers.m
|
||||
dotv4serversbak := IPv4Servers.m
|
||||
defer func() {
|
||||
IPv6Servers.hostseq = dotv6serversseqbak
|
||||
IPv4Servers.hostseq = dotv4serversseqbak
|
||||
IPv6Servers.m = dotv6serversbak
|
||||
IPv4Servers.m = dotv4serversbak
|
||||
}()
|
||||
@@ -100,7 +104,7 @@ func TestBadDNS(t *testing.T) {
|
||||
func (ds *DNSList) test() {
|
||||
ds.RLock()
|
||||
defer ds.RUnlock()
|
||||
for host, addrs := range ds.m {
|
||||
_ = ds.rangeHosts(func(host string, addrs []*dnsstat) error {
|
||||
for _, addr := range addrs {
|
||||
if !addr.e {
|
||||
continue
|
||||
@@ -119,5 +123,6 @@ func (ds *DNSList) test() {
|
||||
}
|
||||
fmt.Println("fail:", host, addr.a)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user